見出し画像

Stable Baselines チュートリアル(4) / コールバックとハイパーパラメータの調整

以下のColabが面白かったので、ざっくり訳してみました。

Stable Baselines Tutorial - Callbacks and hyperparameter tuning

1. はじめに

このノートブックでは、監視、自動保存、モデル操作、進行状況バーなどを実行できるコールバックの使用方法を学習します。また、適切なハイパーパラメータを見つけることがRLで成功するための鍵であることもわかります。

2. pipを使用して依存関係と安定したベースラインをインストール

!apt install swig
!pip install tqdm==4.36.1
!pip install stable-baselines[mpi]==2.8.0
import gym
from stable_baselines import A2C, SAC, PPO2, TD3

3. ハイパーパラメータの調整の重要性

「教師あり学習」と比較した場合、「深層強化学習」は、「学習率」「ニューロンの数」「層の数」「オプティマイザー」などの「ハイパーパラメータ」の選択にはるかに敏感です。「ハイパーパラメータ」の選択が適切でないと、収束が不安定/不安定になる可能性があります。

Pendulum環境に適用された「SAC」アルゴリズムで「ハイパーパラメータ」の重要性を示します。「デフォルト」のパラメータと「調整済み」パラメータの間のパフォーマンスの変化に注目してください。

import numpy as np

def evaluate(model, env, num_episodes=100):
    # この関数は、単一の環境でのみ機能します。
    all_episode_rewards = []
    for i in range(num_episodes):
        episode_rewards = []
        done = False
        obs = env.reset()
        while not done:
            action, _states = model.predict(obs)
            obs, reward, done, info = env.step(action)
            episode_rewards.append(reward)

        all_episode_rewards.append(sum(episode_rewards))

    mean_episode_reward = np.mean(all_episode_rewards)
    return mean_episode_reward
eval_env = gym.make('Pendulum-v0')

「デフォルト」のパラメータで評価します。

default_model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1).learn(8000)
evaluate(default_model, eval_env, num_episodes=100)
-1224.8244208398855

「調整済み」パラメータで評価します。

tuned_model = SAC('MlpPolicy', 'Pendulum-v0', batch_size=256, verbose=1, policy_kwargs=dict(layers=[256, 256])).learn(8000)
evaluate(tuned_model, eval_env, num_episodes=100)
-335.3693346865786

ハイパーパラメータの調整については、このチュートリアルの範囲外です。ただし、「RL Zoo」で調整されたハイパーパラメータと、「Optuna」を使用した自動ハイパーパラメータ最適化を提供していることを知っておく必要があります。

4. ヘルパー関数

これは、「コールバック」が変数を格納するのを支援するためのものですが、クラスメソッドを渡すことも可能です。

def get_callback_vars(model, **kwargs):
    """
    コールバック関数の変数を保存できます
    :param model: (BaseRLModel)
    :param **kwargs: コールバック変数の初期値
    """
    # 呼び出された属性をモデルに保存
    if not hasattr(model, "_callback_vars"):
        model._callback_vars = dict(**kwargs)
    else: # すべてのkwargsがコールバック変数にあることを確認
        for (name, val) in kwargs.items():
            if name not in model._callback_vars:
                model._callback_vars[name] = val
    return model._callback_vars # dict参照を返す(可変)

5. コールバック

◎関数的アプローチ
「コールバック関数」は、モデルからlocals()変数とglobals()変数を取得し、訓練を続行するかどうかのブール値を返します。モデル変数、特に_locals["self"]へのアクセスのおかげで、訓練を停止したり、モデルのコードを変更したりすることなく、モデルのパラメータを変更することができます。

ここには、2回しか呼び出せない単純なコールバックがあります。

def simple_callback(_locals, _globals):
    """
    各ステップ(DQNその他の場合)またはnステップ後(ACERまたはPPO2を参照)に呼び出されるコールバック
    :param _locals: (dict)
    :param _globals: (dict)
    """
    # 初期化されていない場合はデフォルト値でコールバック変数を取得
    callback_vars = get_callback_vars(_locals["self"], called=False)

    if not callback_vars["called"]:
        print("callback - first call")
        callback_vars["called"] = True
        return True # Trueを返し、訓練を継続
    else:
        print("callback - second call")
        return False # Falseを返し、訓練を停止

model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1)
model.learn(8000, callback=simple_callback)

◎ 最初の例 : 最適なモデルの自動保存
RLでは、訓練中にモデルのクリーンバージョンを保持しておくと非常に便利です。不適切なポリシーが焼き付いてしまう可能性があるためです。これは、コールバックの典型的な使用例です。

