ベイズ統計モデリング 〜周期性モデルのリファクタリング〜
背景
前回、ベイズ統計モデリングを始めて実施した話をしました。
こちらのデータは周期性がなさそうだったので、そのモデリングを無視していました。実際、周期性のあるデータの場合も十分考えられるので、そちらを行いたいと思います。
ただし、単純な周期性モデルを行うだけならば上記「ベイズ統計モデリングによるデータ分析入門」に書いてあります。こちらのコードでも最低限動かすことは可能です。
こちらの著者が掲載しているサイトで収束を良くするコツが書かれています。
こちらのパラメータの外出しを行いたいと思います。
データ
データは、上記の本のGithubにある「5-6-1-sales-ts-4.csv」を利用します。
sales_df_4 <- read.csv("5-6-1-sales-ts-4.csv")
sales_df_4$date <- as.POSIXct(sales_df_4$date)
手法
準備として、以下の実行をします。plotSSM.R は本のgithub上にあるものを利用しました。
library(rstan)
library(bayesplot)
library(ggfortify)
library(gridExtra)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
source("plotSSM.R", encoding = "utf-8")
手法1
最初にgithubにある5-6-1-basic-structual-time-series.stanを利用します。
data_list <- list(
y = sales_df_4$sales,
T = nrow(sales_df_4)
)
start_time = Sys.time()
basic_structual <- stan(
file = "5-6-1-basic-structual-time-series.stan",
data = data_list,
seed = 1
)
diff_time = Sys.time() - start_time
これで収束結果と経過時間を確認します。
手法: パラメータの外だし
次にパラメータを外だししたものを実行します。方法は以下のページをベースに 5-6-1-basic-structual-time-series.stan を変更します。
data {
int T; // データ取得期間の長さ
vector[T] y; // 観測値
}
parameters {
vector[T] mu_err;
vector[T] gamma_err;
real<lower=0> s_z;
real<lower=0> s_v;
real<lower=0> s_s;
}
transformed parameters {
vector[T] alpha;
vector[T] mu;
vector[T] gamma;
for(t in 1:2) {
mu[t] = mu_err[t];
}
for(t in 3:T) {
mu[t] = 2 * mu[t-1] - mu[t-2] + s_z * mu_err[t];
}
for(t in 1:6) {
gamma[t] = gamma_err[t];
}
for(t in 7:T) {
gamma[t] = -(sum(gamma[t-6: t-1])) + s_s * gamma_err[t];
}
alpha = mu + gamma;
}
model {
mu_err[2:T] ~ normal(0, 1);
gamma_err[2:T] ~ normal(0, 1);
y ~ normal(alpha, s_v);
}
元のテキスト通り、mu_err[1]とgamma_err[1]は無情報事前分布を想定しています。
start_time = Sys.time()
reparametorize <- stan(
file = "5-6-1-basic-structual-time-series-reparameter.stan",
data = data_list,
seed = 2)
diff_time = Sys.time() - start_time
print(diff_time)
まずはこれと元々のコードで比較します。
※ seed について、1だとかなり収束が悪かったので2 と 3 と繰り替えました。(もっとやれという話もありますが。)2と3の結果の方が近く、こちらの方がメインの結果かなと思い、そちらを利用しました。初期値依存しているのかパラメータをいじらないことがstanだと厳しいのか勉強になりました。
結果
計算結果を確認します。
オリジナル
オリジナルの方法です。収束をよくするためのパラメータ設定は無しで行います。
計算終了後の warningについて
警告メッセージ:
1: There were 1 divergent transitions after warmup. See
http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
to find out why this is a problem and how to eliminate them.
2: There were 589 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
3: There were 4 chains where the estimated Bayesian Fraction of Missing Information was low. See
http://mc-stan.org/misc/warnings.html#bfmi-low
4: Examine the pairs() plot to diagnose sampling problems
5: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#bulk-ess
6: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#tail-ess
経過時間は
Time difference of 1.767173 mins
となりました。数値の確認もしていきます。
print(basic_structual,
par = c("s_z", "s_v", "s_s"),
probs = c(0.025, 0.5, 0.975))
これで、
Inference for Stan model: 5-6-1-basic-structual-time-series.
4 chains, each with iter=2000; warmup=1000; thin=1;
post-warmup draws per chain=1000, total post-warmup draws=4000.
mean se_mean sd 2.5% 50% 97.5% n_eff Rhat
s_z 0.21 0.01 0.09 0.09 0.19 0.44 77 1.04
s_v 7.37 0.04 0.98 5.53 7.32 9.35 614 1.00
s_s 4.22 0.06 0.96 2.46 4.16 6.23 289 1.01
となりました。
mcmc_combo(
rstan::extract(basic_structual, permuted = FALSE),
pars = c("s_z", "s_v", "s_s"))
s_z の結果がややよろしくないようです。
パラメータの外だし
次にパラメータを外だしした結果について
警告メッセージ:
1: There were 127 divergent transitions after warmup. See
http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
to find out why this is a problem and how to eliminate them.
2: There were 3696 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
3: Examine the pairs() plot to diagnose sampling problems
warningsは、6つから3つに減りました。
Time difference of 2.355197 mins
計算時間はさらにかかっていますね。
結果を確認していくと、
print(reparametorize,
par = c("s_z", "s_v", "s_s"),
probs = c(0.025, 0.5, 0.975))
Inference for Stan model: 5-6-1-basic-structual-time-series-reparameter.
4 chains, each with iter=2000; warmup=1000; thin=1;
post-warmup draws per chain=1000, total post-warmup draws=4000.
mean se_mean sd 2.5% 50% 97.5% n_eff Rhat
s_z 1.39 0.01 0.50 0.62 1.32 2.57 1399 1
s_v 25.96 0.04 2.30 21.89 25.82 30.92 3134 1
s_s 8.29 0.03 1.47 5.76 8.18 11.57 2954 1
Samples were drawn using NUTS(diag_e) at Mon Feb 7 08:14:54 2022.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).
全てのパラメータがRhat = 1 になっています。
mcmc_combo(
rstan::extract(reparametorize, permuted = FALSE),
pars = c("s_z", "s_v", "s_s"))
こちらも綺麗な結果になっています。
推定結果を表示します。
mcmc_sample_re <- rstan::extract(reparametorize)
p_all <- plotSSM(
mcmc_sample = mcmc_sample_re,
time_vec = sales_df_4$date,
obs_vec = sales_df_4$sales,
state_name = "alpha",
graph_title = "all",
y_label = "sales")
p_trend <- plotSSM(
mcmc_sample = mcmc_sample_re,
time_vec = sales_df_4$date,
obs_vec = sales_df_4$sales,
state_name = "mu",
graph_title = "trend_except_seasonality",
y_label = "sales")
p_cycle <- plotSSM(
mcmc_sample = mcmc_sample_re,
time_vec = sales_df_4$date,
state_name = "gamma",
graph_title = "seasonality",
y_label = "gamma")
grid.arrange(p_all, p_trend, p_cycle)
こちらも見るとgamma と mu の初期値の設定が悪いのか、seasonality は1ヶ月程度、trend は 1週間程度、推定結果が悪そうです。ここは課題として残ってしまいました。
パラメータ調整
最後に元事記と同じように、エラーメッセージに合わせてstanのパラメータを調整します。
reparametorize_v2 <- stan(
file = "5-6-1-basic-structual-time-series-reparameter.stan",
data = data_list,
seed = 2,
control = list( max_treedepth = 15, adapt_delta = 0.99)
)
control に max_treedepth と adapt_delta を設定します。結果として、warnings は消えました。
まとめ
パラメータの外だししたことで、iteration の数を増やすことなく収束させることができました。時系列の初期の推定がうまくいっていないので、その調整が必要そうです。通常設定だと収束しない、最新時点の状態を把握したいという時なら使えそうです。
発展版として、ドリフト成分と水準成分を別にして実行したいなと思います。
この記事が気に入ったらサポートをしてみませんか?