見出し画像

StanとRでベイズ統計モデリングをPyMC Ver.5で写経~第8章「8.2 複数の階層を持つモデル」

第8章「階層モデル」

書籍の著者 松浦健太郎 先生


この記事は、テキスト第8章「階層モデル」の8.2節「複数の階層を持つモデル」の PyMC5写経 を取り扱います。
モデルが複雑になるにつれて、ひとまず収束条件を満たすものの、MCMCサンプルデータに発散が含まれる状況になりました。

はじめに


StanとRでベイズ統計モデリングの紹介

この記事は書籍「StanとRでベイズ統計モデリング」(共立出版、「テキスト」と呼びます)のベイズモデルを用いて、PyMC Ver.5で「実験的」に写経する翻訳的ドキュメンタリーです。

テキストは、2016年10月に発売され、ベイズモデリングのモデル式とプログラミングに関する丁寧な解説とモデリングの改善ポイントを網羅するチュートリアル「実践解説書」です。もちろん素晴らしいです!
アヒル本」の愛称で多くのベイジアンに愛されてきた書籍です!

テキストに従ってStanとRで実践する予定でしたが、RのStan環境を整えることができませんでした(泣)
そこでこのシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。

引用表記

この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「StanとRでベイズ統計モデリング」初版第13刷、著者 松浦健太郎、共立出版

記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!

PyMC環境の準備

Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。


8.2 複数の階層を持つ階層モデル


インポート

### インポート

# 数値・確率計算
import pandas as pd
import numpy as np
import scipy.stats as stats

# PyMC
import pymc as pm
import arviz as az

# 描画
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.family'] = 'Meiryo'

# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')

データの読み込み

サンプルコードのデータを読み込みます。

### データの読み込み ◆データファイル8.2 data-salary-3.txt
# X:年齢-23, Y:年収, KID:勤務会社ID, GID:会社が所属する業界ID

data = pd.read_csv('./data/data-salary-3.txt')
print('data.shape: ', data.shape)
display(data.head())

【実行結果】

8.2.1 解析の目的とデータの分布の確認

散布図で可視化します。
テキスト図8.5の左右のグラフに相当します。
まず左側の「すべてのデータをまとめてプロット」です。
単回帰の傾き・切片は scipy.stats の linregress で求めます。

### 散布図の描画 ◆図8.5左

## 単回帰分析の実行
#傾きと切片を取得
slope, intercept, _, _, _ = stats.linregress(x=data.X, y=data.Y)
print(f'傾き: {slope:.3f}, 切片: {intercept:.3f}')
# 回帰直線描画用のxとyの算出
x_lm = np.linspace(data.X.min() - 1, data.X.max() + 1, 2)
y_lm = intercept + slope * x_lm

## 描画処理
# マーカーの設定
markers = {1: 'o', 2: '^', 3: 'X', 4: 'd'}
# 描画領域の設定
plt.figure(figsize=(5, 4))
ax = plt.subplot()
# 回帰直線の描画
ax.plot(x_lm, y_lm, color='black', lw=3, alpha=0.4)
# 散布図の描画
sns.scatterplot(ax=ax, data=data, x='X', y='Y', hue='GID', style='GID',
                s=100, markers=markers, palette='tab10', alpha=0.5)
# 修飾
ax.set(xlabel='年齢 $X$ [-23歳]', ylabel='年収 $Y$ [万円]',
       title='年齢 $X$ と年収 $Y$ の散布図')
ax.legend(bbox_to_anchor=(1, 1), title='業界ID')
ax.grid(lw=0.5)
plt.show()

【実行結果】

上記グラフの凡例を会社別のグラフで使う目的で、凡例データを保存します。

### 会社別グラフで用いる凡例データを保存
handles, labels = ax.get_legend_handles_labels()

続いて図8.5右の「会社別ごとに分割してプロット」です。

### 会社別散布図の描画 ■図8.5右

# 描画領域の指定
fig, axes = plt.subplots(1, 3, figsize=(10, 4), sharex=True, sharey=True)