Monitorラッパーを使用して、環境の統計を保存し、それらを使用して平均報酬を計算できます。これにより、訓練中に最適なモデルを保存できます。これはRLエージェントを評価する適切な方法ではないことに注意してください。

テスト環境を作成し、コールバックでエージェントのパフォーマンスを評価する必要があります。簡単にするために、訓練報酬をプロキシとして使用します。

import os

import numpy as np

from stable_baselines.bench import Monitor
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.results_plotter import load_results, ts2xy

def auto_save_callback(_locals, _globals):
    """
    各ステップ(DQNその他の場合)またはnステップ後(ACERまたはPPO2を参照)に呼び出されるコールバック
    :param _locals: (dict)
    :param _globals: (dict)
    """
    # 初期化されていない場合はデフォルト値でコールバック変数を取得
    callback_vars = get_callback_vars(_locals["self"], n_steps=0, best_mean_reward=-np.inf)

    # 20ステップごと
    if callback_vars["n_steps"] % 20 == 0:
        # ポリシーのパフォーマンスを評価する
        x, y = ts2xy(load_results(log_dir), 'timesteps')
        if len(x) > 0:
            mean_reward = np.mean(y[-100:])

            # 新しい最高のモデル、ここでエージェントを保存できる
            if mean_reward > callback_vars["best_mean_reward"]:
                callback_vars["best_mean_reward"] = mean_reward
                # Example for saving best model
                print("Saving new best model at {} timesteps".format(x[-1]))
                _locals['self'].save(log_dir + 'best_model')
    callback_vars["n_steps"] += 1
    return True

# logフォルダの生成
log_dir = "/tmp/gym/"
os.makedirs(log_dir, exist_ok=True)

# 環境を作成してラップ
env = gym.make('CartPole-v1')
env = Monitor(env, log_dir, allow_early_resets=True)
env = DummyVecEnv([lambda: env])

model = A2C('MlpPolicy', env, verbose=0)
model.learn(total_timesteps=10000, callback=auto_save_callback)

◎ 2番目の例 : パフォーマンスのリアルタイムプロット
訓練中に、一時的な報酬と比較して、訓練が時間とともにどのように進行するかが役立つ場合があります。このため、「Stable Baselines」は「TensorBoard」をサポートしていますが、これは特にディスクスペースの使用において非常に面倒です。

注:残念ながら、ライブプロットはColabですぐに使用できません。

ここで、「コールバック」を再度使用して、Monitorラッパーを使用して、一時的な報酬をリアルタイムでプロットできます。

import matplotlib.pyplot as plt
import numpy as np
%matplotlib notebook

def plotting_callback(_locals, _globals):
    """
    各ステップ(DQNその他の場合)またはnステップ後(ACERまたはPPO2を参照)に呼び出されるコールバック
    :param _locals: (dict)
    :param _globals: (dict)
    """
    # 初期化されていない場合はデフォルト値でコールバック変数を取得
    callback_vars = get_callback_vars(_locals["self"], plot=None)

    # モニターのデータの取得
    x, y = ts2xy(load_results(log_dir), 'timesteps')
    if callback_vars["plot"] is None: # make the plot
        plt.ion()
        fig = plt.figure(figsize=(6,3))
        ax = fig.add_subplot(111)
        line, = ax.plot(x, y)
        callback_vars["plot"] = (line, ax, fig)
        plt.show()
    else: # update and rescale the plot
        callback_vars["plot"][0].set_data(x, y)
        callback_vars["plot"][-2].relim()
        callback_vars["plot"][-2].set_xlim([_locals["total_timesteps"] * -0.02,
                                           _locals["total_timesteps"] * 1.02])
        callback_vars["plot"][-2].autoscale_view(True,True,True)
        callback_vars["plot"][-1].canvas.draw()

# logフォルダの生成
log_dir = "/tmp/gym/"
os.makedirs(log_dir, exist_ok=True)

# 環境を作成してラップ
env = gym.make('MountainCarContinuous-v0')
env = Monitor(env, log_dir, allow_early_resets=True)
env = DummyVecEnv([lambda: env])

model = PPO2('MlpPolicy', env, verbose=0)
model.learn(20000, callback=plotting_callback)

◎ 3番目の例 : 進行状況バー
RLを開発および使用する場合、実験環境の質の向上は常に歓迎されます。
ここでは、tqdmを使用して、訓練の進行状況バー、1秒あたりのタイムステップ数、および訓練の終了までの残り時間を表示しました。

