見出し画像

第16章「あなたの英語,大丈夫?」のベイズモデリングをPyMC Ver.5 で

この記事は、テキスト「たのしいベイズモデリング」の第16章「あなたの英語,大丈夫?」のベイズモデルを用いて、PyMC Ver.5で「実験的」に実装する様子を描いた統計ドキュメンタリーです。

この章では、特定の状況下で発せられる外国語表現が適切か不適切かを判断する能力について、信号検出モデルを用いた階層ベイズモデルで推論します。

アメリカの信号のイラスト:「いらすとや」さんより

今回のモデルは久しぶりのクリーンヒット!
PyMCでテキストを再現できたのです。嬉しいです!
さあ、PyMCのベイズモデリングを楽しみましょう。

テキストの紹介、引用表記、シリーズまえがき、PyMC等のバージョン情報は、このリンクの記事をご参照ください。

テキストで使用するデータは、R・Stan等のサンプルスクリプトとともに、出版社サイトからダウンロードして取得できます。


サマリー


テキストの概要

執筆者   : 鬼田崇作 先生、草薙邦広 先生
モデル難易度: ★★★・・ (ふつう)

自己評価

評点

$$
\begin{array}{c:c:c}
実装精度 & ★★★★★& GoooD! \\
結果再現度 & ★★★★★& 最高✨ \\
楽しさ & ★★★★★& 楽しい! \\
\end{array}
$$

評価ポイント

  • テキストのモデルを再現できました!やったね!

工夫・喜び・反省

  • 実は、1周目の時点では$${\hat{R}}$$が異常な値となり、収束を果たせませんでした。
    ブログ化をきっかけにして、再度、テキストとRスクリプトを見返して、モデル式の誤りを見つけて、なんとか結果を出せました。
    諦めないこと、そして、少し時間を置いて「新たな目」で解き直すことの大切さを噛み締めました。

インターネット上で歓喜する人達のイラスト:「いらすとや」さんより

モデルの概要


テキストの調査・実験の概要

■ 語用論的能力
著者は、我々はどの程度、社会的状況に応じたことばの使い分けを「外国語」でできるだろうか、と疑問を投げかけています。

 適切な発話を正しく適切と判断できるだろうか。
 適切な発話を不適切と誤って判断したりしないだろうか。

テキストより引用

社会的なことばの使用に関わる能力を「語用論的能力」と呼んでいます。

■ 信号検出モデル

 適切な発話を適切と判断したり、不適切と誤って判断する。
 または、不適切な発話を不適切と判断したり、適切と誤って判断する。

テキストより引用

テキストでは、このような判断を心理学の「刺激の弁別」として捉えて、信号検出モデルを適用しています。
正解と被験者の判断の関係=信号検出理論の反応テーブルは次のとおりです。

$$
\begin{array}{c|cc}
& S:正解=適切 & N:正解=不適切 \\
\hline
R_1:判断=適切 & ヒット(\text{Hit}) & 誤警報(\text{False Alarm}) \\
R_2:判断=不適切 & ミス(\text{Miss}) & 正棄却(\text{Correct Rejection})\\
\end{array}
$$

モデルでは、ヒット率$${P(R_1 \mid S)}$$と誤警報率$${P(R_1 \mid N)}$$に着目します。

■ 実験の概要
テキストでは、日本語による場面説明と英語による会話文章について、「適切な表現を15個」「不適切な表現を15個」用意して、大学生から回答データを取得、101人のデータを分析しています。

実験の概要を図示します。

テキストのモデリング

■目的変数と関心のあるパラメータ
目的変数は、ヒット数$${H}$$と誤警報数$${FA}$$です。
関心のあるパラメータは、適切な表現を適切と判断し、不適切な表現を不適切と判断する能力「弁別力」を示す$${d}$$、判断が適切または不適切に偏りやすい傾向=バイアスを示す$${c}$$です。

■ 二項分布とヒット率・誤警報率
ヒット数は、試行回数=適切な文章数$${n_S}$$と成功確率=ヒット率$${\theta_H}$$をパラメータにもつ二項分布に従います。
誤警報数は、試行回数=不適切な文章数$${n_N}$$と成功確率=誤警報率$${\theta_F}$$をパラメータにもつ二項分布に従います。