# 業界IDごとに繰り返し描画処理(処理的にはaxesごとに繰り返し処理)
for i, ax in enumerate(axes.ravel()):

    ## 描画用データの作成
    # 業界を1つ取り出す
    tmp_df = data[data['GID'] == i + 1]
    # 当該業界のデータで回帰直線の傾きと切片を取得
    slope2, intercept2, _, _ ,_ = stats.linregress(x=tmp_df.X, y=tmp_df.Y)
    # 当該業界の回帰直線描画用のxとyの算出
    x_lm2 = np.linspace(tmp_df.X.min(), tmp_df.X.max(), 2)
    y_lm2 = intercept2 + slope2 * x_lm2
    
    ## 描画処理
    # 全業界の回帰直線の描画
    ax.plot(x_lm, y_lm, color='black', lw=3, alpha=0.4)
    # 当該業界の回帰直線の描画
    ax.plot(x_lm2, y_lm2, color='red', lw=2, ls='--')
    # 散布図の描画
    sns.scatterplot(ax=ax, data=tmp_df, x='X', y='Y', style='GID',
                    s=100, markers=markers, color=plt.cm.tab10(i/10), alpha=0.8,
                    legend=None)
    # 修飾
    ax.set(xlabel=None, ylabel=None, title=f'業界ID: {i+1}')
    ax.grid(lw=0.3)

# 図8.5左のグラフから取得した凡例を表示
fig.legend(handles=handles, labels=labels, bbox_to_anchor=(1.07, 0.85),
           title='業界ID')
# 全体修飾
fig.supxlabel('年齢 $X$ [-23歳]')
fig.supylabel('年収 $Y$ [万円]')
fig.suptitle('年齢 $X$ と年収 $Y$ の散布図:業界別')
plt.tight_layout();

【実行結果】

続いて図8.6の会社ごとに直線をあてはめて求めた切片$${a}$$と傾き$${b}$$のヒストグラムを描画します。

### 会社別の回帰直線の切片aと傾きbのヒストグラム ◆図8.6

## 描画用データの作成:業界別会社別の切片と傾きを算出
# 一時リストの初期化
lm_data = []
# 会社IDごとに繰り返し処理
for k in data['KID'].unique():
    # 当該会社IDのデータを取り出し
    tmp_k = data[data['KID'] == k]
    # 業界IDを取得
    g = tmp_k['GID'].head(1).values[0]
    # 回帰分析を実施して切片a, 傾きbを取得
    b, a, _, _ ,_ = stats.linregress(x=tmp_k.X, y=tmp_k.Y)
    # 一時リストに業界ID, 会社ID, 切片, 傾きを格納
    lm_data.append([g, k, a, b])
# 一時リストをデータフレーム化
lm_df = pd.DataFrame(lm_data, columns=['GID', 'KID', '切片a', '傾きb'])

## 描画用のパラメータの設定
# ヒストグラムのビンの設定
bins_a = np.linspace(0, 800, 21)
bins_b = np.linspace(0, 50, 21)

## 描画処理
# 描画領域の指定
fig, ax = plt.subplots(3, 2, figsize=(8, 7), sharey=True)
# 業界IDごとに繰り返し描画処理(処理的にはaxesごとに繰り返し処理)
for i, g in enumerate(sorted(lm_df['GID'].unique())):
    ## 切片aのヒストグラムの描画処理
    sns.histplot(ax=ax[i, 0], x=lm_df[lm_df['GID']==g]['切片a'],
                 bins=bins_a, kde=True, ec='white')
    ax[i, 0].set(title=f'業界{g}', yticks=range(0, 6), xlabel=None, ylabel=None)
    ax[i, 0].grid(lw=0.3)
    ## 傾きbのヒストグラムの描画処理
    sns.histplot(ax=ax[i, 1], x=lm_df[lm_df['GID']==g]['傾きb'],
                 bins=bins_b, kde=True, ec='white')
    ax[i, 1].set(title=f'業界{g}', yticks=range(0, 6), xlabel=None, ylabel=None)
    ax[i, 1].grid(lw=0.3)
