見出し画像

Stable Baselines入門 / GIFアニメでの出力

Stable Baselinesのエージェントのテストの様子をGIFアニメで出力します。

画像1

import gym
import imageio
import numpy as np
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2

# 環境の生成
env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env])

# モデルの生成
model = PPO2(MlpPolicy, env, verbose=1)

# モデルの学習
model.learn(total_timesteps=10000)

# 画像配列の準備
images = []

# モデルのテスト
state = env.reset()
for i in range(200):
   images.append(model.env.render(mode='rgb_array')) # 画像の追加

   env.render()
   action, _ = model.predict(state)
   state, reward, done, info = env.step(action)

env.close()

# 画像配列をGIFファイルに変換して保存
imageio.mimsave('cartpole.gif', images, 'GIF', **{'duration': 1.0/30.0})


この記事が気に入ったらサポートをしてみませんか?