■等分散ガウス信号検出モデルとヒット率・誤警報率
「不適切な文章に対する心理量」が標準正規分布$${\text{Normal}(0,\ 1)}$$に従い、「適切な文章に対する心理量」が標準正規分布の平均パラメータに$${d}$$を加えた(右側に$${d}$$シフトした)正規分布$${\text{Normal}(d,\ 1)}$$に従うとしています。
下図の2つの正規分布では弁別力$${d=3}$$としています。
2つの正規分布の平均は$${d}$$ずれています。

しきい値$${k}$$を超えると「適切」な文章に対して、判断は「適切」となります。
下図の青塗りの部分がヒット率$${\theta_H}$$です。

同じくしきい値$${k}$$を超えて、「不適切」な文章に対して判断を「適切」としてしまうオレンジ塗りの部分が誤警報率$${\theta_F}$$です。

バイアス$${c=k-d/2}$$を導入して、かつ、「適切」の正規分布を標準正規分布に変換すると、ヒット率と誤警報率は、標準正規分布の累積分布関数で表現できます。

まず「適切」から。青塗りがヒット率です。
青塗り部分は上の図と比べて左右反転させています。
赤い点線:境界線$${d/2-c}$$がポイントです。
青塗り部分は、標準正規分布の累積分布関数$${\phi(d/2-c)=\theta_H}$$なのです。

続いて「不適切」。オレンジ塗りが誤警報率です。
オレンジ塗り部分は上の図と比べて左右反転させています。
赤い点線:境界線$${-d/2-c}$$がポイントです。
オレンジ塗り部分は、標準正規分布の累積分布関数$${\phi(-d/2-c)=\theta_F}$$なのです。

役者が揃いました。

■ テキストの階層ベイズモデル
目的変数であるヒット数$${H}$$・誤警報数$${FA}$$は二項分布に従います。
ヒット率$${\theta_H}$$・誤警報率$${\theta_F}$$は標準正規分布の累積分布関数$${\phi(\cdot)}$$です。
上のグラフで標準正規分布の累積分布関数で表現したヒット率・誤警報率の境界線が引数です。
添字$${i}$$は実験参加者のIDです。

$$
\begin{align*}
H &\sim \text{Binomial}\ (\theta_H,\ n_S) \\
FA &\sim \text{Binomial}\ (\theta_F,\ n_N) \\
\theta_H &= \phi\ \left(\cfrac{d_i}{2} -c_i\right) \\
\theta_F &= \phi\ \left(-\cfrac{d_i}{2} -c_i\right) \\
d_i &\sim \text{Normal}\ (\mu_d,\ \sigma_d) \\
c_i &\sim \text{Normal}\ (\mu_c,\ \sigma_c) \\
\mu_d,  \mu_c &\sim \text{Normal}\ (0,\ \sqrt{1000}) \\
\sigma_d,  \sigma_c &= 1/\sqrt{\lambda_d},\ 1/\sqrt{\lambda_c} \\
\lambda_d,  \lambda_c &\sim \text{Gamma}\ (1/1000,\ 1/1000) \\
\end{align*}
$$

テキストの記述を一部改変して引用

■分析・分析結果
分析の仕方や分析数値はテキストの記述が正確だと思いますので、テキストの読み込みをおすすめします。
PyMCの自己流モデルの推論値を利用した分析は「PyMC実装」章の【分析】をご覧ください。

PyMC実装


Let's enjoy PyMC & Python !

準備・データ確認

1.インポート

### インポート

# ユーティリティ
import pickle

# 数値・確率計算
import pandas as pd
import numpy as np
import scipy.stats as stats

# PyMC
import pymc as pm
import pytensor.tensor as pt
import arviz as az

# 描画
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'Meiryo'

# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')

2.データの読み込み
csvファイルをpandasのデータフレームに読み込みます。

### データの読み込み

# データの読み込み
# 適切な例15、不適切な例15、合計30例のセリフに対する反応数
# 1列目H:ヒット,2列目FA:誤警報,3列目M:ミス,4列目CR:正棄却
data = pd.read_csv('SDT_data.csv')

