Stable Baselines入門 / GAIL
1. GAIL
「GAIL」(Generative Adversarial Imitaiton Learning)は「模倣学習」のひとつで、人間のデモ(一連の観測と行動)を与えて、模倣できていたら報酬を与える学習法です。現在の実装では画像による学習がまだ対応していません。
今回は、「CartPole-v0」環境を「GAIL」で学習します。
2. OpenMPI
OpenMPI並列計算のライブラリです。OpenMPIに依存するアルゴリズム(GAIL、DDPG、TRPO、PPO1)の利用に必要です。以下のコマンドでインストールします。
$ pip install stable-baselines[mpi]
3. 人間のデモの記録
はじめに、人間のデモを記録するコードを作成します。
描画するウィンドウと、キー操作するウィンドウに別れています。1つのウィンドウにするは大変なので、分離したままとします。
キー操作のウィンドウを選択して、何かキーを押した時に記録開始となります。CartPoleは左右キーで操作します。
5エピソード分の記録で終了となります。人間のデモが「cartpole_traj.npz」として出力されます。
import random
import pyglet
import gym
import time
from pyglet.window import key
from stable_baselines.gail import generate_expert_traj
# 環境の生成
env = gym.make('CartPole-v1')
env.reset()
env.render()
# キーイベント用のウィンドウの生成
win = pyglet.window.Window(width=300, height=100, vsync=False)
key_handler = pyglet.window.key.KeyStateHandler()
win.push_handlers(key_handler)
pyglet.app.platform_event_loop.start()
# キー状態の取得
def get_key_state():
key_state = set()
win.dispatch_events()
for key_code, pressed in key_handler.items():
if pressed:
key_state.add(key_code)
return key_state
# キー入力待ち
while len(get_key_state()) == 0:
time.sleep(1.0/30.0)
# デモの行動の指定
def dummy_expert(_obs):
# キー状態の取得
key_state = get_key_state()
# 行動の選択
action = 0
if key.LEFT in key_state:
action = 0
elif key.RIGHT in key_state:
action = 1
# スリープ
time.sleep(1.0/2.0)
# 環境の描画
env.render()
# 行動の選択
return action
# デモの記録
generate_expert_traj(dummy_expert, 'cartpole_traj', env, n_episodes=5)
コードの説明は「Stable Baselines入門 / Behavior Cloning」を参照。
4. GAILによる学習
人間のデモ「mountaincar_traj.npz」を使ってモデルを学習します。
import gym
import time
from stable_baselines import GAIL
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.gail import ExpertDataset, generate_expert_traj
# 環境の生成
env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env])
# デモの読み込み
dataset = ExpertDataset(expert_path='cartpole_traj.npz', verbose=1)
# モデルの生成
model = GAIL(MlpPolicy, env, dataset, verbose=1)
# モデルの学習
model.learn(total_timesteps=10000)
# モデルのテスト
state = env.reset()
while True:
time.sleep(1.0/10.0)
env.render()
action, _ = model.predict(state)
state, reward, done, info = env.step(action)
if done:
env.reset()
◎デモの読み込み
デモの読み込みは、ExpertDataset()を使います。
◎モデルの生成と学習
GAILで学習するには「GAIL」を使います。引数にデモも渡します。学習は他の強化学習と同様にlearn()を使います。
この記事が気に入ったらサポートをしてみませんか?