麻雀の強化学習環境Mjx(v0.1.0)を触る(2/3)教師あり学習編
前回の記事では、Mjxの環境でゲームを進行させ、報酬を取得するところまで見てみました。
今回は、学習できるパラメタを持ったモデルを用意して実際に学習をしてみます。
学習可能なAgentの実装
Observationの確認
Agentの入力になるObservation(観測)には、プレイヤーの席順、点数、手配、捨て牌など様々な情報が含まれています。
Mjxには、観測を機械学習モデルで扱いやすい行列にするメソッドto_featuresが用意されています。
obs_dict["player_1"].to_features(feature_name="mjx-small-v0")
array([[0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1],
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)
# shape = ((16, 34))
この行列が↓の画像に対応しています。手前のプレイヤーの観測なので他家の手牌は隠れています。
実装はこのファイルにあるようです。
https://github.com/mjx-project/mjx/blob/master/include/mjx/internal/observation.cpp#L147
Actionの確認
Actionには0~180のidxが割り振られており、action.to_idx()で取得できるようです。このidxを予測するモデルを作ればよさそうです。
idxの割り当てはこの実装を見るとわかりそうです。
https://github.com/mjx-project/mjx/blob/master/include/mjx/internal/action.cpp#L162
また、現在のタイミングで選ぶことのできるActionのmaskを取得するメソッドも用意されています。
AgentのActionを決める方法としては、idxの確率分布をモデルで予測した後に、このmaskをかけてargmaxをすると良さそうです。
obs_dict["player_0"].action_mask()
# array([1., 1., 1., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
actions["player_0"].to_idx()
# 10
初期モデルを作る(教師データ作成)
強化学習のみでランダムな状態からモデルを学習することは難しいので、ShantenAgentのログを使って教師あり学習したモデルを用意します。
100半荘動かし、観測とアクションのペアデータを約10万件作成しファイルに出力しました。
import mjx
from mjx.agents import ShantenAgent
import json
agent = ShantenAgent()
env = mjx.MjxEnv()
obs_dict = env.reset()
obs_hist = []
action_hist = []
for j in range(100):
while not env.done():
actions = {}
for player_id, obs in obs_dict.items():
legal_actions = obs.legal_actions()
action = agent.act(obs)
actions[player_id] = action
# 選択できるアクションが複数ある場合、obsとactionを保存する
if len(legal_actions) > 1:
obs_hist.append(obs.to_features(feature_name="mjx-small-v0").ravel())
action_hist.append(action.to_idx())
obs_dict = env.step(actions)
env.reset()
# ファイルに書き出し
np.save("shanten_obs.npy", np.stack(obs_hist))
np.save("shanten_actions.npy", np.array(action_hist, dtype=np.int32))
初期モデルを作る(教師あり学習)
適当ですが、中間層が一つあるニューラルネットワークのモデルを用意します。
import torch
from torch import optim, nn, utils, Tensor
import pytorch_lightning as pl
class MLP(pl.LightningModule):
def __init__(self, obs_size=544, n_actions=181, hidden_size=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, n_actions),
)
self.loss_module = nn.CrossEntropyLoss()
def training_step(self, batch, batch_idx):
x, y = batch
preds = self.forward(x)
loss = self.loss_module(preds, y)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def forward(self, x):
return self.net(x.float())
データを読み込んで、Loaderを作ります。
import numpy as np
inps = np.load("./shanten_obs.npy")
tgts = np.load("./shanten_actions.npy")
from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(torch.Tensor(inps), torch.LongTensor(tgts))
loader = DataLoader(dataset, batch_size=2)
モデルをインスタンス化して、TrainerにLoaderを渡して訓練します。
model = MLP()
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model=model, train_dataloaders=train_loader)
torch.save(model.state_dict(), './model_shanten_100.pth')
Tensorboardでtrain_lossの値を確認すると、一応学習してくれているようです。
モデルの動作確認
初期モデルができたので、これをMjxのAgentから使えるように、モデルの出力からActionを作るコードを実装します。
作成したモデルの予測にmaskを掛けて、選べるアクションの中からargmaxで選んでみます。
アクションのidxからActionオブジェクトを作るときは、Mjxに用意されているmjx.Action.select_fromが使えます。
import random
from mjx import Agent, Observation, Action
class MLPAgent(Agent):
def __init__(self) -> None:
super().__init__()
def act(self, observation: Observation) -> Action:
legal_actions = observation.legal_actions()
if len(legal_actions) == 1:
return legal_actions[0]
# 予測
feature = observation.to_features(feature_name="mjx-small-v0")
with torch.no_grad():
action_logit = model(Tensor(feature.ravel()))
action_proba = torch.sigmoid(action_logit).numpy()
# アクション決定
mask = observation.action_mask()
action_idx = (mask * action_proba).argmax()
return mjx.Action.select_from(action_idx, legal_actions)
ShanteAgentの出力で教師あり学習したMLPAgentを使って1半荘回したところ、ちゃんとアガれているようなのである程度学習できていそうです。
次の記事で、このモデルを強化学習で学習していきます。
麻雀の強化学習環境Mjx(v0.1.0)を触る(3/3)強化学習編
https://note.com/oshizo/n/n2477c103609e
この記事が気に入ったらサポートをしてみませんか?