見出し画像

TF-Agentsによる強化学習

TensorFlowがバージョンアップして強化学習用のライブラリ TF-Agentsが使えるようになったようだ。fastaiは強化深層学習はサポートしないそうなので、RLがしたいときにはこれを使えば良い。

ただプログラムはあまり綺麗ではなく、Pythonのバージョンも2のようだ。ChainerもRLに力を入れているようなので、比較して良い方を使うべきだろう。

SAC(Soft Actor Critic)などの新しめの手法も実装しているようで、色々比較して使いたいときにはTF-Agentsが良いだろう。

まずは色々インポートして(省略)から、定数パラメータを準備する。

env_name = 'CartPole-v0'  # @param
num_iterations = 20000  # @param
initial_collect_steps = 1000  # @param
collect_steps_per_iteration = 1  # @param
replay_buffer_capacity = 100000  # @param
fc_layer_params = (100,)
batch_size = 64  # @param
learning_rate = 1e-3  # @param
log_interval = 200  # @param
num_eval_episodes = 10  # @param
eval_interval = 1000  # @param

Gymを使って倒立振子の環境を準備する。

env = suite_gym.load(env_name)

環境を観測するには以下のようにしてobservation属性をみる。

 env.time_step_spec().observation
>>>
Observation Spec:
BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name=None, minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])

観測は、台車の現在位置と速度、振子の角度と速度である。

行動は0(左方向に動かす)と1(右方向に動かす)である。

env.action_spec()
>>>
Action Spec:
BoundedArraySpec(shape=(), dtype=dtype('int64'), name=None, minimum=0, maximum=1)

Q値(状態と行動の組を与えると価値を返す関数)はニューラルネットであり、観測と行動と層パラメータを与えて構築する。

q_net = q_network.QNetwork(
   train_env.observation_spec(),
   train_env.action_spec(),
   fc_layer_params=fc_layer_params)

最適化はAdamとし、エージェントを生成し、初期化しておく。

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
train_step_counter = tf.compat.v2.Variable(0)
tf_agent = dqn_agent.DqnAgent(
   train_env.time_step_spec(),
   train_env.action_spec(),
   q_network=q_net,
   optimizer=optimizer,
   td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
   train_step_counter=train_step_counter)

tf_agent.initialize()

次いで、方策を定義しておく。

eval_policy = tf_agent.policy
collect_policy = tf_agent.collect_policy
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                               train_env.action_spec())

評価値(価値関数)を計算する関数を準備する。

def compute_avg_return(environment, policy, num_episodes=10):
 total_return = 0.0
 for _ in range(num_episodes):
   time_step = environment.reset()
   episode_return = 0.0
   while not time_step.is_last():
     action_step = policy.action(time_step)
     time_step = environment.step(action_step.action)
     episode_return += time_step.reward
   total_return += episode_return
 avg_return = total_return / num_episodes
 return avg_return.numpy()[0]
compute_avg_return(eval_env, random_policy, num_eval_episodes)

モダンな強化学習はリプレイバッファを利用する。

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
   data_spec=tf_agent.collect_data_spec,
   batch_size=train_env.batch_size,
   max_length=replay_buffer_capacity)

最初にランダム方策を実行してリプレイバッファに溜めておく。これを用いて学習する。

def collect_step(environment, policy):
 time_step = environment.current_time_step()
 action_step = policy.action(time_step)
 next_time_step = environment.step(action_step.action)
 traj = trajectory.from_transition(time_step, action_step, next_time_step)
 # Add trajectory to the replay buffer
 replay_buffer.add_batch(traj)
for _ in range(initial_collect_steps):
 collect_step(train_env, random_policy)
dataset = replay_buffer.as_dataset(
   num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3)
iterator = iter(dataset)

最後にエージェントを訓練する。

tf_agent.train = common.function(tf_agent.train)
# Reset the train step
tf_agent.train_step_counter.assign(0)
# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
returns = [avg_return]
for _ in range(num_iterations):
 # Collect one step using collect_policy and save to the replay buffer.
 collect_step(train_env, tf_agent.collect_policy)
 # Sample a batch of data from the buffer and update the agent's network.
 experience, unused_info = next(iterator)
 train_loss = tf_agent.train(experience)
 step = tf_agent.train_step_counter.numpy()
 if step % log_interval == 0:
   print('step = {0}: loss = {1}'.format(step, train_loss.loss))
 if step % eval_interval == 0:
   avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
   print('step = {0}: Average Return = {1}'.format(step, avg_return))
   returns.append(avg_return)

5分くらい回すと冒頭に示したように訓練が行われ、振子が倒れないようになる。




この記事が気に入ったらサポートをしてみませんか?