# データの表示
display(data)

【実行結果】
全101行です。1行が実験参加者1人分のデータです。
データ項目は次のとおりです。
・H:ヒット数(適切な表現を適切と判断)
・FA:誤警報数(不適切な表現を適切と判断)
・M:ミス数(適切な表現を不適切と判断)
・CR:正棄却数(不適切な表現を不適切と判断)

2.データの外観の確認
データの要約統計量とヒストグラムでデータの外観を確認します。
まずは要約統計量から。

### 要約統計量の表示
display(data.describe().round(2))

【実行結果】
適切な文章・不適切な文章はそれぞれ15個あるので、中間値は7.5です。
平均値で見ると、適切な文章に対する判断は悪く(H:ヒット<7.5)、不適切な文章に対する判断は良い(CR:正棄却>7.5)ようです。

続いてヒストグラム。pandas の hist を利用しました。

### データの可視化:箱ひげ図
fig, ax = plt.subplots(figsize=(10, 5))
ax = data.hist(ax=ax);
fig.supxlabel('判断数')
fig.supylabel('度数:人数');

【実行結果】
4つの分布はなんとなく正規分布に似ているような、そうでないような、という感じです。

モデル構築

モデルの数式表現
目指したいPyMCのモデルの雰囲気を混ぜた「なんちゃって数式」表記です。

$$
\begin{align*}
n_S,\ n_N &= 15 \\
\mu_d,\ \mu_c &\sim \text{Normal}\ (\text{mu}=0,\ \text{sigma}=1/\sqrt{1/1000}) \\
\lambda_d,\ \lambda_c &\sim \text{Gamma}\ (\text{alpha}=1/1000,\ \text{beta}=1/1000) \\
\sigma_d &= 1/\sqrt{\lambda_d} \\
\sigma_c &= 1/\sqrt{\lambda_c} \\
d_i &\sim \text{Normal}\ (\text{mu}=\mu_d,\ \text{sigma}=\sigma_d) \\
c_i &\sim \text{Normal}\ (\text{mu}=\mu_c,\ \text{sigma}=\sigma_c) \\
\theta_H &= \text{Normal.CDF}\ (d_i/2-c_i,\ \text{mu}=0,\ \text{sigma}=1) \\
\theta_F &= \text{Normal.CDF}\ (-d_i/2-c_i,\ \text{mu}=0,\ \text{sigma}=1) \\
H &\sim \text{Binomial}\ (\text{n}=n_S,\ \text{p}=\theta_H) \\
FA &\sim \text{Binomial}\ (\text{n}=n_N,\ \text{p}=\theta_F) \\
\end{align*}
$$

1.モデルの定義
数式表現を実直にモデル記述します。

### モデルの定義

# 初期値設定
nS = 15      # 適切な文の数15
nN = 15      # 不適切な文の数15

# モデルの定義
with pm.Model() as model:
    
    ### データ関連定義
    # coordの定義
    model.add_coord('data', data.index, mutable=True)    
    # dataの定義
    hitData = pm.ConstantData('hitData', value=data['H'], dims='data')
    faData = pm.ConstantData('faData', value=data['FA'], dims='data')
    
    ### 事前分布
    
    # muC,D:集団の平均、c,dの平均muの事前分布 stanはsigma=inv_sqrt(1/1000)
    muC = pm.Normal('muC', mu=0, sigma=1/np.sqrt(1/1000))
    muD = pm.Normal('muD', mu=0, sigma=1/np.sqrt(1/1000))
    # λ:cとdの精度の事前分布、σに変換される
    lambdaC = pm.Gamma('lambdaC', alpha=1/1000, beta=1/1000)
    lambdaD = pm.Gamma('lambdaD', alpha=1/1000, beta=1/1000)
    # σ:集団の標準偏差、λから変換 stanはsigmaC=inv_sqrt(lambdaC)
    sigmaC = pm.Deterministic('sigmaC', 1/pt.sqrt(lambdaC))
    sigmaD = pm.Deterministic('sigmaD', 1/pt.sqrt(lambdaD))
    
    # 個人のcとd、集団の平均と標準偏差をパラメータにもつ正規分布に従う
    c = pm.Normal('c', mu=muC, sigma=sigmaC, dims='data')
    d = pm.Normal('d', mu=muD, sigma=sigmaD, dims='data')

    
    # テキストではφを標準正規分布の累積分布関数の逆関数と呼んでいるが
    # stanでは Phi関数、つまり標準正規分布の累積分布関数を使用
    # またCDF関数が不明のため、logCDFをexp関数に通してCDFに変換
    # ヒット率
    normDistH = pm.Normal.dist(mu=0, sigma=1)
    thetaH = pt.exp(pm.logcdf(normDistH, d/2 - c))
    # 誤警報率
    normDistF = pm.Normal.dist(mu=0, sigma=1)
    thetaF = pt.exp(pm.logcdf(normDistF, -d/2 - c))

    ### 尤度
    # 個人のついての観測値:ヒット
    hit = pm.Binomial('hit', n=nS, p=thetaH, observed=hitData, dims='data')
    # 個人のついての観測値:誤警報
    fa = pm.Binomial('fa', n=nN, p=thetaF, observed=faData, dims='data')

