Stable Baselines 3 入門 (2) - Monitor
「Stable Baselines 3」の「Monitor」の使い方をまとめました。
前回
1. Monitor
「Monitor」は、「報酬」(r)「エピソード長」(l)「時間」(t)をログ出力するためのラッパーです。使い方は、EnvをMonitorでラップするだけです。
import gym
import os
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
# ログフォルダの準備
log_dir = './logs/'
os.makedirs(log_dir, exist_ok=True)
# 学習環境の準備
env = gym.make('CartPole-v1')
env = Monitor(env, log_dir, allow_early_resets=True) # Monitorの利用
# モデルの準備
model = PPO('MlpPolicy', env, verbose=1)
# 学習の実行
model.learn(total_timesteps=128000)
# 推論の実行
state = env.reset()
while True:
# 学習環境の描画
env.render()
# モデルの推論
action, _ = model.predict(state, deterministic=True)
# 1ステップ実行
state, rewards, done, info = env.step(action)
# エピソード完了
if done:
break
# 学習環境の解放
env.close()
実行すると、logフォルダに以下のログが出力されます。
・monitor.csv
#{"t_start": 1661003542.954774, "env_id": "CartPole-v1"}
r,l,t
12.0,12,0.138155
46.0,46,0.155652
11.0,11,0.159838
:
2. グラフのプロット
グラフのプロットには、「qt」が必要なのでインストールします。
$ pip install PyQt5
グラフをプロットするコードは、次のとおりです。
・monitor_plot.py
import pandas as pd
import matplotlib.pyplot as plt
# monitor.csvの読み込み
df = pd.read_csv('./logs/monitor.csv', names=['r', 'l','t'])
df = df.drop(range(2)) # 1〜2行目の削除
# 報酬のプロット
x = range(len(df['r']))
y = df['r'].astype(float)
plt.plot(x, y)
plt.xlabel('episode')
plt.ylabel('reward')
plt.show()
# エピソード長のプロット
x = range(len(df['l']))
y = df['l'].astype(float)
plt.plot(x, y)
plt.xlabel('episode')
plt.ylabel('episode len')
plt.show()
実行すると、グラフがプロットされます。
$ python monitor_plot.py
この記事が気に入ったらサポートをしてみませんか?