見出し画像

Stable Baselines入門 / GAIL

1. GAIL

「GAIL」(Generative Adversarial Imitaiton Learning)は「模倣学習」のひとつで、人間のデモ(一連の観測と行動)を与えて、模倣できていたら報酬を与える学習法です。現在の実装では画像による学習がまだ対応していません。

今回は、「CartPole-v0」環境を「GAIL」で学習します。

2. OpenMPI

OpenMPI並列計算のライブラリです。OpenMPIに依存するアルゴリズム(GAIL、DDPG、TRPO、PPO1)の利用に必要です。以下のコマンドでインストールします。

$ pip install stable-baselines[mpi]

3. 人間のデモの記録

はじめに、人間のデモを記録するコードを作成します。

描画するウィンドウと、キー操作するウィンドウに別れています。1つのウィンドウにするは大変なので、分離したままとします。
キー操作のウィンドウを選択して、何かキーを押した時に記録開始となります。CartPoleは左右キーで操作します。
5エピソード分の記録で終了となります。人間のデモが「cartpole_traj.npz」として出力されます。

画像1

画像2

import random
import pyglet
import gym
import time
from pyglet.window import key
from stable_baselines.gail import generate_expert_traj

# 環境の生成
env = gym.make('CartPole-v1')
env.reset()
env.render()

# キーイベント用のウィンドウの生成
win = pyglet.window.Window(width=300, height=100, vsync=False)
key_handler = pyglet.window.key.KeyStateHandler()
win.push_handlers(key_handler)
pyglet.app.platform_event_loop.start()

# キー状態の取得
def get_key_state():
   key_state = set()
   win.dispatch_events()
   for key_code, pressed in key_handler.items():
       if pressed:
           key_state.add(key_code)
   return key_state

# キー入力待ち
while len(get_key_state()) == 0:
   time.sleep(1.0/30.0)

# デモの行動の指定
def dummy_expert(_obs):
   # キー状態の取得
   key_state = get_key_state()

   # 行動の選択
   action = 0
   if key.LEFT in key_state:
       action = 0
   elif key.RIGHT in key_state:
       action = 1

   # スリープ
   time.sleep(1.0/2.0)

   # 環境の描画
   env.render()

   # 行動の選択
   return action

# デモの記録
generate_expert_traj(dummy_expert, 'cartpole_traj', env, n_episodes=5)

コードの説明は「Stable Baselines入門 / Behavior Cloning」を参照。

4. GAILによる学習

人間のデモ「mountaincar_traj.npz」を使ってモデルを学習します。

import gym
import time
from stable_baselines import GAIL
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.gail import ExpertDataset, generate_expert_traj

# 環境の生成
env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env])

# デモの読み込み
dataset = ExpertDataset(expert_path='cartpole_traj.npz', verbose=1)

# モデルの生成
model = GAIL(MlpPolicy, env, dataset, verbose=1)

# モデルの学習
model.learn(total_timesteps=10000)

# モデルのテスト
state = env.reset()
while True:
   time.sleep(1.0/10.0)
   env.render()
   action, _ = model.predict(state)
   state, reward, done, info = env.step(action)
   if done:
       env.reset()

◎デモの読み込み
デモの読み込みは、ExpertDataset()を使います。

◎モデルの生成と学習
GAILで学習するには「GAIL」を使います。引数にデモも渡します。学習は他の強化学習と同様にlearn()を使います。


この記事が気に入ったらサポートをしてみませんか?