見出し画像

StanとRでベイズ統計モデリングをPyMC Ver.5で写経~第5章「5.3 ロジスティック回帰」

第5章「基本的な回帰とモデルのチェック」

書籍の著者 松浦健太郎 先生


この記事は、テキスト第5章「基本的な回帰とモデルのチェック」の5.3節「ロジスティック回帰」の PyMC5写経 を取り扱います。
尤度関数にベルヌーイ分布を用います。
ベルヌーイ分布のパラメータである「成功確率$${p}$$」はロジスティック関数を介して線形モデルが適用されます。

はじめに


StanとRでベイズ統計モデリングの紹介

この記事は書籍「StanとRでベイズ統計モデリング」(共立出版、「テキスト」と呼びます)のベイズモデルを用いて、PyMC Ver.5で「実験的」に写経する翻訳的ドキュメンタリーです。

テキストは、2016年10月に発売され、ベイズモデリングのモデル式とプログラミングに関する丁寧な解説とモデリングの改善ポイントを網羅するチュートリアル「実践解説書」です。もちろん素晴らしいです!
アヒル本」の愛称で多くのベイジアンに愛されてきた書籍です!

テキストに従ってStanとRで実践する予定でしたが、RのStan環境を整えることができませんでした(泣)
そこでこのシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。

引用表記

この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「StanとRでベイズ統計モデリング」初版第13刷、著者 松浦健太郎、共立出版

記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!

PyMC環境の準備

Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。


5.3 ロジスティック回帰


インポート

### インポート

# 数値・確率計算
import pandas as pd
import numpy as np
import scipy.stats as stats

# PyMC
import pymc as pm
import arviz as az

# ROC-AUC
from sklearn.metrics import roc_curve, roc_auc_score

# 描画
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.family'] = 'Meiryo'

# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')

サンプルコードのデータを読み込みます。

### データの読み込み ◆データファイル5.3 data-attendance-3.txtの構成
# PersonID:学生ID、A:バイト好き区分(1:好き), Weather:天気(A:晴れ、B:曇り、C:雨)
# Y:出欠区分:授業に出席したかどうか(0:欠席、1:出席)

data = pd.read_csv('./data/data-attendance-3.txt')
print('data.shape: ', data.shape)
display(data.head())

【実行結果】

5.3.2 データの分布の確認

天気の種類ごとに目的変数$${Y}$$を集計します。

### 天気ごとにYを集計する ◆テキスト70ページ
data.pivot_table(index='Weather', columns='Y', values='PersonID',
                 aggfunc='count')

【実行結果】

散布図行列を描画しやすいようにデータを加工します。

### 描画しやすいように天気を数値に変換
Weather_dict = {'A': 1, 'B': 2, 'C': 3}
data['Weather_num'] = data['Weather'].map(Weather_dict)
data.head()

【実行結果】

散布図行列を描画します。
スピアマンの順位相関係数の算出コードを少々簡素化しました。

### 散布図行列の描画

## 描画領域の指定
fig, ax = plt.subplots(4, 4, figsize=(10, 10))
ax = ax.ravel() # 1次元でaxesを指定したいので

## 番地0,0:ヒストグラムの描画(棒グラフを使用)
bar_A = data.A.value_counts().sort_index()
sns.barplot(ax=ax[0], x=bar_A.index, y=bar_A, hue=bar_A.index, palette='tab10',
            alpha=0.5, ec='white', legend=None)
ax[0].set(ylabel='A', xlabel=None)
ax[0].grid(lw=0.5)

## 番地1,0:箱ひげ図+ストリッププロットの描画
sns.boxplot(ax=ax[4], x=data.A, y=data.Score, hue=data.A, fill=False,
            legend=None)
sns.stripplot(ax=ax[4], x=data.A, y=data.Score, hue=data.A, size=8, alpha=0.01,
              legend=None)
ax[4].set(xlabel=None)
ax[4].grid(lw=0.5)