# 全体修飾
ax[2, 0].set_xlabel('切片$a$')
ax[2, 1].set_xlabel('傾き$b$')
fig.supylabel('count')
plt.tight_layout();

【実行結果】

8.2.4 Stanで実装-その1

PyMCでモデル式8-5を実装します。
モデルに用いる追加データの作成です。

### データの追加的作成
# 会社インデックス、会社が所属する業界のインデックスの作成(0始まりにする)

# 会社インデックス
k_idx = data.KID - 1
display(k_idx)

# 会社が所属する業界IDを格納するデータフレーム
k2g_df = data[['KID', 'GID']].drop_duplicates().sort_values(['KID']) - 1
display(k2g_df)

【実行結果】

続いてモデルの定義です。

### モデルの定義 ◆モデル式8-5 model8-5.stan

with pm.Model() as model1:
    
    ### データ関連定義
    ## coordの定義
    model1.add_coord('data', values=data.index, mutable=True)
    model1.add_coord('kaisha', values=sorted(data.KID.unique()), mutable=True)
    model1.add_coord('gyokai', values=sorted(data.GID.unique()), mutable=True)
    ## dataの定義
    # 目的変数 Y
    Y = pm.ConstantData('Y', value=data['Y'].values, dims='data')
    # 説明変数 X
    X = pm.ConstantData('X', value=data['X'].values, dims='data')
    # 説明変数 KIdx 会社インデックス
    KIdx = pm.ConstantData('KIdx', value=k_idx.values, dims='data')
    # 説明変数 K2G 会社が所属する業界ID
    K2G = pm.ConstantData('K2G', value=k2g_df.GID.values, dims='kaisha')

    ### 事前分布
    a0 = pm.Uniform('a0', lower=-10000, upper=10000)
    b0 = pm.Uniform('b0', lower=-10000, upper=10000)
    sigmaAg = pm.Uniform('sigmaAg', lower=0, upper=10000)
    sigmaBg = pm.Uniform('sigmaBg', lower=0, upper=10000)
    sigmaA = pm.Uniform('sigmaA', lower=0, upper=10000)
    sigmaB = pm.Uniform('sigmaB', lower=0, upper=10000)
    sigmaY = pm.Uniform('sigmaY', lower=0, upper=10000)
    ag = pm.Normal('ag', mu=a0, sigma=sigmaAg, dims='gyokai')
    bg = pm.Normal('bg', mu=b0, sigma=sigmaBg, dims='gyokai')
    a = pm.Normal('a', mu=ag[K2G], sigma=sigmaA, dims='kaisha')
    b = pm.Normal('b', mu=bg[K2G], sigma=sigmaB, dims='kaisha')

    ### 線形予測子
    mu = pm.Deterministic('mu', a[KIdx] + b[KIdx] * X, dims='data')

    ### 尤度関数
    obs = pm.Normal('obs', mu=mu, sigma=sigmaY, observed=Y, dims='data')

モデルの定義内容を見ます。

### モデルの表示
model1

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model1)

【実行結果】
会社の階層と業界の階層を確認できます。

MCMCを実行します。

### 事後分布からのサンプリング 20秒
with model1:
    idata1  = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.9,
                        nuts_sampler='numpyro', random_seed=1234)

【実行結果】省略

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata1        # idata名
threshold = 1.01         # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
var_names = ['a0', 'b0', 'sigmaAg', 'sigmaBg', 'ag', 'bg', 'sigmaA', 'sigmaB',
             'sigmaY']
pm.summary(idata1, hdi_prob=0.95, var_names=var_names, round_to=3)

【実行結果】

