見出し画像

麻雀の強化学習環境Mjx(v0.1.0)を触る(3/3)強化学習編

前回の記事で、ShantenAgentの入出力を保存し、教師あり学習でモデルを作成し動作確認しました。
今回は、強化学習でこのモデルをさらに学習していきます。

MjxのリポジトリにExampleが公開されています。
https://github.com/mjx-project/mjx/blob/master/examples/rl_gym.py

MjxEnvをOpenAI Gymと同じIFで使うためのGymEnvクラスと、報酬を使ってモデルのパラメタを更新するREINFORCEクラスが実装されています。
examples/rl_gym.pyには、中間層1層のMLPを初期化して訓練する例が実装されていますが、今回は前の記事で教師あり学習したモデルを学習します。

まずは、RandomAgentに対して100戦した勝率を見ておきます。

import mjx
from mjx.agents import RandomAgent

random_agent = RandomAgent()
env = GymEnv(
    opponent_agents=[random_agent, random_agent, random_agent],
    reward_type="game_tenhou_7dan",
    done_type="game",
    feature_type="mjx-small-v0",
)

model = MLP()
model.load_state_dict(torch.load('./model_shanten_100.pth'))

opt = optim.Adam(model.parameters(), lr=1e-3)
agent = REINFORCE(model, opt)

# 報酬(=順位)をカウントする
rank_counter = Counter()

for i in range(100):
    obs, info = env.reset()
    done = False
    while not done:
        a = agent.act(obs, info["action_mask"])
        obs, r, done, info = env.step(a)
    rank_counter[r] += 1

rank_counter
# => Counter({90: 96, 45: 3, 0: 1})

100半荘の報酬の内訳は{90: 96, 45: 3, 0: 1}となり、RandomAgentに対してはほとんど1着が取れています。

次に、相手を少し強めに変更してみます。

Mjxに実装されているShantenAgentと、このShantenAgentからポン・チー・ミンカンをしないようにしたMenzenAgentを実装しておきます。
ShantenAgentは役がなくなる鳴きも行うので、このMenzenAgentのほうがアガリが出やすいと思います。

from mjx import ActionType, Agent, Observation, Action
class MenzenAgent(Agent):
    def __init__(self) -> None:
        super().__init__()

    def act(self, observation: Observation) -> Action:

        # (略)

        # if it can apply chi/pon/open-kan, pass
        steal_actions = [
            a
            for a in legal_actions
            if a.type() in [ActionType.CHI, ActionType.PON, ActionType, ActionType.OPEN_KAN]
        ]

        # 鳴きができる場合、代わりにPASSを選択する
        if len(steal_actions) >= 1:
            pass_action = [a for a in legal_actions if a.type() == ActionType.PASS][0]
            return pass_action

        # (略)

MenzenAgent、ShantenAgent、RandomAgentの3人と対戦させてみます。

import mjx
from mjx.agents import RandomAgent, ShantenAgent

agent1 = MenzenAgent()
agent2 = ShantenAgent()
agent3 = RandomAgent()

env = GymEnv(
    opponent_agents=[agent1, agent2, agent3],
    reward_type="game_tenhou_7dan",
    done_type="game",
    feature_type="mjx-small-v0",
)

model = MLP()
model.load_state_dict(torch.load('./model_shanten_100.pth'))

opt = optim.Adam(model.parameters(), lr=1e-3)
agent = REINFORCE(model, opt)

rank_counter = Counter()

for i in range(100):
    obs, info = env.reset()
    done = False
    while not done:
        a = agent.act(obs, info["action_mask"])
        obs, r, done, info = env.step(a)
    rank_counter[r] += 1
rank_counter
# => Counter({0: 56, -135: 19, 45: 16, 90: 9})

100半荘の報酬の内訳は{0: 56, -135: 19, 45: 16, 90: 9}で、RandomAgentがいるわりに4着が19局と多く見えますが、1着が取れることもあるようです。

では、この状態からしばらく学習して、MenzenAgentやShantenAgentに勝てるようになっていくかを見てみます。

for i in range(10000):
    obs, info = env.reset()
    done = False
    R = 0
    while not done:
        a = agent.act(obs, info["action_mask"])
        obs, r, done, info = env.step(a)
        R += r
    agent.update_gradient(R)
半荘数(横軸)に対し、平均報酬が減っていく様子

はじめは平均的な順位が取れていましたが、数千局の学習で平均報酬が-60あたりで落ち着くようになりました。
4着が-135、3着が0ポイントなので、4着と3着を半々ぐらいでとっている状態です。
はじめはShantenAgent程度の性能があったはずですが、強化学習によって対戦相手の中で最弱のRandomAgent並みの性能になってしまったようです(悲しい)。

少しでもうまくいっていればfeatureの設計やモデル構造の工夫に進もうと思っていましたが、強化学習での報酬の作り方やパラメタ更新の方法を工夫する必要がありそうです。

Mjxはまだバージョン0.1.0で開発も継続的に行われているようなので、アップデートに期待しつつまた試してみたいと思います。

この記事が気に入ったらサポートをしてみませんか?