## 番地1,1:ヒストグラムの描画
sns.histplot(ax=ax[5], data=data, x='Score', hue='A', bins=10, kde=True,
             ec='white', legend=None)
ax[5].set(xlabel=None, ylabel=None)
ax[5].grid(lw=0.5)

## 番地2,0:箱ひげ図+ストリッププロットの描画
sns.boxplot(ax=ax[8], x=data.A, y=data.Weather_num, hue=data.A, fill=False,
            legend=None)
sns.stripplot(ax=ax[8], x=data.A, y=data.Weather_num, hue=data.A, size=8,
              alpha=0.01, legend=None)
ax[8].grid(lw=0.5)

## 番地2,1:散布図の描画
sns.scatterplot(ax=ax[9], data=data, x='Score', y='Weather_num', hue='A',
                size='A', style='A', markers=['o', '^'], sizes=(80, 80),
                alpha=0.5, legend=None)
ax[9].set(ylabel=None)
ax[9].grid(lw=0.5)

## 番地2,2:ヒストグラムの描画
sns.histplot(ax=ax[10], data=data, x='Weather_num', hue='A', bins=10, kde=True,
             ec='white', legend=None)
ax[10].set(ylabel=None)
ax[10].grid(lw=0.5)

## 番地3,0:箱ひげ図+ストリッププロットの描画
sns.boxplot(ax=ax[12], x=data.A, y=data.Y, hue=data.A, fill=False,
            legend=None)
sns.stripplot(ax=ax[12], x=data.A, y=data.Y, hue=data.A, size=8, alpha=0.01,
              legend=None)
ax[12].grid(lw=0.5)

## 番地3,1:散布図の描画
sns.scatterplot(ax=ax[13], data=data, x='Score', y='Y', hue='A', size='A',
                style='A', markers=['o', '^'], sizes=(80, 80), alpha=0.5,
                legend=None)
ax[13].set(ylabel=None)
ax[13].grid(lw=0.5)

## 番地3,2:散布図の描画
sns.scatterplot(ax=ax[14], data=data, x='Weather_num', y='Y', hue='A', size='A',
                style='A', markers=['o', '^'], sizes=(80, 80), alpha=0.5,
                legend=None)
ax[14].set(ylabel=None)
ax[14].grid(lw=0.5)

## 番地3,3:ヒストグラムの描画(棒グラフを使用)
bar_Y = data.Y.value_counts().sort_index()
sns.barplot(ax=ax[15], x=bar_Y.index, y=bar_Y, hue=bar_Y.index, palette='tab10',
            alpha=0.5, ec='white', legend=None)
ax[15].set(ylabel='Y', xlabel=None)
ax[15].grid(lw=0.5)

## スピアマンの順位相関係数を上三角のaxesに表示
# 列名をリスト化
cols = ['A', 'Score', 'Weather_num', 'Y']
# 列名の組み合わせ 行i,列j ごとにテキスト表示を繰り返す
for i, col1 in enumerate(cols):
    for j, col2 in enumerate(cols):
        # 上三角の位置は 行i < 列j のとき
        if i < j:
            # axesの番号を取得
            pos = i * len(cols) + j
            # スピアマンの順位相関係数を算出
            corr, pval = stats.spearmanr(data[col1], data[col2])
            # 枠線等を削除
            ax[pos].set_axis_off()
            # テキスト表示:中央表示に関連する引数: x,y,va,ha,transform
            ax[pos].text(x=0.5, y=0.5, s=round(corr * 100), fontsize=30,
                         va='center', ha='center', transform=ax[pos].transAxes)
# 全体修飾
plt.tight_layout();

【実行結果】

5.3.3 背景知識を使った値の変換

### 天気の重み係数を設定 ◆テキスト71ページに記述の変換処理
Weather_dict2 = {'A': 0, 'B': 0.2, 'C': 1}
data['Weather_w'] = data['Weather'].map(Weather_dict2)
data.head()

【実行結果】

5.3.6 Stanで実装