【モデル注釈】

  • coordの定義
    座標に名前を付けたり、その座標が取りうる値を設定できます。
    今回は次の1つを設定しました。
    ・行の座標:名前「data」、値「行インデックス」

  • dataの定義
    目的変数であるヒット数$${hitData}$$、誤警報数$${faData}$$を設定しました。

  • パラメータの事前分布

    • テキストにならって設定しました。
      ◯◯の逆関数・inv_◯◯~関数は「1/◯◯」のように設定しました。

    • 累積分布関数$${CDF}$$の扱い
      PyMCの累積分布関数のクラス・関数・メソッドがよくわからなかったので、次のように少々回りくどい設定にしました。

      • pm.Normal.dist()で正規分布関数を設定

      • pm.logcdf()で上記正規分布関数を指定した対数-累積分布関数を設定

      • pt.exp()で上記対数-累積分布関数のeを底とするべき乗をとって、累積分布関数化

  • 尤度
    ヒット率$${\theta_H}$$・誤警報率$${\theta_F}$$、適切な文章数$${n_S}$$・不適切な文章数$${n_N}$$をパラメータにする2つの二項分布を定義します。

2.モデルの外観の確認

### モデルの表示
model

【実行結果】
いろんな確率分布クラスを用いています。

# モデルの可視化
pm.model_to_graphviz(model)

【実行結果】
弁別力$${d}$$・バイアス$${c}$$・ヒット数$${hit}$$・誤警報数$${fa}$$の繋がりをざっくり掴めます。

3.事後分布からのサンプリング
乱数生成数はテキストが1チェーン、9000ドローですが、この記事では、4チェーン、各2500ドロー(合計10000)にしました。
発散防止の目的でtarget_acceptを高い値にしています。
nuts_sampler='numpyro'とすることで、numpyroをNUTSサンプラーに利用できます。
処理時間はおよそ1分でした。

# 事後分布からのサンプリング 1分
# テキストは draws=9000, tune=1000, chains=1
with model:
    idata = pm.sample(draws=2500, tune=1000, chains=4, target_accept=0.995,
                      nuts_sampler='numpyro', random_seed=123)

【実行結果】省略

4.サンプリングデータの確認
$${\hat{R}}$$、事後分布の要約統計量、トレースプロットを確認します。
事後分布の収束確認はテキストにならって$${\hat{R} \leq 1.1}$$としています。

### r_hat>1.1の確認
rhat_idata = az.rhat(idata)
(rhat_idata > 1.1).sum()

【実行結果】
$${\hat{R} > 1.1}$$のパラメータは「0」件です。
全てのパラメータが$${\hat{R} \leq 1.1}$$であることを確認できました。

ざっくり事後分布サンプリングデータの要約統計量とトレースプロットを確認します。

# 推論データの要約統計量
pm.summary(idata, hdi_prob=0.95)

【実行結果】

# トレースプロットの描画
pm.plot_trace(idata)
plt.tight_layout();

