StanとRでベイズ統計モデリングをPyMC Ver.5で写経~第11章「11.4 Latent Dirichlet Allocation」
第11章「離散値をとるパラメータを使う」
書籍の著者 松浦健太郎 先生
この記事は、テキスト第11章「離散値をとるパラメータを使う」の 11.4節「Latent Dirichlet Allocation」の PyMC5写経 を取り扱います。
周辺化消去をせず、2つのカテゴリカル分布を重ねましたが、長時間に渡るMCMCの結果、収束できませんでした。
はじめに
StanとRでベイズ統計モデリングの紹介
この記事は書籍「StanとRでベイズ統計モデリング」(共立出版、「テキスト」と呼びます)のベイズモデルを用いて、PyMC Ver.5で「実験的」に写経する翻訳的ドキュメンタリーです。
テキストは、2016年10月に発売され、ベイズモデリングのモデル式とプログラミングに関する丁寧な解説とモデリングの改善ポイントを網羅するチュートリアル「実践解説書」です。もちろん素晴らしいです!
「アヒル本」の愛称で多くのベイジアンに愛されてきた書籍です!
テキストに従ってStanとRで実践する予定でしたが、RのStan環境を整えることができませんでした(泣)
そこでこのシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。
引用表記
この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「StanとRでベイズ統計モデリング」初版第13刷、著者 松浦健太郎、共立出版
記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!
PyMC環境の準備
Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。
11.4 Latent Dirichlet Allocation
モデリングの準備
インポート
### インポート
# 数値・確率計算
import pandas as pd
import numpy as np
# PyMC
import pymc as pm
import arviz as az
# 描画
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.family'] = 'Meiryo'
# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')
データの読み込みと確認
サンプルコードのデータを読み込みます。
### データの読み込み ◆データファイル11.6 data-lda.txt
# 顧客がどの商品を購入したかを記録した架空データ。PersonID: 顧客ID, ItemID:商品ID
data = pd.read_csv('./data/data-lda.txt')
print('data.shape: ', data.shape)
display(data.head())
【実行結果】
軽くデータの外観を確認します。
各変数のユニークな要素数を確認します。
### 各変数のユニークな要素数のカウント
data.nunique()
【実行結果】
顧客数は 50、アイテム数は 112 です。
要約統計量を算出します。
### 要約統計量の表示
data.describe().round(1)
【実行結果】
アイテムIDは1から始まっていないようです。
11.4.1 解析の目的とデータの分布の確認
商品と顧客のクロス集計表を可視化します。
テキスト図11.5に相当します。
### 商品と顧客のクロス集計表を可視化 ◆図11.5
# クロス集計表の作成
sum_df = (data.groupby(['PersonID', 'ItemID'])['ItemID'].count().to_frame()
.rename({'ItemID': 'count'}, axis=1).reset_index())
# 散布図の描画
sns.scatterplot(data=sum_df, x='ItemID', y='PersonID', hue='count',
palette='ch:s=0.25, rot=-0.25')
# 修飾
plt.legend(bbox_to_anchor=(1, 1), title='購入数')
plt.grid(lw=0.5)
【実行結果】
顧客ごとの購入回数と商品ごとの購入回数のヒストグラムを描画します。
テキスト図11.6に相当します。
### 顧客ごとの購入回数/商品ごとの購入回数のヒストグラム ◆図11.6
# 描画領域の設定
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
# 顧客ごとの購入回数のヒストグラムの描画
sns.histplot(data.groupby('PersonID')['PersonID'].count(), bins=25, ec='white',
ax=ax1)
ax1.set_title('顧客ごとの購入回数')
ax1.grid(lw=0.5)
sns.histplot(data.groupby('ItemID')['ItemID'].count(), bins=25, ec='white',
ax=ax2)
ax2.set_title('商品ごとの購入回数')
ax2.grid(lw=0.5)
plt.tight_layout();
【実行結果】
11.4.4 Rでシミュレーション
Pythonでシミュレーションします。
次の2つの確率的なプロセスを通じて、顧客がどの商品を選ぶかをシミュレートします。
顧客ごとに異なる6面のサイコロを振って、6つの商品タグのいずれかを特定する(確率$${\theta}$$)。
特定した商品タグに相当する120面のサイコロを振って、120の商品のいずれかを特定する(確率$${\phi}$$)。
詳細はぜひテキストをお読みください。
11.4.2項と11.4.3項でデータが生成されるメカニズムとモデル式を丁寧に説明されています。
### シミュレーション ◆sim-model11-8.R
## 設定
# 初期値設定
N = 50 # 顧客数
I = 120 # 商品数
K = 6 # タグ数(K面サイコロ)
# 乱数生成器の設定
rng = np.random.default_rng(seed=123)
# K面サイコロ θ[n]の作成
alpha0 = np.repeat(0.8, repeats=K) # 6面
theta = rng.dirichlet(alpha=alpha0, size=N) # 6面✕50人
# I面サイコロ φ[dice]の作成
alpha1 = np.repeat(0.2, repeats=I) # 120面
phi = rng.dirichlet(alpha=alpha1, size=K) # 120面✕6タグ
# 顧客ごとの購入回数(イベント数) ※対数正規分布で生成
num_items_by_n = (np.exp(rng.normal(loc=2.0, scale=0.5, size=N))
.round().astype(int))
## シミュレーションの実行
# シミュレーション結果を格納するデータフレームの準備
d = pd.DataFrame()
# 顧客ごとにItemIDを乱数から取得する処理を繰り返し実施
for n in range(N):
# K面サイコロを振って商品のタグを取得(購入回数分)
z = rng.choice(a=range(K), size=num_items_by_n[n], replace=True, p=theta[n])
# I面サイコロを振って商品IDを特定(購入回数分)
item = [rng.choice(a=range(I), size=1, p=phi[k]) for k in z]
# データフレームに顧客IDと商品IDを追加
d = pd.concat([d, pd.DataFrame({'PersonID': n+1,
'ItemID': np.array(item).flatten() + 1})],
axis=0)
display(d)
【実行結果】
シミュレーションデータを用いて、商品と顧客のクロス集計表を可視化します。
### 商品と顧客のクロス集計表を可視化 シミュレーション
# クロス集計表の作成
sum_df2 = (d.groupby(['PersonID', 'ItemID'])['ItemID'].count().to_frame()
.rename({'ItemID': 'count'}, axis=1).reset_index())
# 散布図の描画
sns.scatterplot(data=sum_df2, x='ItemID', y='PersonID', hue='count',
palette='ch:s=0.25, rot=-0.25')
# 修飾
plt.legend(bbox_to_anchor=(1, 1), title='購入数')
plt.grid(lw=0.5)
【実行結果】
顧客ごとの購入回数と商品ごとの購入回数のヒストグラムを描画します。
### 顧客ごとの購入回数/商品ごとの購入回数のヒストグラム シミュレーション
# 描画領域の設定
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
# 顧客ごとの購入回数のヒストグラムの描画
sns.histplot(d.groupby('PersonID')['PersonID'].count(), bins=25, ec='white',
ax=ax1)
ax1.set_title('顧客ごとの購入回数')
ax1.grid(lw=0.5)
sns.histplot(d.groupby('ItemID')['ItemID'].count(), bins=25, ec='white', ax=ax2)
ax2.set_title('商品ごとの購入回数')
ax2.grid(lw=0.5)
plt.tight_layout();
【実行結果】
11.4.5 Stanで実装
PyMCのモデル定義
PyMCでモデル式11-8を実装します。
周辺化消去は行わず、式(11.13)と式(11.14)を用いています。
モデルの定義です。
### モデルの定義 ◆モデル式11-8(式(11.13), 式(11.14)) model11-8.stan
# 周辺化消去を行わず、パラメータに離散型分布を適用
## 分析に用いる観測データ
data_obs = data
## 初期値設定
N = 50 # 顧客数
K = 6 # タグ数(K面サイコロ)
I = 120 # 商品アイテム数
alpha = np.repeat(0.5, I) # phiのディリクレ分布のパラメータalpha
## モデルの定義
with pm.Model() as model:
### データ関連定義
## coordの定義
model.add_coord('data', values=data_obs.index, mutable=True)
model.add_coord('person', values=sorted(data_obs['PersonID'].unique()),
mutable=True)
model.add_coord('item', values=range(1, I+1), mutable=True)
model.add_coord('Kdice', values=range(1, K+1), mutable=True)
## dataの定義
# 目的変数 Y 購入した商品ID 0始まり
Y = pm.ConstantData('Y', value=data_obs['ItemID'].values - 1, dims='data')
# 顧客インデックス personIdx 0始まり
personIdx = pm.ConstantData('personIdx', value=data_obs['PersonID'].values - 1,
dims='data')
### 事前分布
theta = pm.Dirichlet('theta', a=np.ones(K), dims=('person', 'Kdice'))
phi = pm.Dirichlet('phi', a=alpha, dims=('Kdice', 'item'))
dice = pm.Categorical('dice', p=theta[personIdx], dims='data')
### 尤度関数 カテゴリカル分布
obs = pm.Categorical('obs', p=phi[dice], observed=Y, dims='data')
モデルの定義内容を見ます。
### モデルの表示
model
【実行結果】
### モデルの可視化
pm.model_to_graphviz(model)
【実行結果】
MCMCの実行と収束確認
MCMCを実行します。
約2時間40分かかりました。
### 事後分布からのサンプリング ◆run-model11-8.R 162分38秒
with model:
idata = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
random_seed=1234)
【実行結果】
Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。
### r_hat>1.1の確認
# 設定
idata_in = idata # idata名
threshold = 1.1 # しきい値
# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())
【実行結果】
$${\hat{R}>1.1}$$のパラメータが多数存在するため、収束条件を満たしていません。
事後統計量を表示します。
### 推論データの要約統計情報の表示
pm.summary(idata, hdi_prob=0.95, round_to=3)
【実行結果】
r_hat の値は 1.1 前後であり、もう少し何かを調整すると収束条件を満たすかもしれません。
トレースプロットを描画します。
パラメータの一部です。
### トレースプロットの表示
pm.plot_trace(idata, compact=False)
plt.tight_layout();
【実行結果】
発散を示すバーコードのような黒い線が多数表示されています。
推論結果の解釈
収束していないので、MCMCサンプルを分析できる状態ではありませんが、コードの雰囲気を確認する目的で、テキストのグラフ描画に取り組みます。
テキスト図11.11の左のコードを実装してみます。
### タグごとに商品の出現確率φ[k,y]の中央値を棒グラフで可視化 ◆図11.11左
# 推論データからφのMCMCサンプルデータを取り出し
phi_samples = idata.posterior.phi.stack(sample=('chain', 'draw')).data
# 描画領域の設定
fig, ax = plt.subplots(2, 3, figsize=(10, 6), sharex=True, sharey=True)
# 中央値の算出と横棒グラフの描画の処理をタグ数繰り返す
for i, sample in enumerate(phi_samples):
# 中央値の算出
sample_median = np.median(sample, axis=1)
# 表示axesの算出
pos = divmod(i, 3)
# 横棒グラフの描画
ax[pos].barh(range(1, sample.shape[0]+1), sample_median)
# 修飾
ax[pos].set(title=f'タグ {i+1}')
ax[pos].grid(lw=0.5)
# 全体修飾
fig.supxlabel('phi[k,y]')
fig.supylabel('ItemID')
plt.tight_layout();
【実行結果】
続いて図11.11の右のコードを実装してみます。
### θ[1,k]とθ[50,k]の中央値と50%CIを棒グラフ・エラーバーで可視化 ◆図11.11右
# 推論データからθのMCMCサンプルデータを取り出し
theta_samples = idata.posterior.theta.stack(sample=('chain', 'draw')).data
# 描画領域の設定
fig, ax = plt.subplots(2, 1, figsize=(6, 8), sharex=True, sharey=True)
# 第1~3四分位数の算出とグラフ描画の処理を2人分繰り返す
for i, person in enumerate([0, 49]):
# 25%点, 50%点(中央値), 75%点を算出
sample_quantile = np.quantile(theta_samples[person], q=[0.25, 0.50, 0.75],
axis=1)
# 中央値の横棒グラフの描画
ax[i].barh(range(1, 7), sample_quantile[1], alpha=0.5)
# 50%CI区間のエラーバーの描画
ax[i].errorbar(y=range(1, 7), x=sample_quantile[1],
xerr=[sample_quantile[0], sample_quantile[1]], ls='',
color='tab:red', capsize=5, alpha=0.5)
# 修飾
ax[i].set(title=f'顧客 {person+1}')
ax[i].grid(lw=0.5)
# 全体修飾
fig.supxlabel('theta[n,k]')
fig.supylabel('tag')
plt.tight_layout();
【実行結果】
変分推論法
テキストはNUTSを使うと推定に時間がかかるため、変分ベイズ法の一種である ADVI を用いています。
このPyMCモデルは離散パラメータを含むため、ADVIを実行するとエラーになりました。
with model:
mean_field = pm.fit(method=pm.ADVI(), n=4000, obj_optimizer=pm.adam())
【実行結果】
(パラメータ化エラー:離散変数はVIでサポートされていません)
11.4 節は以上です。
シリーズの記事
次の記事
前の記事
目次
ブログの紹介
note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!
1.のんびり統計
統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。
2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で
書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!
3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で
書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!
4.楽しい写経 ベイズ・Python等
ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀
5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で
書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。
6.データサイエンスっぽいことを綴る
統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。
7.Python機械学習プログラミング実践記
書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。
最後までお読みいただきまして、ありがとうございました。