PyMC Ver.5 で実装します。
モデルの定義です。

### モデルの定義 ◆model5-5.stan

with pm.Model() as model:
    
    ### データ関連定義
    ## coordの定義
    model.add_coord('data', values=data.index, mutable=True)
    model.add_coord('beta', values=[1, 2, 3, 4], mutable=True)
    ## dataの定義
    # 目的変数 Y
    Y = pm.ConstantData('Y', value=data['Y'].values, dims='data')
    # 説明変数 A
    A = pm.ConstantData('A', value=data['A'].values, dims='data')
    # 説明変数 Score / 200
    Score = pm.ConstantData('Score', value=data['Score'].values / 200,
                            dims='data')
    # 説明変数 W(Weather)
    W = pm.ConstantData('W', value=data['Weather_w'].values, dims='data')

    ### 事前分布
    b = pm.Uniform('b', lower=-10, upper=10, dims='beta')

    ### 線形予測子の逆ロジット変換
    q = pm.Deterministic(
        'q',
        pm.invlogit(b[0] + b[1] * A + b[2] * Score + b[3] * W),
        dims='data')
    
    ### 尤度関数
    obs = pm.Bernoulli('obs', p=q, observed=Y, dims='data')

モデルの定義内容を見ます。

### モデルの表示
model

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model)

【実行結果】

PythonでMCMCを実行します。

### 事後分布からのサンプリング 25秒 ◆run-model5-5.R
with model:
    idata = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
                      nuts_sampler='numpyro', random_seed=1234)

【実行結果】省略

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata         # idata名
threshold = 1.01         # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
var_names = ['b', 'q']
pm.summary(idata, hdi_prob=0.95, var_names=var_names, round_to=3)

【実行結果】

トレースプロットを描画します。
パラメータ q は一部のみの表示です。

### トレースプロットの表示
pm.plot_trace(idata, compact=False, var_names=var_names)
plt.tight_layout();

【実行結果】

$${Y}$$(モデル的には obs )の事後予測サンプリングを行います。

### Yの事後予測分布のサンプリング
with model:
    idata.extend(pm.sample_posterior_predictive(idata, random_seed=1234))

【実行結果】

ppcプロットを描画します。

### ppcプロットの描画
pm.plot_ppc(idata, num_pp_samples=100);

【実行結果】

パラメータの事後統計量の要約を算出します。

### パラメータの要約を確認

## 統計量算出関数:mean,sd,2.5%,25%,50%,75%,97.5%点をデータフレーム化する
def make_stats_df(y):
    probs = [2.5, 25, 50, 75, 97.5]
    columns = ['mean', 'sd'] + [str(s) + '%' for s in probs]
    quantiles = pd.DataFrame(np.percentile(y, probs, axis=0).T, index=y.columns)
    tmp_df = pd.concat([y.mean(axis=0), y.std(axis=0), quantiles], axis=1)
    tmp_df.columns=columns
    return tmp_df

## 要約統計量の算出・表示
# 事後分布サンプリングデータidataからパラメータbを取り出してデータフレーム化
param_samples = pd.DataFrame(
    idata.posterior.b.stack(sample=('chain', 'draw')).T.data,
    columns=['b1', 'b2', 'b3', 'b4'])
# 上記データフレームを統計量算出関数に与えて事後統計量データフレームを作成
params_stats_df = make_stats_df(param_samples)
# 事後統計量データフレームの表示
display(params_stats_df.round(2))

【実行結果】

5.3.7 図によるモデルのチェック

■ 図5.9
説明変数$${Score}$$に応じた目的変数$${Y}$$の観測値と確率パラメータ$${q}$$の事後分布を描画します。
描画に利用する$${q}$$の事後分布データはこのコードで生成します。
点が重ならないようにする jitter 処理は、手作りで対応しました。

### qの事後分布の描画 ◆図5.9
# 描画条件:A=0(アルバイトが好きでない), Weather='A'(晴れ)