### 推論データの要約統計情報の表示
var_names = ['a', 'b']
pm.summary(idata1, hdi_prob=0.95, var_names=var_names, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
pm.plot_trace(idata1, compact=True)
plt.tight_layout();

【実行結果】
発散が見られます。

8.2.7 Stanで実装-その2

PyMCでモデル式8-6を実装します。
モデルに用いる追加データの作成です。

### データの追加的作成
# 業界インデックスの作成(0始まりにする)
g_idx = data.GID - 1
display(g_idx)

【実行結果】

続いてモデルの定義です。

### モデルの定義 ◆モデル式8-6 model8-6.stan

with pm.Model() as model2:
    
    ### データ関連定義
    ## coordの定義
    model2.add_coord('data', values=data.index, mutable=True)
    model2.add_coord('kaisha', values=sorted(data.KID.unique()), mutable=True)
    model2.add_coord('gyokai', values=sorted(data.GID.unique()), mutable=True)
    ## dataの定義
    # 目的変数 Y
    Y = pm.ConstantData('Y', value=data['Y'].values, dims='data')
    # 説明変数 X
    X = pm.ConstantData('X', value=data['X'].values, dims='data')
    # 説明変数 KIdx 会社インデックス
    KIdx = pm.ConstantData('KIdx', value=k_idx.values, dims='data')
    # 説明変数 KIdx 会社インデックス
    GIdx = pm.ConstantData('GIdx', value=g_idx.values, dims='data')
    # 説明変数 K2G 会社が所属する業界ID
    K2G = pm.ConstantData('K2G', value=k2g_df.GID.values, dims='kaisha')

    ### 事前分布
    a0 = pm.Uniform('a0', lower=-10000, upper=10000)
    b0 = pm.Uniform('b0', lower=-1000, upper=1000)
    sigmaAg = pm.Uniform('sigmaAg', lower=0, upper=10000)
    sigmaBg = pm.Uniform('sigmaBg', lower=0, upper=1000)
    sigmaA = pm.Uniform('sigmaA', lower=0, upper=10000, dims='gyokai')
    sigmaB = pm.Uniform('sigmaB', lower=0, upper=1000, dims='gyokai')
    sigmaY = pm.Uniform('sigmaY', lower=0, upper=10000, dims='gyokai')
    ag = pm.Normal('ag', mu=a0, sigma=sigmaAg, dims='gyokai')
    bg = pm.Normal('bg', mu=b0, sigma=sigmaBg, dims='gyokai')
    a = pm.Normal('a', mu=ag[K2G], sigma=sigmaA[K2G], dims='kaisha')
    b = pm.Normal('b', mu=bg[K2G], sigma=sigmaB[K2G], dims='kaisha')

    ### 線形予測子
    mu = pm.Deterministic('mu', a[KIdx] + b[KIdx] * X, dims='data')

    ### 尤度関数
    obs = pm.Normal('obs', mu=mu, sigma=sigmaY[GIdx], observed=Y, dims='data')

モデルの定義内容を見ます。

### モデルの表示
model2

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model2)

【実行結果】
会社の階層と業界の階層を確認できます。

MCMCを実行します。

### 事後分布からのサンプリング 40秒
with model2:
    idata2  = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.9,
                        nuts_sampler='numpyro', random_seed=1234)

【実行結果】省略

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata2        # idata名
threshold = 1.02         # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
var_names = ['a0', 'b0', 'sigmaAg', 'sigmaBg', 'ag', 'bg', 'sigmaA', 'sigmaB',
             'sigmaY']
pm.summary(idata2, hdi_prob=0.95, var_names=var_names, round_to=3)

【実行結果】

### 推論データの要約統計情報の表示
var_names = ['a', 'b']
pm.summary(idata2, hdi_prob=0.95, var_names=var_names, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
pm.plot_trace(idata2, compact=True)
plt.tight_layout();

【実行結果】
発散が見られます。

8.2 節は以上です。


シリーズの記事

次の記事

前の記事

目次

ブログの紹介


note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!

1.のんびり統計

統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。

2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で

書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!

3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で

書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!

4.楽しい写経 ベイズ・Python等

ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀

5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で

書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。

6.データサイエンスっぽいことを綴る

統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。

7.Python機械学習プログラミング実践記

書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。

最後までお読みいただきまして、ありがとうございました。

いいなと思ったら応援しよう!

この記事が参加している募集