見出し画像

上がるの? 下がるの? 二値分類 AIを使って日経平均株価の予測に挑戦 過学習への対策は始まったばかり編


前回の記事より、今後の課題を振り返る

前回は、RNN LSTMの構造を使用して日経平均株価の予測に対する二値分類を行いました。

AIモデルの予測に関しては、Accuracy(分類精度)が50.29%であったため、未だ山勘レベルです。

しかし、私にとっては、ようやくスタートラインに立てたと感じた結果でした。

さて、前回の結果でAccuracy(分類精度)が低くなった原因として、過学習が挙げられます。

過学習とは、AIモデルが学習データで最適化されすぎた結果、条件が異なる他の評価データに対して学習した性能が十分に発揮されない状態です。

RNN LSTM AIモデル 学習曲線 Affine 出力ニューロン数 100 前回の結果
前回の結果における学習曲線

上記は、前回の結果における学習曲線を表しています。

Epoch(AIモデルが学習した回数)が増えるに従い、赤い実線で示された学習時のエラー(TRAINING ERROR)が低下しているのが確認できます。

一方で、Epochの増加に従い、赤い破線で示された評価時のエラー(VALIDATION ERROR)は増加しています。

この状態が過学習を表しています。

一般的に、過学習への対策は以下の通りです。

  • 過学習への対策

    • 学習データ(説明変数)の見直し

    • AIモデルの見直し

そこで、今回は、学習データの見直しを行うことにしました。

学習データの見直し

これまで使用してきた学習データのフォーマットは、下記の通りです。

  • これまでの学習データのフォーマット

    • 1行7列のベクトルデータ

    • 左から終値始値高値安値5日移動平均25日移動平均75日移動平均

    • RNNに使用する場合は、1行7列で表される1日分のデータを2日分横に並べて1行14列としている

今回使用する学習データのフォーマットは、下記の通りです。

  • 今回の学習データのフォーマット

    • 1日分のデータを左から終値始値高値安値とする

    • 5日分のデータを時系列に従い左から右に並べて1行20列のベクトルデータとする

RNN LSTMの過学習への対策のために用意した学習データの基データを下記に示します。

RNN LSTM 過学習への対策 学習データ 基データ
RNN LSTM向け過学習への対策用の学習データの基データ

U列のラベルは、翌営業日の終値が当日の終値以上であれば1, そうでなければ0としています。

例えば、セルU2IF(Q2<=Q3, 1, 0)としています。

今回は学習データのフォーマットを変更しましたので、これに伴いLSTMのパラメーターも修正する必要があります。

AIモデルに使用したRNN LSTMの構造図

学習データのフォーマットに合わせてパラメーターを修正したLSTMの構造図を下記に示します。

RNN LSTM 構造図 学習データ フォーマット 修正
修正したRNN LSTMによるAIモデルの構造図

修正箇所は、下記の2カ所です。

  • AIモデルの修正箇所

    • InputSize20に修正

    • ReshapeOutShape5,4に修正

Reshapeで1行20列のベクトルデータを5行4列(5×(終値、始値、高値、安値))のベクトルデータに変換しています。

そして、RecurrentInputおよびRecurrentOutputで挟まれた層では、分割された1行4列(終値、始値、高値、安値)のベクトルデータに対して、処理を合わせて5回行います。

AIモデルの学習および評価を実行

修正したAIモデルに今回作成したRNN用の学習データを学習させた際の学習曲線を下記に示します。

修正 RNN LSTM AIモデル 二値分類 学習曲線
修正したRNN LSTMのAIモデルによる二値分類の学習曲線

前回の結果と同じく、AIモデルの学習が進むにつれてVALIDATION ERRORが増加する傾向となっていますが、途中でリセットされたような状況も確認できます。

学習の途中で何が起きたのかは分かっておりません。

続いて、今回の評価結果に対する混同行列を示します。

修正 RNN LSTM AIモデル 二値分類 混同行列
修正したRNN LSTMのAIモデルによる二値分類の混同行列

比較のため、下記に前回の評価結果に対する混同行列を示します。

前回 RNN LSTM AIモデル 二値分類 混同行列
前回のRNN LSTMのAIモデルによる二値分類の混同行列

僅かではありますが、Accuracy(分類精度)が50.29%から56.73%に改善していることが確認できます。

今後の課題

引き続き、過学習への対策を検討していきます。

  • 過学習への対策

    • 学習データ(説明変数)の見直し

    • AIモデルの見直し

学習データの見直しについては、移動平均やボリンジャーバンド、MACD, 等のテクニカル分析で使用される指標の追加を検討します。

AIモデルの見直しについては、中間層の多層化やCNN(Convolutional Neural Network)の使用も検討していきたいと考えています。

ただし、LSTMの構造に切り替えてからはAIモデルの学習時間が数時間に及ぶため、効率の良いやり方も検討したいと思います。

AIモデルのデータについて

今回作成したAIモデルのデータは、Googleドライブにて共有しています。

URL: https://drive.google.com/drive/folders/1eYkB4ob_VThhObaH3WqvalDDjem-UqUr?usp=drive_link

  • N225_LSTM_Affine100N_5Days.sdcproj

    • Neural Network Console用のプロジェクトファイル

この記事が気に入ったらサポートをしてみませんか?