## 設定
# 乱数生成器の設定
rng = np.random.default_rng(seed=1234)
# ロジスティック関数の定義
def logistic(x):
    return 1 / (1 + np.exp(-x))

## 描画用データの作成:確率 q の事後分布
# パラメータbをMCMCサンプルから取り出し shape=(4, 4000)
b_samples = idata.posterior.b.stack(sample=('chain', 'draw')).data
# x軸のscoreの30~200の値を作成
score_vals = np.arange(30, 201)
# パラメータqの計算:b0 + b2 * score / 200
q_plot = np.array([
    logistic(b_samples[0] + b_samples[2] * score_val / 200)
    for score_val in score_vals])
# qの中央値と80%CIを算出
q_plot_median = np.median(q_plot, axis=1)
q_plot_80ci = np.quantile(q_plot, q=[0.1, 0.9], axis=1)

## 描画用データの作成:観測値 Y
# Yのうち描画条件に該当するデータを抽出
y_plot = data.query("A==0 & Weather=='A'")[['Score', 'Y']]
# Yの描画の際に重なりを避けるためのjitterデータの作成
y_plot['Y_jitter'] = (
    y_plot['Y'] + rng.uniform(low=-0.1, high=0.1, size=y_plot.shape[0]))

## 描画処理
# qの事後分布の中央値の折れ線グラフの描画
plt.plot(score_vals, q_plot_median)
# qの事後分布の80%CIの塗りつぶしの描画
plt.fill_between(score_vals, q_plot_80ci[0], q_plot_80ci[1], alpha=0.2)
# Yの観測値の散布図の描画(y_jitterを描画)
sns.scatterplot(x=y_plot['Score'], y=y_plot['Y_jitter'], color='tab:red',
                alpha=0.3, label='A=0')
# 修飾
plt.title('確率パラメータ $q$ の事後分布\n中央値と80%ベイズ区間')
plt.legend(bbox_to_anchor=(1, 1), title='アルバイト区分')
plt.grid(lw=0.5);

【実行結果】

■ 図5.10
確率パラメータ$${q}$$の事後分布と目的変数$${Y}$$の観測値の描画です。
点が重ならないようにする jitter 処理は、seaborn の stripplot で対応しました。

### 確率と観測値のプロット ◆図5.10

## 描画用データの作成
# 確率qの事後分布の中央値の算出
q_medians = idata.posterior.q.stack(s=('chain', 'draw')).median(axis=1).data
# 描画項目 Y, A, q_madian をデータフレーム化
plot_data = pd.concat([data[['Y', 'A']], pd.DataFrame({'q': q_medians})], axis=1)
# Yをカテゴリ型に変換
plot_data['Y'] = plot_data['Y'].astype('category')

## 描画処理
# 描画領域の設定
plt.figure(figsize=(8, 5))
# ヴァイオリンプロットの描画
sns.violinplot(data=plot_data, x='q', y='Y', order=['1', '0'], inner=None,
               fill=False, color='darkgray', linewidth=5)
# ストリッププロットの描画
sns.stripplot(data=plot_data, x='q', y='Y', hue='A', alpha=0.5)
# 修飾
plt.title('確率 $q$(中央値)の事後分布と観測値 $Y$ のプロット')
plt.legend(bbox_to_anchor=(1.2, 1), title='アルバイト区分')
plt.grid(lw=0.5);

【実行結果】

■ 図5.11
ROC曲線を描画します。まずデータ加工から。

### ROC曲線の描画 ◆図5.11

## 描画用データの作成
# 確率qの事後分布サンプルデータの取り出し
q_samples = idata.posterior.q.stack(s=('chain', 'draw')).data
# qの80%区間をデータフレームに追加
plot_data[['q10%', 'q90%']] = np.quantile(q_samples, q=[0.1, 0.9], axis=1).T
display(plot_data)

【実行結果】

ROC曲線を描画します。
80%区間は不正確な方法で描画していますので、ざっくりとご覧ください。

