見出し画像

Stable Baselines入門 / マルチプロセッシング

1. マルチプロセッシング

PPO2のように複数環境で訓練するアルゴリズムは、マルチプロセスで効率よく学習することができます。マルチプロセスが必要な場合は「SubprocVecEnv」を使います。

ステップあたり1環境でエージェントを学習する代わりに、複数環境でそれを学習します。エージェントと環境がやりとりする「行動」「状態」「報酬」「エピソード完了」「情報」は複数次元のベクトルになります。

マルチプロセッシングのコードは次の通りです。今回は4環境で同時に訓練とテストを行っています。

import gym
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines import PPO2
from stable_baselines.common import set_global_seeds

# 環境を生成する関数
def make_env(env_id, rank, seed=0):
   def _init():
       env = gym.make(env_id)
       env.seed(seed + rank)
       return env
   set_global_seeds(seed)
   return _init

# 環境の生成
env_id = "CartPole-v1" # 環境ID
num_env = 4 # 環境の数
env = SubprocVecEnv([make_env(env_id, i) for i in range(num_env)])

# エージェントの生成
agent = PPO2(MlpPolicy, env, verbose=1)

# エージェントの学習
agent.learn(total_timesteps=10000)

# テスト
state = env.reset()
for i in range(200):
   env.render()
   action, _ = agent.predict(state)
   state, reward, done, info = env.step(action)

◎set_global_seeds
PythonとTensorFlowとNumpyとGym Spaceの乱数シードを指定します。


この記事が気に入ったらサポートをしてみませんか?