見出し画像

天鳳の牌譜を学習してMjx(v0.1.0)で使えるAgentを作る

以前の記事で、麻雀の強化学習環境Mjxを触り、強化学習を試してみました。その時は、ShantenAgentの行動を教師あり学習したAgentをもとに強化学習を行っていました。

今回は、その初期Agentを天鳳の牌譜から学習してみます。

mjx-projectのリポジトリに天鳳の牌譜データをMjxのStateの形式に変換するスクリプトが用意されており、天鳳の牌譜データやMjxのObservationのデータ形式の知識がなくても、簡単に実施できます。

牌譜データ

天鳳の牌譜は公式ページのリンクからダウンロードできます。
ランキング上位のユーザーについて、ユーザごとに過去の全牌譜をまとめたzipファイルのリンクが掲載されています。
いくつか選んでダウンロードします。

Mjxで使えるjson形式への変換

zipファイルを解凍すると、1半荘ごとのmjlogファイルが出てきます。
これを、mjx-convertを使ってjsonファイルに変換します。
環境はcolabを使いました。
集めたmjlogをmjlog_dirフォルダに入れ、出力用にoutフォルダを作っておきます。

!git clone https://github.com/mjx-project/mjx-convert.git
%cd mjx-convert/
!make install
!mjxc convert ./mjlog_dir ./out --to-mjxproto

git cloneしてインストールした後、mjxcコマンドにmjlogの入ったフォルダと、jsonを出力するフォルダを指定すると、全ファイルを一括で変換できます。

ObservationとActionペアへの変換

こうして作ったjsonファイルは、1行が1局のjson文字列になっています。
このファイルは10行あるので、10局で1半荘だったことになります。

mjlogファイルを変換したjsonファイル

この1行分のデータは、MjxのStateインスタンス(終局時)と同じ情報を持っています。この文字列から終局時のStateインスタンスを作成できます。
そして、Stateインスタンスの._cpp_obj.past_decisionsメソッドを使うことで、この1局の初めから終局までの観測とアクションの全ての履歴をループで取得できます。

from mjx import Observation, State, Action 
with open(path) as f:
  lines = f.readlines()

  for line in lines:
    state = State(line)

    for cpp_obs, cpp_act in state._cpp_obj.past_decisions():
      obs = Observation._from_cpp_obj(cpp_obs)
      feature = obs.to_features(feature_name="mjx-small-v0")

      action = Action._from_cpp_obj(cpp_act)
      action_idx = action.to_idx()

ループで得られるのは観測とアクションのcppのオブジェクトです。
対応するPythonのクラスの_from_cpp_objクラスメソッドを使うことでPythonのインスタンスにしています。

その後、obs.to_featuresと、action.to_idxを使うことでモデルの学習に使えるnumpy形式に変換してファイルを出力しました。
この処理はそこそこ重く、1半荘を変換して保存するのに30秒~1分程度かかりました。

教師あり学習

前回の記事と同じように教師あり学習を行いました。前回とほぼ同じなので詳細は割愛します。
ネットワークは隠れ1層のシンプルな構成にしており、6000半荘から作った、約400万の観測-アクションペアデータをバッチサイズ1024で10epoch学習しました。(GPU1枚で数分)
網羅的な実験はしていませんが、バッチサイズ大きめが良い傾向で、hidden_sizeや隠れ層の層数はあまりtest_lossに影響しませんでした。

from torch import optim, nn, utils, Tensor
import pytorch_lightning as pl
import torch
from torch import nn
class MLP(pl.LightningModule):
    def __init__(self, obs_size=544, n_actions=181, hidden_size=544):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions),
        )
        self.loss_module = nn.CrossEntropyLoss()

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self.forward(x)
        loss = self.loss_module(preds, y)
        self.log("train_loss", loss)
        return loss

    def forward(self, x):
        return self.net(x.float())

前回はShantenAgentの行動を学習しましたが、今回はうまいプレイヤーの行動を学習しているので、より強くなっているはずです。

ルールベースのエージェントとの対戦

前回と同様ShantenAgent、MenzenAgent(鳴かないShantenAgent)、RandomAgentと100局対戦させて比較してみました。

前回:1着:9%、2着:16%、3着:56%、4着:19%
今回:1着:49%、2着:24%、3着:24%、4着:3%

半分ぐらい1着が取れており、かなり強くなっています!
もう私より強そうです。いくつかの局面での判断を見てみましょう。

局面の判断サンプル

Mjxの可視化機能を使うと、Agentの判断を可視化して確認できます。

4筒を引いたタイミング

アクション:ActionType.DISCARD TileType.P1
4筒を引いて、1筒を切っています。
Actionインスタンスの中身は以下のようにするとアクションのタイプと牌の種類を表示できます。

print(action.type())
if action.tile() is not None:
  print(action.tile().type())
次順、3筒を引いたタイミング

アクション:ActionType.DISCARD TileType.P2
その後、3筒を引いて、2筒を切る判断。
私の雀力ではどの選択が良いのかわからないので、何切るシミュレータで見てみます。

https://pystyle.info/apps/mahjong-nanikiru-simulator/

3索か5索のほうが受け入れが多いですが、今回は手牌と捨牌で待ちがかなり見えていて2筒を切るほうが枚数が多そうです。
また、自分で8筒を切っていてフリテンになる可能性があることも考慮しているかもしれません。

一向聴から一向聴に待ち変えするケースもかなりできていそうです。

ニューラルネットの判断根拠を可視化する手法を使って、こういったタイミングにfeatureのどの部分に着目して判定しているかの可視化も行ってみたいですね。


この記事が気に入ったらサポートをしてみませんか?