見出し画像

CartPoleをPytorchで実装する

強化学習を実装する際は、EnvironmentやAgentなどを別々のclassとして実装します。
これはclassに分けることによってそれぞれのclass毎でdebugすることができるようにすることが1つの理由だと聞きました。
しかし、実際にdebugする際にサンプルコードがないと困ります。
Pytorchによるサンプルコードがパッと見たところ見当たらなかったので、メモとしてここに示すことにします。(ほとんど自分のため)
扱う問題はOpenAI gymのCartPoleです。これをDeep Q-Learning(DQN)で実装します。

CartPoleの概要

https://github.com/openai/gym/wiki/Leaderboard#cartpole-v0には、次のように説明されていました。

CartPoleは摩擦のない軌道を移動するカートに、ポールが無動作ジョイントで取り付けられています。カートに+1または-1の力を加えることで制御されます。振り子は直立した状態でスタートし、倒れないようにすることが目標です。振り子が直立したままであれば、1タイムステップごとに+1の報酬を与えます。ポールが垂直から15°以上離れるか、カートが中心から2.4ユニット以上移動するとエピソードが終了します。

出典:https://github.com/openai/gym/wiki/Leaderboard#cartpole-v0


出典:https://github.com/openai/gym/wiki/Leaderboard#cartpole-v0

action, state, reward, episodeの終了など。

  • action:台を左に押す(0)か 右に押す(1)の二択です。

  • state:stateの情報として、台車の位置、台車の速度、棒の角度、棒の先端の速度の4つが得られます。

  • reward:1が与え続けられます。

  • episodeの終了条件:以下のどれかの条件を満たした場合に、
    エピソードが終了したと判定されます。CartPole_v1を用いるのでepisodeの上限は500らしいです。

    • ポールのアングルが±15°以内

    • 台車の位置が±2.4以内

    • エピソードの長さが500以上

Codeの紹介

以下にコードを示します。
結果の描写にmatplotlib、モデルの実装にPytorch、環境の実装にgymを用いています。
自分の環境では動作しましたが、他の環境では問題があるかもしれません。ご了承ください。
CPU で動作させることを前提としているため、to(device)等は書いてません。

import matplotlib.pyplot as plt
import numpy as np
import torch, copy, random, gym
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque



class ReplayBuffer:
    def __init__(self, buffer_size, batch_size):
        self.buffer = deque(maxlen=buffer_size)
        self.batch_size = batch_size

    def add(self, state, action, reward, next_state, done):
        data = (state, action, reward, next_state, done)
        self.buffer.append(data)

    def __len__(self):
        return len(self.buffer)

    def get_batch(self):
        data = random.sample(self.buffer, self.batch_size)
        state = torch.tensor([x[0] for x in data], dtype=torch.float32)
        action = np.array([x[1] for x in data])
        reward = np.array([x[2] for x in data], dtype=np.float32)
        next_state = torch.tensor([x[3] for x in data], dtype=torch.float32)
        done = np.array([x[4] for x in data], dtype=np.float32).astype(np.int32)

        return state, action, reward, next_state, done

# ニューラルネットワークの設定
# 単純なlinear層を用いる
class QNet(nn.Module):
    def __init__(self, action_size):
        super().__init__()
        # stateの情報が台車の位置、台車の速度、棒の角度、棒の先端の速度の4つなので入力は4
        # cartpoleではaction_sizeは右か左かの2つ
        self.l1 = nn.Linear(4, 128)
        self.l2 = nn.Linear(128, 128)
        self.l3 = nn.Linear(128, action_size)

    def forward(self, x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.l3(x)
        return x


class DQNAgent:
    def __init__(self):
        # ハイパーパラメータの設定
        self.gamma = 0.98
        self.lr = 0.0005
        self.epsilon = 0.1
        self.buffer_size = 10000
        self.batch_size = 32
        self.action_size = 2
        # 経験再生の設定
        self.replay_buffer = ReplayBuffer(self.buffer_size, self.batch_size)
        # モデル等の設定
        # Target Networkを実装し、学習の安定化を図る
        self.qnet = QNet(self.action_size)
        self.qnet_target = QNet(self.action_size)
        self.optimizer = optim.Adam(self.qnet.parameters(), self.lr)

    def get_action(self, state):
        # ε-greedy法によるactionの取得
        if np.random.rand() < self.epsilon:
            return np.random.choice(self.action_size)
        else:
            qs = self.qnet(torch.tensor(state))
            return qs.data.argmax().numpy()

    def update(self, state, action, reward, next_state, done):
        # replay bufferに経験をためる
        self.replay_buffer.add(state, action, reward, next_state, done)
        if len(self.replay_buffer) < self.batch_size:
            return

        # モデルの学習
        # replay bufferにためたデータからミニバッチを作成して学習に用いる
        state, action, reward, next_state, done = self.replay_buffer.get_batch()
        # Target Networkの出力値を得る
        next_qs = self.qnet_target(next_state)
        next_q = next_qs.max(axis=1).values
        # main Networkの出力値を得る
        self.qnet.train()
        qs = self.qnet(state)
        q = qs[np.arange(self.batch_size), action]

        self.optimizer.zero_grad()
        # モデルの更新
        target = torch.tensor(reward, dtype=torch.float32) + torch.mul(torch.mul(next_q, self.gamma), torch.tensor(1.0 - done))
        loss = nn.MSELoss()(q.to(torch.float32), target.to(torch.float32))

        loss.backward()
        self.optimizer.step()

    # Target NetworkとMain Networkを同期させる
    def sync_qnet(self):
        self.qnet_target = copy.deepcopy(self.qnet)

# 設定いろいろ
episodes = 300
sync_interval = 20
env = gym.make('CartPole-v1')
agent = DQNAgent()
reward_history = []
# コードを動かす
for episode in range(episodes):
    # env.reset()をするとstateとしてnext_state, reward, done, infoの4つと謎の値1つの合計5つと謎の値を含むtupleが返ってくる
    state = env.reset()
    done = False
    total_reward = 0

    while not done:
        # replay bufferでミニバッチを作成する際に型やsizeでエラーが出る
        # env.reset()で得たstateはtupleであることが問題
        # tupleの時は0番目の要素が必要な情報
        if isinstance(state, tuple):
            state = state[0]
        # actionを取得する
        action = agent.get_action(state)
        # env.step(action)をするとstateとしてnext_state, reward, done, infoの4つと謎の値1つの合計5つが返ってくる
        next_state, reward, done, info, _ = env.step(action)
        # モデルの学習を行う
        agent.update(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward

    if episode % sync_interval == 0:
        agent.sync_qnet()

    reward_history.append(total_reward)
    if episode % 10 == 0:
        print("episode :{}, total reward : {}".format(episode, total_reward))


# 結果をPlotする
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.plot(range(len(reward_history)), reward_history)
plt.show()


# 学習結果を確かめる
agent.epsilon = 0  # greedy policy
state = env.reset()
done = False
total_reward = 0

while not done:
    action = agent.get_action(state)
    next_state, reward, done, info, _ = env.step(action)
    state = next_state
    total_reward += reward
    env.render()
print('Total Reward:', total_reward)



この記事が気に入ったらサポートをしてみませんか?