StanとRでベイズ統計モデリングをPyMC Ver.5で写経~第9章「9.3.2 行列演算を使った重回帰」
第9章「一歩進んだ文法」
書籍の著者 松浦健太郎 先生
この記事は、テキスト第9章「一歩進んだ文法」・9.3節「ベクトルや行列の数学的性質の利用」の 9.3.2項「行列演算を使った重回帰」の PyMC5写経 を取り扱います。
テキストは第9章で Stan の文法上の工夫を取り扱っています。
Stanの文法とPyMCの文法は異なっており、Stan の工夫点をPyMCに取り入れることができない、または、取り入れる方法が分からない場合があります。
したがいまして第9章では、私個人のスキルと相談して、PyMC化に意味を見いだせるテーマを選択して取り上げています。
はじめに
StanとRでベイズ統計モデリングの紹介
この記事は書籍「StanとRでベイズ統計モデリング」(共立出版、「テキスト」と呼びます)のベイズモデルを用いて、PyMC Ver.5で「実験的」に写経する翻訳的ドキュメンタリーです。
テキストは、2016年10月に発売され、ベイズモデリングのモデル式とプログラミングに関する丁寧な解説とモデリングの改善ポイントを網羅するチュートリアル「実践解説書」です。もちろん素晴らしいです!
「アヒル本」の愛称で多くのベイジアンに愛されてきた書籍です!
テキストに従ってStanとRで実践する予定でしたが、RのStan環境を整えることができませんでした(泣)
そこでこのシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。
引用表記
この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「StanとRでベイズ統計モデリング」初版第13刷、著者 松浦健太郎、共立出版
記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!
PyMC環境の準備
Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。
9.3.2 行列演算を使った重回帰
線形予測子に行列演算(行列積、ドット積)を用いてモデリングします。
インポート
### インポート
# 数値・確率計算
import pandas as pd
import numpy as np
# PyMC
import pymc as pm
import pytensor.tensor as pt
import arviz as az
# 描画
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.family'] = 'Meiryo'
# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')
データの読み込み・確認
サンプルコードのデータを読み込みます。
### データの読み込み ◆データファイル9.2 data-attendance-5.txt
# A:バイト好き区分(1:好き), Score:学問の興味の強さ(0~200), X3:説明変数3,
# X4:説明変数4, Y:1年間の出席率
data2 = pd.read_csv('./data/data-attendance-5.txt')
print('data2.shape: ', data2.shape)
display(data2.head())
【実行結果】
データの外観を確認しましょう。
まずは要約統計量から。
### 要約統計量の表示
data2.describe().round(2)
【実行結果】
続いて相関係数。
### 相関係数の表示
data2.corr().round(2)
【実行結果】
アルバイト好き区分 A と年間出席率 Y (目的変数)に強めの負の相関があります。
散布図行列で可視化して変数間の相関関係を見てみましょう。
HUE に A を指定したので、A 単体のカラムは表示されず、各変数の散布図・ヒストグラムに A の分類が明示されます。
### 相関行列の描画
sns.pairplot(data=data2, hue='A', diag_kind='hist',
plot_kws={'s': 100, 'alpha': 0.5}, diag_kws={'ec': 'white'})
plt.show()
【実行結果】
青色はアルバイト好きでない、オレンジ色はアルバイト好き、です。
Y の行を見ると、アルバイト好きなオレンジ点が下方に、アルバイトが好きでない青い点が上方になる傾向が見られます。
PyMCのモデル定義
PyMCでモデル式9-3を実装します。
データの前処理を行います。
Score を 1/200 にスケール変更し、定数1の切片用変数 const を追加します。
### データの前処理
# data2のコピーを作成
data2_2 = data2.copy()
# Scoreのスケールを1/200に変換
data2_2['Score'] = data2_2['Score'] / 200
# 説明変数に定数(切片用)を追加
data2_2['const'] = 1
【実行結果】なし
モデルの定義です。
テキストの「行列演算」に合わせて、線形予測子 mu の計算式に行列積 dot を用いています。
行列演算をバラすと次のような計算式になります。
### モデルの定義 ◆モデル式9-3 model9-3.stan
with pm.Model() as model2:
### データ関連定義
## coordの定義
model2.add_coord('data', values=data2_2.index, mutable=True)
model2.add_coord('param', values=['const', 'A', 'Score', 'X3', 'X4'],
mutable=True)
## dataの定義
# 目的変数 Y
Y = pm.ConstantData('Y', value=data2_2['Y'].values, dims='data')
# 説明変数 X
X = pm.ConstantData(
'X',
value=data2_2[['const', 'A', 'Score', 'X3', 'X4']].values,
dims=('data', 'param'))
### 事前分布
# 重回帰の係数
b = pm.Uniform('b', lower=-10, upper=10, dims='param')
# 尤度関数の標準偏差
sigma = pm.Uniform('sigma', lower=0, upper=10)
# 線形予測子
mu = pm.Deterministic('mu', pt.dot(X, b), dims='data')
### 尤度関数
obs = pm.Normal('obs', mu=mu, sigma=sigma, observed=Y, dims='data')
モデルの定義内容を見ます。
### モデルの表示
model2
【実行結果】
### モデルの可視化
pm.model_to_graphviz(model2)
【実行結果】
MCMCの実行と収束確認
MCMCを実行します。
### 事後分布からのサンプリング 15秒
with model2:
idata2 = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
nuts_sampler='numpyro', random_seed=1234)
【実行結果】省略
Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。
### r_hat>1.1の確認
# 設定
idata_in = idata2 # idata名
threshold = 1.01 # しきい値
# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())
【実行結果】
収束条件を満たしています。
事後統計量を表示します。
### 推論データの要約統計情報の表示
pm.summary(idata2, hdi_prob=0.95, round_to=3)
【実行結果】
トレースプロットを描画します。
### トレースプロットの表示
pm.plot_trace(idata2, compact=True)
plt.tight_layout();
【実行結果】
推定結果の解釈
事後分布の要約統計量を算出します。
算出関数を定義します。
### mean,sd,2.5%,25%,50%,75%,97.5%パーセンタイル点をデータフレーム化する関数の定義
def make_stats_df(y):
probs = [2.5, 25, 50, 75, 97.5]
columns = ['mean', 'sd'] + [str(s) + '%' for s in probs]
quantiles = pd.DataFrame(np.percentile(y, probs, axis=0).T, index=y.columns)
tmp_df = pd.concat([y.mean(axis=0), y.std(axis=0), quantiles], axis=1)
tmp_df.columns=columns
return tmp_df
要約統計量を算出します。
### 要約統計量の算出・表示
vars = ['sigma']
param_samples = idata2.posterior[vars].to_dataframe().reset_index(drop=True)
cols = idata2.posterior.coords['param'].data
b_samples = pd.DataFrame(
idata2.posterior.b.stack(sample=('chain', 'draw')).data.T,
columns=[f'b[{cols[i]}]' for i in range(5)])
param_samples = pd.concat([b_samples, param_samples], axis=1)
display(make_stats_df(param_samples).round(3))
【実行結果】
テキストに事後分布の推定値が掲載されていないので、PyMCモデルによる推論の適否は不明です。
事後予測
Yの事後予測サンプリングを行って、観測値と予測値のプロットを描画します。
事後予測サンプリングを実行します。
### 事後予測サンプリングの実行
with model2:
idata2.extend(pm.sample_posterior_predictive(idata2, random_seed=1234))
【実行結果】
事後予測プロットで推論値の良し悪しを確認します。
### 事後予測プロットの描画
pm.plot_ppc(idata2, num_pp_samples=100, random_seed=1234);
【実行結果】
観測値の黒線に対して、事後予測推論値のオレンジ線(平均値)・青細線(サンプリングデータの一部)は、そこそこ近い曲線を描いています。
観測値と予測値のプロットを描画します。
### 観測値と予測値のプロット
## 描画用データの作成 yPredの個人別の中央値と80%区間を算出
# 推論データからy_predのMCMCサンプルデータを取り出し shape=(32, 2, 4000)
y_pred_samples = (idata2.posterior_predictive.obs
.stack(sample=('chain', 'draw')).data)
# サンプリングデータの10%,50%,90%パーセンタイル点を算出してデータフレーム化
y_pred_df = pd.DataFrame(
np.quantile(y_pred_samples, q=[0.1, 0.5, 0.9], axis=1).T,
columns=['10%', 'median', '90%'])
y_pred_df = pd.concat([data2, y_pred_df], axis=1)
# 中央値と10%点の差、90%点と中央値の差を算出: errorbarで利用
y_pred_df['err_lower'] = y_pred_df['median'] - y_pred_df['10%']
y_pred_df['err_upper'] = y_pred_df['90%'] - y_pred_df['median']
# アルバイト0とアルバイト1に分離
y_pred_A0 = y_pred_df[y_pred_df['A']==0]
y_pred_A1 = y_pred_df[y_pred_df['A']==1]
## 描画処理
# 描画領域の指定
plt.figure(figsize=(6, 6))
ax = plt.subplot()
# アルバイト0の描画(エラーバー付き散布図)
ax.errorbar(y_pred_A0['Y'], y_pred_A0['median'],
yerr=[y_pred_A0['err_lower'], y_pred_A0['err_upper']],
color='tab:blue', alpha=0.7, marker='o', ms=10, linestyle='none',
label='0')
# アルバイト1の描画(エラーバー付き散布図)
ax.errorbar(y_pred_A1['Y'], y_pred_A1['median'],
yerr=[y_pred_A1['err_lower'], y_pred_A1['err_upper']],
color='tab:orange', alpha=0.7, marker='^', ms=10, linestyle='none',
label='1')
# 赤い対角線の描画
ax.plot([0, 0.5], [0, 0.5], color='red', ls='--')
# 修飾
ax.set(xlabel='Observed: $Y$の観測値', ylabel='Predicted: $Y$の予測値(中央値)',
title='$Y$ の観測値と予測値(中央値)のプロット 80%区間バー付き')
ax.legend(title='アルバイト好き', bbox_to_anchor=(1, 1))
ax.grid(lw=0.5);
【実行結果】
80%区間はほぼ赤い観測値=予測値のラインに乗っている感じです。
まあまあの予測精度かもです。
9.3.2 項は以上です。
シリーズの記事
次の記事
前の記事
目次
ブログの紹介
note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!
1.のんびり統計
統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。
2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で
書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!
3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で
書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!
4.楽しい写経 ベイズ・Python等
ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀
5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で
書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。
6.データサイエンスっぽいことを綴る
統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。
7.Python機械学習プログラミング実践記
書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。
最後までお読みいただきまして、ありがとうございました。