見出し画像

機械学習:R言語(Tidymodels)チュートリアルの和訳② ブートストラッピング

第2弾!!


元の記事はこちら
https://www.tidymodels.org/learn/statistics/bootstrap/

Bootstrap resampling and tidy regression models

library(tidymodels)

ggplot(mtcars, aes(mpg, wt)) + 
    geom_point()


可視化
mtcarsデータセットに対して非線形モデルを適合させる方法を示します。
nlsfit <- nls(mpg ~ k / wt + b, mtcars, start = list(k = 1, b = 0))
summary(nlsfit)
#>
#> Formula: mpg ~ k/wt + b
#>
#> Parameters:
#>   Estimate Std. Error t value Pr(>|t|)
#> k   45.829      4.249  10.786 7.64e-12 ***
#> b    4.386      1.536   2.855  0.00774 **
#> ---
#> Signif. codes:  0 '' 0.001 '' 0.01 '' 0.05 '.' 0.1 ' ' 1
#>
#> Residual standard error: 2.774 on 30 degrees of freedom
#>
#> Number of iterations to convergence: 1
#> Achieved convergence tolerance: 6.813e-09
ggplot(mtcars, aes(wt, mpg)) +
geom_point() +
geom_line(aes(y = predict(nlsfit)))


モデルと実態の可視化

ブートストラッピングの導入

ブートストラッピングを使用することで、信頼区間と予測をより現実的に評価できます。以下に、ブートストラッピングを使用して非線形モデルを適合させる方法を示します。

set.seed(27)
boots <- bootstraps(mtcars, times = 2000, apparent = TRUE)
boots
#> # Bootstrap sampling with apparent sample
#> # A tibble: 2,001 × 2
#>    splits          id
#>    <list>          <chr>
#>  1 <split [32/13]> Bootstrap0001
#>  2 <split [32/10]> Bootstrap0002
#>  3 <split [32/13]> Bootstrap0003
#>  4 <split [32/11]> Bootstrap0004
#>  5 <split [32/9]>  Bootstrap0005
#>  6 <split [32/10]> Bootstrap0006
#>  7 <split [32/11]> Bootstrap0007
#>  8 <split [32/13]> Bootstrap0008
#>  9 <split [32/11]> Bootstrap0009
#> 10 <split [32/11]> Bootstrap0010
#> # ℹ 1,991 more rows

各ブートストラップサンプルに対してnls()モデルを適合させるヘルパー関数を作成し、この関数をpurrr::map()を使って全てのブートストラップサンプルに一度に適用します。同様に、unnestを使って整然とした係数情報の列を作成します。

fit_nls_on_bootstrap <- function(split) {
nls(mpg ~ k / wt + b, analysis(split), start = list(k = 1, b = 0))
}
boot_models <-
boots %>%
mutate(model = map(splits, fit_nls_on_bootstrap),
coef_info = map(model, tidy))
boot_coefs <-
boot_models %>%
unnest(coef_info)
boot_coefs
#> # A tibble: 4,002 × 8
#>    splits          id          model term  estimate std.error statistic  p.value
#>    <list>          <chr>       <lis> <chr>    <dbl>     <dbl>     <dbl>    <dbl>
#>  1 <split [32/13]> Bootstrap0… <nls> k        42.1       4.05     10.4  1.91e-11
#>  2 <split [32/13]> Bootstrap0… <nls> b         5.39      1.43      3.78 6.93e- 4
#>  3 <split [32/10]> Bootstrap0… <nls> k        49.9       5.66      8.82 7.82e-10
#>  4 <split [32/10]> Bootstrap0… <nls> b         3.73      1.92      1.94 6.13e- 2
#>  5 <split [32/13]> Bootstrap0… <nls> k        37.8       2.68     14.1  9.01e-15
#>  6 <split [32/13]> Bootstrap0… <nls> b         6.73      1.17      5.75 2.78e- 6
#>  7 <split [32/11]> Bootstrap0… <nls> k        45.6       4.45     10.2  2.70e-11
#>  8 <split [32/11]> Bootstrap0… <nls> b         4.75      1.62      2.93 6.38e- 3
#>  9 <split [32/9]>  Bootstrap0… <nls> k        43.6       4.63      9.41 1.85e-10
#> 10 <split [32/9]>  Bootstrap0… <nls> b         5.89      1.68      3.51 1.44e- 3
#> # ℹ 3,992 more rows

