Stable Baselines入門 / マルチプロセッシング
1. マルチプロセッシング
PPO2のように複数環境で訓練するアルゴリズムは、マルチプロセスで効率よく学習することができます。マルチプロセスが必要な場合は「SubprocVecEnv」を使います。
ステップあたり1環境でエージェントを学習する代わりに、複数環境でそれを学習します。エージェントと環境がやりとりする「行動」「状態」「報酬」「エピソード完了」「情報」は複数次元のベクトルになります。
マルチプロセッシングのコードは次の通りです。今回は4環境で同時に訓練とテストを行っています。
import gym
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines import PPO2
from stable_baselines.common import set_global_seeds
# 環境を生成する関数
def make_env(env_id, rank, seed=0):
def _init():
env = gym.make(env_id)
env.seed(seed + rank)
return env
set_global_seeds(seed)
return _init
# 環境の生成
env_id = "CartPole-v1" # 環境ID
num_env = 4 # 環境の数
env = SubprocVecEnv([make_env(env_id, i) for i in range(num_env)])
# エージェントの生成
agent = PPO2(MlpPolicy, env, verbose=1)
# エージェントの学習
agent.learn(total_timesteps=10000)
# テスト
state = env.reset()
for i in range(200):
env.render()
action, _ = agent.predict(state)
state, reward, done, info = env.step(action)
◎set_global_seeds
PythonとTensorFlowとNumpyとGym Spaceの乱数シードを指定します。
この記事が気に入ったらサポートをしてみませんか?