見出し画像

強化学習でトレードポイントを探れるか?


注意!

今回のコードはあくまでも試作段階のため、参考程度と御認識いただければと思います。また、今回のコードを使用し実際の取引をする場合は自己責任でお願いいたします。

強化学習でトレードポイントを探れるのか?

と、いうことで今回は日経平均の過去データを使用して強化学習を用いてトレンド予測をおこうなものとなります。得られる結果は買い、売り、保持でわかるようにチャートに表示するものです。理想を言えば、買い時か売り時がわかればOKというのが、今回の目的となります。

実行結果

とりあえず、先に実行結果を確認します。2回実行しましたが、実行結果が異なります。単純に考えるとどちらも利益が出ているようには見えませんが、こんなに結果が違うと別の視点でコードを見直す必要がありそうです。


今後の課題

何回やっても利益が出ているようなチャートがそもそもできるのか?そのためにはいくつか考えるべき点がありそうです。そもそも学習の内容が利益を出せるものになっているのか?ということでしょう。
そのための特徴量の考察、報酬の設定の仕方、その他の機械学習と併せてのプログラム作成を次回の課題として、今回作成したコードを書きに残します。

chatGPTに作ってもらったコード

import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import yfinance as yf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv
import gym

# 日経平均株価のデータを取得
ticker = "^N225"
data = yf.download(ticker, start="2004-01-01", end="2024-07-01")

# 必要なカラムの選択と欠損値の削除
data = data[['Close']].dropna()

# 移動平均やボリンジャーバンドなどのテクニカル指標を計算
data['SMA'] = data['Close'].rolling(window=20).mean()
data['UpperBand'] = data['Close'].rolling(window=20).mean() + 2 * data['Close'].rolling(window=20).std()
data['LowerBand'] = data['Close'].rolling(window=20).mean() - 2 * data['Close'].rolling(window=20).std()

# 欠損値の削除
data = data.dropna()

# 前処理済みデータの確認
print(data.tail())

class TrendEnv(gym.Env):
    def __init__(self, data):
        super(TrendEnv, self).__init__()
        self.data = data
        self.current_step = 0
        self.action_space = gym.spaces.Discrete(3)  # 0: Hold, 1: Buy, 2: Sell
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32)

    def reset(self):
        self.current_step = 0
        return self._get_observation()

    def step(self, action):
        self.current_step += 1
        done = self.current_step >= len(self.data) - 1
        reward = self._calculate_reward(action)
        return self._get_observation(), reward, done, {}

    def _get_observation(self):
        return self.data.iloc[self.current_step][['Close', 'SMA', 'UpperBand', 'LowerBand']].values

    def _calculate_reward(self, action):
        if self.current_step == 0:
            return 0
        if action == 1:  # Buy
            return self.data.iloc[self.current_step]['Close'] - self.data.iloc[self.current_step - 1]['Close'] - 0.001
        elif action == 2:  # Sell
            return self.data.iloc[self.current_step - 1]['Close'] - self.data.iloc[self.current_step]['Close'] - 0.001
        else:  # Hold
            return 0

# データのスケーリング
from sklearn.preprocessing import MinMaxScaler

scaler = MinMaxScaler()
data_scaled = scaler.fit_transform(data[['Close', 'SMA', 'UpperBand', 'LowerBand']])

# 環境の作成
env = DummyVecEnv([lambda: TrendEnv(pd.DataFrame(data_scaled, columns=['Close', 'SMA', 'UpperBand', 'LowerBand']))])

# モデルのトレーニング
model = DQN('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=50000)

# 学習したモデルを用いてトレンド判断を行う
state = env.reset()
trends = []
for _ in range(len(data)):
    action, _states = model.predict(state)
    trends.append(action)
    state, reward, done, info = env.step(action)
    if done:
        break

# 最後の行が欠けていた場合の対策として、trendsリストの長さを確認
if len(trends) < len(data):
    trends.append(0)  # デフォルトで 'Hold' を追加

data['Trend'] = trends[:len(data)]
data['Trend'] = data['Trend'].replace({0: 'Hold', 1: 'Buy', 2: 'Sell'})

# 直近3ヶ月のデータに限定
recent_data = data.tail(90)

plt.figure(figsize=(15, 7))
plt.plot(recent_data.index, recent_data['Close'], label='Close Price')
plt.plot(recent_data.index, recent_data['SMA'], label='SMA')
plt.plot(recent_data.index, recent_data['UpperBand'], label='UpperBand')
plt.plot(recent_data.index, recent_data['LowerBand'], label='LowerBand')

# トレンドに応じたバブルポイントをプロット
buy_signals = recent_data[recent_data['Trend'] == 'Buy']
sell_signals = recent_data[recent_data['Trend'] == 'Sell']

plt.scatter(buy_signals.index, buy_signals['Close'], label='Buy Signal', color='green', marker='^', alpha=1)
plt.scatter(sell_signals.index, sell_signals['Close'], label='Sell Signal', color='red', marker='v', alpha=1)

plt.legend()
plt.show()

この記事が気に入ったらサポートをしてみませんか?