【実行結果】
だいたい収束している感じがします。
sigmaDは各チェーンの峰が少々乱れている感じです。

弓道のイラスト(女性):「いらすとや」さんより

5.分析~テキストにならって
興味のあるパラメータ$${\mu_d, \mu_c, \sigma_d, \sigma_c}$$を見ます。
まずはpm.summary()でざっと眺めます。

### 推論データの要約統計量 ※表16.2に対応
var_names = ['muD', 'muC', 'sigmaD', 'sigmaC']
pm.summary(idata, hdi_prob=0.95, var_names=var_names).round(2)

【実行結果】

テキストの表16.2と同じ統計量を算出します。
事後分布統計量計算の関数は繰り返し利用します。

### パラメータμ,σの要約統計量の計算 ※表16.2に相当

## 事後分布統計量計算の関数定義
def calc_stats(calc_vars):
    mean = calc_vars.mean()
    std = calc_vars.std()
    quantiles = np.quantile(calc_vars, [0.025, 0.25, 0.5, 0.75, 0.975])
    return np.hstack([mean, std, quantiles])

## 初期値設定
# 変数名のリスト
var_list = ['muD', 'muC', 'sigmaD', 'sigmaC']
# データフレームの項目名のリスト
index_list = ['事後期待値', '事後標準偏差', '2.5%', '25%', '50%', '75%', '97.5%']
# データフレームの初期化
stats_df = pd.DataFrame()

## 事後分布統計量を算定してデータフレームを作成
# 変数名リストの変数名ごとに繰り返し処理
for var in var_list:
    # 推論データから指定変数の事後分布サンプリングデータを取り出し
    vars = idata.posterior[var].stack(x=('chain', 'draw'))
    # 事後分布統計量を関数で計算して、データフレームに結合
    stats_df = pd.concat(
        [stats_df, pd.Series(calc_stats(vars), name=f'{var}')], axis=1)

# データフレームを整える
stats_df = stats_df.set_axis(index_list, axis=0).T
display(stats_df.round(2))

【実行結果】
ほぼテキストの統計値と一致しています。
PyMCモデル自体の信頼性はテキストと同様と思われます。

あわせて、テキスト177~178ページの各パラメータの95%HDI(最高密度区間)を算出します。
arviz の hdi() を使えば一発で計算できます!

### 各母数の最高密度区間の算出 ※テキスト177~178ページの「最高密度区間」
var_names = ['muD', 'muC', 'sigmaD', 'sigmaC']
az.hdi(idata, hdi_prob=0.95, var_names=var_names).to_dataframe().T.round(2)

【実行結果】

さらに、テキストの図16.3と同様の各パラメータの事後分布プロットを描画します。
pm.plot_posterior() でサクッとプロットです!

### 推論データの事後分布プロットの描画 ※図16.3に相当
var_names = ['muD', 'sigmaD', 'muC', 'sigmaC']
pm.plot_posterior(idata, hdi_prob=0.95, var_names=var_names);

【実行結果】
sigmaDの峰がやや崩れています。

【分析:英語表現を適切に判断できる力について】
外国語の表現の適否を見分ける弁別力について、集団平均$${\mu_d}$$は0.24、また95%HDIは[0.14, 0.34]であり0をまたがない=正の能力という値です。
実験参加者が実験によって発揮した結果に基づくと、参加者全体としては英語の弁別力があると判断できそうです。

子供向け英語教室のイラスト:「いらすとや」さんより

さらにさらに、テキスト178ページの「負の弁別力を示すメンバーの割合」を算出しましょう。
$${\mu_d}$$・$${\sigma_d}$$の事後分布サンプリングデータを分布のパラメータとする「正規分布」に基づいて、弁別力がマイナスの確率=正規分布の確率変数が0以下になる確率「 rate 」を累積分布関数で求めます。
scipy.stats の norm.cdf() を利用します。

### 負の弁別力を示すメンバーの割合 ※表16.3に相当

# 事後分布サンプルデータからmu_dとsigma_dを取り出す
mu_d = idata.posterior.muD.stack(sample=('chain', 'draw')).data
sigma_d = idata.posterior.sigmaD.stack(sample=('chain', 'draw')).data