## ROC曲線描画用のデータ作成
# ROC曲線のx軸:fpr、y軸:tprの算出
fpr, tpr, threshold = roc_curve(plot_data['Y'], plot_data['q'])
fpr10, tpr10, threshold10 = roc_curve(plot_data['Y'], plot_data['q10%'])
fpr90, tpr90, threshold90 = roc_curve(plot_data['Y'], plot_data['q90%'])
# ROC-AUCスコアの算出
q_score = roc_auc_score(plot_data['Y'], plot_data['q'])
q10_score = roc_auc_score(plot_data['Y'], plot_data['q10%'])
q90_score = roc_auc_score(plot_data['Y'], plot_data['q90%'])

## 描画処理
# 描画領域の設定
plt.figure(figsize=(5, 5))
ax = plt.subplot()
# ROC曲線の青線の描画
ax.plot(fpr, tpr)
# 80%ベイズ信頼区間の描画(ただし、10%と90%のy軸が一致していない為、概算です)
ax.fill_betweenx(tpr, fpr10, fpr90, alpha=0.3)
# 対角線の赤点線の描画
ax.plot([0, 1], [0, 1], color='tab:red', ls='--')
# 修飾
ax.set(xlabel='False Positive: FPR', ylabel='True Positive: TPR',
       title=f'ROC曲線\nAUC: '
             f'{q_score:.4f} [{q10_score:.4f}, {q90_score:.4f}]')
plt.grid(lw=0.5);

【実行結果】

MCMCサンプルの散布図行列を描画します。
スピアマンの順位相関係数の表示を前回記事から変更しました。

### MCMCサンプルの散布図行列の描画

## 描画用データの作成
# MCMCサンプリングデータからq1, q50を取り出し
q1_samples = (idata.posterior['q'].to_dataframe().reset_index()
              .query('data==0').rename({'q': 'q1'}, axis=1))
q50_samples = (idata.posterior['q'].to_dataframe().reset_index()
               .query('data==49').rename({'q': 'q50'}, axis=1))
# 描画対象パラメータをデータフレーム化
plot_df = pd.concat([pd.DataFrame(b_samples.T, columns=['b1', 'b2', 'b3', 'b4']),
                     q1_samples.reset_index(drop=True)['q1'],
                     q50_samples.reset_index(drop=True)['q50']], axis=1)

## 描画処理
# 相関行列プロットの描画
g = sns.pairplot(plot_df, diag_kws={'kde': True, 'ec': 'white'})
# スピアマンの順位相関係数の表示のためのaxフラット化
ax = g.axes.ravel()

## スピアマンの順位相関係数を上三角のaxesに表示
# 列名をリスト化
cols = plot_df.columns
# 列名の組み合わせ行i, 列j ごとにテキスト表示を繰り返す
for i, col1 in enumerate(cols):
    for j, col2 in enumerate(cols):
        # 上三角の位置は 行i < 列j のとき
        if i < j:
            # axesの番号を取得
            pos = i * len(cols) + j
            # スピアマンの順位相関係数を算出
            corr, pval = stats.spearmanr(plot_df[col1], plot_df[col2])
            # テキスト表示:中央表示に関連する引数: x,y,va,ha,transform
            ax[pos].text(x=0.5, y=0.5, s=round(corr * 100), fontsize=30,
                         va='center', ha='center', transform=ax[pos].transAxes,
                         bbox=dict(boxstyle='round', facecolor='white'))

【実行結果】

5.3 節は以上です。


シリーズの記事

次の記事

前の記事

目次

ブログの紹介


note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!

1.のんびり統計

統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。

2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で

書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!

3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で

書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!

4.楽しい写経 ベイズ・Python等

ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀

5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で

書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。

6.データサイエンスっぽいことを綴る

統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。

7.Python機械学習プログラミング実践記

書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。

最後までお読みいただきまして、ありがとうございました。

この記事が参加している募集

この記事が気に入ったらサポートをしてみませんか?