StanとRでベイズ統計モデリングをPyMC Ver.5で写経~第9章「9.5.2 欠損値」
第9章「一歩進んだ文法」
書籍の著者 松浦健太郎 先生
この記事は、テキスト第9章「一歩進んだ文法」・9.5節「トラブルシューティング」の 9.5.2項「欠損値」の PyMC5写経 を取り扱います。
テキストは第9章で Stan の文法上の工夫を取り扱っています。
Stanの文法とPyMCの文法は異なっており、Stan の工夫点をPyMCに取り入れることができない、または、取り入れる方法が分からない場合があります。
したがいまして第9章では、私個人のスキルと相談して、PyMC化に意味を見いだせるテーマを選択して取り上げています。
選択の結果、9.5.1「int型のパラメータ」、9.5.3「Stanのエラーメッセージ」、9.5.4「print関数を使ったデバッグ」並びに「第9章 練習問題」の写経は省略しました。
はじめに
StanとRでベイズ統計モデリングの紹介
この記事は書籍「StanとRでベイズ統計モデリング」(共立出版、「テキスト」と呼びます)のベイズモデルを用いて、PyMC Ver.5で「実験的」に写経する翻訳的ドキュメンタリーです。
テキストは、2016年10月に発売され、ベイズモデリングのモデル式とプログラミングに関する丁寧な解説とモデリングの改善ポイントを網羅するチュートリアル「実践解説書」です。もちろん素晴らしいです!
「アヒル本」の愛称で多くのベイジアンに愛されてきた書籍です!
テキストに従ってStanとRで実践する予定でしたが、RのStan環境を整えることができませんでした(泣)
そこでこのシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。
引用表記
この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「StanとRでベイズ統計モデリング」初版第13刷、著者 松浦健太郎、共立出版
記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!
PyMC環境の準備
Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。
9.5.2 欠損値
Stanはデータに欠損値が含まれるとエラーになるそうです。
一方、PyMCはデータの欠損値を含めて推論できます。
PyMC写経では欠損値を削除せず、欠損値を含めた推論を行います。
なお、欠損値を含むデータの場合、私の知るところでは次の2点の制約があります。
Data に定義できません(ConstantData等のアレです)
標準のNUTSサンプラーを使用する必要があります(numpyro等が使えないです)
インポート
### インポート
# 数値・確率計算
import pandas as pd
import numpy as np
# PyMC
import pymc as pm
import pytensor.tensor as pt
import arviz as az
# 描画
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'Meiryo'
# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')
データの読み込み・確認
サンプルコードのデータを読み込みます。
### データの読み込み ◆データファイル9.4 data-conc-2-NA-wide.txt
# PersonID:個人ID, TimeT:時間Tの血中濃度
data5 = pd.read_csv('./data/data-conc-2-NA-wide.txt')
print('data5.shape: ', data5.shape)
display(data5.head())
【実行結果】
個人別の時系列折れ線グラフを描画します。
### 個人別折れ線グラフの描画
## 描画用のデータ加工
# 列名の変更:時間を数値に
data5.columns = ['PersonID', 1, 2, 4, 8, 12, 24]
# インデックスの設定
data5.set_index('PersonID')
## 描画処理
# 描画領域の設定
plt.figure(figsize=(8, 4.5))
sns.lineplot(data5.set_index('PersonID').T, markers=True, dashes=False)
plt.xlabel('$Time[hour]$', fontsize=12)
plt.ylabel('$Y$', fontsize=12)
plt.xticks([1, 2, 4, 8, 12, 24])
plt.legend(bbox_to_anchor=(1, 1.03), title='Person ID')
plt.grid(lw=0.5);
【実行結果】
PyMCのモデル定義
PyMCでモデル式9-6を実装します。
データの前処理を行います。
### データ前処理 縦持ちに変換
# データに欠損値を含めたいので、テキストのlong.txtを使わず、自作する
data5_2 = data5.copy()
data5_2.columns = ['PersonID', 1, 2, 4, 8, 12, 24]
data5_2 = data5_2.melt(id_vars=['PersonID'], value_vars=[1, 2, 4, 8, 12, 24],
var_name='Time', value_name='Y')
data5_2 = data5_2.sort_values(['PersonID', 'Time']).reset_index(drop=True)
data5_2['Time'] = data5_2['Time'].astype(int)
display(data5_2)
【実行結果】
モデルの定義です。
Y の観測データは ConstantData に設定できない点に留意します。
### モデルの定義 ◆モデル式9-6 model9-6.stan
# 時間値のリストの作成
time_cat = [1, 2, 4, 8, 12, 24]
# モデルの定義
with pm.Model() as model5:
### データ関連定義
## coordの定義
model5.add_coord('data', values=data5_2.index, mutable=True)
model5.add_coord('person', values=sorted(data5_2['PersonID'].unique()),
mutable=True)
## dataの定義
# 目的変数 Y 欠損値を含むため data に定義できない
Y = data5_2['Y'].values
# 説明変数 Time
Time = pm.ConstantData('Time', value=data5_2['Time'], dims='data')
# インデックス personIdx 0始まり化
personIdx = pm.ConstantData('personIdx', value=data5_2['PersonID'].values - 1,
dims='data')
### 事前分布
a0 = pm.Uniform('a0', lower=-10, upper=10)
b0 = pm.Uniform('b0', lower=-10, upper=10)
sigmaA = pm.Uniform('sigmaA', lower=0, upper=10)
sigmaB = pm.Uniform('sigmaB', lower=0, upper=10)
sigmaY = pm.Uniform('sigmaY', lower=0, upper=10)
logA = pm.Normal('logA', mu=a0, sigma=sigmaA, dims='person')
logB = pm.Normal('logB', mu=b0, sigma=sigmaB, dims='person')
### 線形予測子
a = pm.Deterministic('a', pt.exp(logA), dims='person')
b = pm.Deterministic('b', pt.exp(logB), dims='person')
mu = pm.Deterministic('mu',
a[personIdx]*(1 - pt.exp(-b[personIdx]*Time)),
dims='data')
### 尤度関数
obs = pm.Normal('obs', mu=mu, sigma=sigmaY, observed=Y, dims='data')
モデルの定義内容を見ます。
欠損値を含む場合、アンダースコアを含む変数が自動作成されるので、KaTexエラーが発生しました。
### モデルの表示
model5.basic_RVs
【実行結果】
### モデルの可視化
pm.model_to_graphviz(model5)
【実行結果】
obs_unobserved は Y の欠損値に相当する変数です。
MCMCの実行と収束確認
MCMCを実行します。
### 事後分布からのサンプリング 4分30秒 ※numpyro未使用
with model5:
idata5 = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.95,
random_seed=1234)
【実行結果】
Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。
### r_hat>1.1の確認
# 設定
idata_in = idata5 # idata名
threshold = 1.01 # しきい値
# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())
【実行結果】
収束条件を満たしています。
事後統計量を表示します。
### 推論データの要約統計情報の表示
var_names = ['a0', 'b0', 'sigmaA', 'sigmaB', 'sigmaY', 'logA', 'logB']
pm.summary(idata5, hdi_prob=0.95, var_names=var_names, round_to=3)
【実行結果】
### 推論データの要約統計情報の表示
var_names = ['a', 'b', 'mu']
pm.summary(idata5, hdi_prob=0.95, var_names=var_names, round_to=3)
【実行結果】
トレースプロットを描画します。
### トレースプロットの表示
var_names = ['a0', 'b0', 'sigmaA', 'sigmaB', 'sigmaY', 'logA', 'logB',
'a', 'b', 'mu']
pm.plot_trace(idata5, compact=True, var_names=var_names)
plt.tight_layout();
【実行結果】
推定結果の解釈
事後分布の要約統計量を算出します。
算出関数を定義します。
### mean,sd,2.5%,25%,50%,75%,97.5%パーセンタイル点をデータフレーム化する関数の定義
def make_stats_df(y):
probs = [2.5, 25, 50, 75, 97.5]
columns = ['mean', 'sd'] + [str(s) + '%' for s in probs]
quantiles = pd.DataFrame(np.percentile(y, probs, axis=0).T, index=y.columns)
tmp_df = pd.concat([y.mean(axis=0), y.std(axis=0), quantiles], axis=1)
tmp_df.columns=columns
return tmp_df
要約統計量を算出します。
### パラメータの要約統計量を確認
vars = ['a0', 'b0', 'sigmaA', 'sigmaB', 'sigmaY']
param_samples = idata5.posterior[vars].to_dataframe().reset_index(drop=True)
display(make_stats_df(param_samples).round(2))
【実行結果】
テキストに事後分布の推定値が掲載されていないので、PyMCモデルによる推論の適否は不明です。
欠損値があった個人の$${\mu}$$のベイズ信用区間の推移を描画します。
テキスト図9.2に相当します。
### muの事後分布の描画 ◆図9.2
# muの事後分布データの取り出し
mu_samples = (idata5.posterior.mu.stack(sample=('chain', 'draw'))
.data.reshape(16, 6, 4000))
# 描画領域の設定
fig, ax = plt.subplots(2, 2, figsize=(7, 5), sharex=True, sharey=True)
# 4人分の描画を繰り返し処理
for i, person in enumerate([0, 1, 2, 15]):
# muの事後分布の中央値と95%CIを算出
mu_median = np.median(mu_samples[person], axis=1)
mu_95ci = np.quantile(mu_samples[person], q=[0.025, 0.975], axis=1)
# axesの値を算出
pos = divmod(i, 2)
# Yの観測値の描画
ax[pos].plot(time_cat, data5.iloc[person, 1:], 'o', color='tab:blue',
label='$Y$ の観測値')
# muの事後分布・中央値の折れ線グラフの描画
ax[pos].plot(time_cat, mu_median, color='tab:red',
label='$\mu$ の事後分布:中央値')
# muの事後分布・95%CIの折れ線グラフの描画
ax[pos].fill_between(time_cat, mu_95ci[0], mu_95ci[1], color='tomato',
alpha=0.2, label='$\mu$ の事後分布:95%区間')
# 修飾
ax[pos].set(title=f'ID:{i+1}', xticks=time_cat)
ax[pos].grid(lw=0.5)
ax[1, 1].legend()
# 全体修飾
fig.supxlabel('Time[hour]')
fig.supylabel('Y : 血中濃度[mg/mL]')
plt.tight_layout();
【実行結果】
テキストに近い結果になったと思います!
9.5.2 項は以上です。
シリーズの記事
次の記事
前の記事
目次
ブログの紹介
note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!
1.のんびり統計
統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。
2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で
書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!
3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で
書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!
4.楽しい写経 ベイズ・Python等
ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀
5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で
書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。
6.データサイエンスっぽいことを綴る
統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。
7.Python機械学習プログラミング実践記
書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。
最後までお読みいただきまして、ありがとうございました。