# mu_d, sigma_dをパラメータとする正規分布において、点0に対応する累積分布関数の算出
# scipy.statsのnorm()で累積分布関数cdfを計算
rate = stats.norm.cdf(0, loc=mu_d, scale=sigma_d)

# 統計量算出~事後分布統計量計算の関数を利用
rate_df = pd.DataFrame(calc_stats(rate), index=index_list, columns=['値']).T
display(rate_df.round(3))

【実行結果】

テキスト178ページの rate の95%HDIも算出しましょう。

### 上記生成量rateの95%HDIの計算 ※テキスト178ページの「生成量の最高密度区間」に相当
print(az.hdi(rate, hdi_prob=0.95).round(2))

【実行結果】
rate の 95%HDIは[0.00,  0.24]です。

【分析:英語表現を適切に判断できない可能性のある参加者の割合】
(おそらく95%HDIを念頭において)数%から20%台ほどの割合の参加者は、誤った判断、つまり、適切な文章を不適切と判断したり、不適切な文章を適切と判断する傾向を示す、可能性があるとするテキストに同感です!

著者はベイズ推論の結果から、次のようにまとめています。

外国語において語用論的に適切な表現と不適切な表現を弁別することは、非常に難しいことであることがわかる。

テキストより引用

続く著者の知見は一見の価値ありですので、ぜひ、テキストをお読み下さいね!

eラーニングのイラスト(英語・女性):「いらすとや」さんより

6.推論データ(idata)の保存
推論データを再利用する場合に備えてファイルに保存しましょう。
idataをpickleで保存します。

### idataの保存 pickle
file = r'idata_ch16.pkl'
with open(file, 'wb') as f:
    pickle.dump(idata, f)

読み込みコードは次のとおりです。

### idataの読み込み pickle
file = r'idata_ch16.pkl'
with open(file, 'rb') as f:
    idata_load = pickle.load(f)

以上で第16章は終了です。

おわりに


外国語の表現理解

身の回りで外国語のことばをよく体験するのは、Pythonパッケージ群の公式サイトの技術情報です。
主に関数やメソッドの機能概要、パラメータ説明を知りたいシーンで英語に直面します。

英語に詳しくないため、ほぼ日本語翻訳ツールに頼っています。
Chromeブラウザではページ全体をGoogle翻訳で日本語化したり、一部の文章を範囲指定してDeepLで翻訳したり。
英語の意味が理解できない場合であっても、コードを動かして機能を知ることもできます(英語レス)。

ニュース、情報記事、小説などの英語は久方読んだり聞いたりしていません(汗)
海外の方と英語でコミュニケーションを取ることもほとんど無くなりました。

ただ、生成AIの機能拡充に伴って、さまざまな媒体で同時自動翻訳が行われ、言語の違いがコミュニケーションの壁にならない時代は、もうすぐだと思います。
このことを信じて、これからも、英語特化型の学習はしないでおこうと思います!(英語から顔を背ける・・・)

翻訳機を使う人のイラスト:「いらすとや」さんより



シリーズの記事

次の記事

前の記事

目次


ブログの紹介


note で4つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!

1.のんびり統計
統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。

2.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で
書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析トピックを PythonとPyMC Ver.5で取り組みます。
豊富なテーマ(お題)を実践することによって、PythonとPyMCの基礎体力づくりにつながる、そう信じています。
日々、Web検索に勤しみ、時系列モデルの理解、Pythonパッケージの把握、R・Stanコードの翻訳に励んでいます!
このシリーズがPython時系列分析の入門者の参考になれば幸いです🍀

3.Python機械学習プログラミング実践記
書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learn と PyTorch の教科書です。
よかったらぜひ、お試しくださいませ。

4.データサイエンスっぽいことを綴る
統計、データ分析、AI、機械学習、Python のコラムを不定期に綴っています。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。
ベイズ書籍の実践記録も掲載中です。

最後までお読みいただきまして、ありがとうございました。

この記事が参加している募集

今年やりたい10のこと

この記事が気に入ったらサポートをしてみませんか?