信頼区間の計算

percentile_intervals <- int_pctl(boot_models, coef_info)

percentile_intervals
#> # A tibble: 2 × 6
#>   term   .lower .estimate .upper .alpha .method
#>   <chr>   <dbl>     <dbl>  <dbl>  <dbl> <chr>
#> 1 b      0.0475      4.12   7.31   0.05 percentile
#> 2 k     37.6        46.7   59.8    0.05 percentile
ggplot(boot_coefs, aes(estimate)) +
geom_histogram(bins = 30) +
facet_wrap( ~ term, scales = "free") +
geom_vline(aes(xintercept = .lower), data = percentile_intervals, col = "blue") +
geom_vline(aes(xintercept = .upper), data = percentile_intervals, col = "blue")


可視化

モデルの適合例

augment()を使って、適合曲線の不確実性を視覚化できます。ブートストラップサンプルが非常に多いため、可視化にはモデル適合のサンプルのみを表示します。

boot_aug <-
boot_models %>%
sample_n(200) %>%
mutate(augmented = map(model, augment)) %>%
unnest(augmented)
boot_aug
#> # A tibble: 6,400 × 8
#>    splits          id            model  coef_info   mpg    wt .fitted .resid
#>    <list>          <chr>         <list> <list>    <dbl> <dbl>   <dbl>  <dbl>
#>  1 <split [32/11]> Bootstrap1644 <nls>  <tibble>   16.4  4.07    15.6  0.829
#>  2 <split [32/11]> Bootstrap1644 <nls>  <tibble>   19.7  2.77    21.9 -2.21
#>  3 <split [32/11]> Bootstrap1644 <nls>  <tibble>   19.2  3.84    16.4  2.84
#>  4 <split [32/11]> Bootstrap1644 <nls>  <tibble>   21.4  2.78    21.8 -0.437
#>  5 <split [32/11]> Bootstrap1644 <nls>  <tibble>   26    2.14    27.8 -1.75
#>  6 <split [32/11]> Bootstrap1644 <nls>  <tibble>   33.9  1.84    32.0  1.88
#>  7 <split [32/11]> Bootstrap1644 <nls>  <tibble>   32.4  2.2     27.0  5.35
#>  8 <split [32/11]> Bootstrap1644 <nls>  <tibble>   30.4  1.62    36.1 -5.70
#>  9 <split [32/11]> Bootstrap1644 <nls>  <tibble>   21.5  2.46    24.4 -2.86
#> 10 <split [32/11]> Bootstrap1644 <nls>  <tibble>   26    2.14    27.8 -1.75
#> # ℹ 6,390 more rows
ggplot(boot_aug, aes(wt, mpg)) +
geom_line(aes(y = .fitted, group = id), alpha = .2, col = "blue") +
geom_point()


可視化

わずかな変更を加えるだけで、他の種類の予測モデルや仮説検定モデルに対しても簡単にブートストラッピングを行うことができます。なぜなら、tidy()およびaugment()関数は多くの統計出力に対応しているからです。別の例として、smooth.spline()を使用することもできます。これはデータに対して三次平滑スプラインを適合させるものです。

fit_spline_on_bootstrap <- function(split) {
data <- analysis(split)
smooth.spline(data$wt, data$mpg, df = 4)
}
boot_splines <-
boots %>%
sample_n(200) %>%
mutate(spline = map(splits, fit_spline_on_bootstrap),
aug_train = map(spline, augment))
splines_aug <-
boot_splines %>%
unnest(aug_train)
ggplot(splines_aug, aes(x, y)) +
geom_line(aes(y = .fitted, group = id), alpha = 0.2, col = "blue") +
geom_point()


可視化

END

この記事が気に入ったらサポートをしてみませんか?