見出し画像

Stable BaselinesでLSTMポリシーを使う

1. LSTM

「LSTM」は、時系列を扱えるニューラルネットワークで、主に動画分類、自然言語処理、音声認識などに利用されます。
強化学習では、通常「現在の環境」の状態に応じて「エージェント」が「行動」を決定しますが、「LSTM」を利用することで「過去の環境」の状態も「行動」決定の判断材料に使えるようになります。

2. Stable BaselinesでLSTMポリシーを使う

Stable BaselinesでLSTMポリシーを使うコードは、次のとおり。

import gym
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2

# 環境の生成
env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env])

# モデルの生成
model = PPO2('MlpLstmPolicy', env, nminibatches=1, verbose=1)

# モデルの学習
model.learn(total_timesteps=100000)

# LSTMパラメータの準備
lstm_state = None
done = [False for _ in range(env.num_envs)]

# モデルのテスト
state = env.reset()
for i in range(200):
    # 環境の描画
    env.render()

    # モデルの推論
    action, lstm_state = model.predict(state, state=lstm_state, mask=done)

    # 1ステップ実行
    state, rewards, done, info = env.step(action)
    if done:
        break

◎モデルの生成
MLPでLSTMを利用するにはMlpLstmPolichy、CNNでLSTMを利用するにはCnnLstmPolicyを使います。並列実行する環境数はnminibatchesの倍数であるため、nminibatches=1を指定します。

◎モデルの推論
LSTMで学習したモデルの推論を行うには、LSTM状態「state」とLSTM状態のリセットを促す「mask」を指定します。戻り値として、行動といっしょにLSTM状態を受け取り、次の推論に使います。


この記事が気に入ったらサポートをしてみませんか?