from tqdm.auto import tqdm

# このコールバックは「with」ブロックを使用して、正しい初期化と破棄を可能にする
class progressbar_callback(object):
    def __init__(self, total_timesteps): # init object with total timesteps
        self.pbar = None
        self.total_timesteps = total_timesteps

    def __enter__(self): # create the progress bar and callback, return the callback
        self.pbar = tqdm(total=self.total_timesteps)

        def callback_progressbar(local_, global_):
            self.pbar.n = local_["self"].num_timesteps
            self.pbar.update(0)

        return callback_progressbar

    def __exit__(self, exc_type, exc_val, exc_tb): # close the callback
        self.pbar.n = self.total_timesteps
        self.pbar.update(0)
        self.pbar.close()

model = TD3('MlpPolicy', 'Pendulum-v0', verbose=0)
with progressbar_callback(2000) as callback: # this the garanties that the tqdm progress bar closes correctly
    model.learn(2000, callback=callback)

◎ 4番目の例 : Composition
コールバックの関数的性質のおかげで、コールバックのCompositionを単一のコールバックにできます。これは、最適なモデルを自動保存し、訓練の進行状況バーと一時的な報酬を表示できることを意味します。

%matplotlib notebook

def compose_callback(*callback_funcs): # takes a list of functions, and returns the composed function.
    def _callback(_locals, _globals):
        continue_training = True
        for cb_func in callback_funcs:
            if cb_func(_locals, _globals) is False: # as a callback can return None for legacy reasons.
                continue_training = False
        return continue_training
    return _callback

# logフォルダの生成
log_dir = "/tmp/gym/"
os.makedirs(log_dir, exist_ok=True)

# 環境を作成してラップ
env = gym.make('CartPole-v1')
env = Monitor(env, log_dir, allow_early_resets=True)
env = DummyVecEnv([lambda: env])

model = PPO2('MlpPolicy', env, verbose=0)
with progressbar_callback(10000) as progress_callback:
    model.learn(10000, callback=compose_callback(progress_callback, plotting_callback, auto_save_callback))

◎ 【演習】 独自のコールバックのコーディング
前の例では、コールバックとは何か、それを使用して何をするかの基本を示しました。この演習の目標は、テスト環境を使用してモデルを評価し、これが最もよく知られているモデルである場合に保存するコールバックを作成することです。物事を簡単にするために、マジックメソッド__call__を持つ関数の代わりにクラスを使用します。

class EvalCallback(object):
  """
  エージェントを評価するためのコールバック

  :param eval_env: (gym.Env) 初期化に使用される環境
  :param n_eval_episodes: (int) エージェントをテストするエピソード数
  :param eval_freq: (int) コールバックのeval_freq呼び出しごとにエージェントを評価
  """
  def __init(self, eval_env, n_eval_episodes=5, eval_freq=20):
    super(EvalCallback, self).__init__()
    self.eval_env = eval_env
    self.n_eval_episodes = n_eval_episodes
    self.eval_freq = eval_freq
    self.n_calls = 0
    self.best_mean_reward = -np.inf

  def __call__(self, locals_, globals_):
    """
    このメソッドはモデルによって呼び出される
    前の例を使用したコールバック関数と同等
    :param locals_: (dict)
    :param globals_: (dict)
    :return: (bool)
    """
    # モデルのselfオブジェクトを取得
    self_ = locals['self']

    if self.n_calls % self.eval_freq == 0:
      # === YOU CODE HERE ===#
      # エージェントの評価 : self.eval_envを使用してself.n_eval_episodesループを実行する必要がある
      # ヒント : self_.predict(obs)を使用できる

      # S必要に応じてエージェントを保存しself.best_mean_rewardを更新

      print("Best mean reward: {:.2f}".format(self.best_mean_reward))


      # ====================== #

    self.n_calls += 1

    return True

コールバックをテストします。

# Env used for training
env = gym.make("CartPole-v1")
# Env for evaluating the agent
eval_env gym.make("CartPole-v1")

# === YOU CODE HERE ===#
# コールバックオブジェクトの作成
callback = None

# RLモデルの作成
model = None

# ====================== #

# Train the RL model
model.learn(int(100000), callback=callback)

6. 参照

・Github repo: https://github.com/araffin/rl-tutorial-jnrr19
・Stable-Baselines: https://github.com/hill-a/stable-baselines
・Documentation: https://stable-baselines.readthedocs.io/en/master/
・RL Baselines zoo: https://github.com/araffin/rl-baselines-zoo


この記事が気に入ったらサポートをしてみませんか?