RStanで状態空間モデル:時変係数モデル

Rでデータサイエンス

状態空間モデル:時変係数モデル

パッケージ読み込み等

library(rstan)
library(dplyr)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
packageVersion("rstan")
[1] '2.32.6'

サンプルデータ

n <- 30
x <- seq(n)
mu <- a <- b <- y <- vector()
mu[1] <- 5
b[1] <- 1
s_w <- 5
s_t <- 1
s_v <- 4
for (i in 2:n) {
    mu[i] <- rnorm(n = 1, mean = mu[i - 1], s_w)
    b[i] <- rnorm(n = 1, mean = b[i - 1], s_t)
}
for (i in 1:n) {
    a[i] <- mu[i] + b[i] * x[i]
    y[i] <- rnorm(n = 1, mean = a[i], s_v)
}
data.frame(mu, b, a, y) %>%
    write.csv(file = "sample-10.csv", fileEncoding = "UTF-8", row.names = F)
ggplot(mapping = aes(x = x)) + geom_line(mapping = aes(y = y, col = "y")) + geom_line(mapping = aes(y = a, col = "a")) + theme(legend.title = element_blank())
Figure 1
ggplot(mapping = aes(x = b, y = mu, col = "b")) + geom_line() + geom_point() + theme(legend.title = element_blank())
Figure 2
ggplot(mapping = aes(x = x, y = mu, col = "mu")) + geom_line() + geom_point() + theme(legend.title = element_blank())
Figure 3

Stanコード

# stan code
model_code <- "
data {
  int n;
  vector[n] x;
  vector[n] y;
}

parameters {
  vector[n] mu;
  vector[n] b;
  real<lower=0> s_w;
  real<lower=0> s_t;
  real<lower=0> s_v;
}

transformed parameters {
  vector[n] a;
  for(i in 1:n) {
    a[i] = mu[i] + b[i] * x[i];
  }

}

model {
  // 状態方程式
  for(i in 2:n) {
    mu[i] ~ normal(mu[i-1], s_w);
    b[i] ~ normal(b[i-1], s_t);
  }
  // 観測方程式
  for(i in 1:n) {
    y[i] ~ normal(a[i], s_v);
  }

}"

MCMC

data_list <- list(n = n, y = y, x = x)
# data_list
stan_model_compile <- stan_model(model_code = model_code)
setwd(stan_output)
fit_sample <- sampling(object = stan_model_compile, data = data_list, iter = 10000, chain = 3, warmup = 2500, refresh = 0, sample_file = "sample_file_stan_code_10.csv", diagnostic_file = "diagnostic_file_stan_code_10.txt")

結果の確認

setwd(stan_output)
csvfiles <- dir(pattern = "sample_file_stan_code_10")
fit_sample <- read_stan_csv(csvfiles = csvfiles)
result_summary <- fit_sample %>%
    summary() %>%
    .$summary %>%
    data.frame(check.names = F) %>%
    tibble::add_column(pars = row.names(.), .before = 1) %>%
    {
        row.names(.) <- NULL
        .
    }
result_summary[, c("pars", "mean", "2.5%", "50%", "97.5%", "Rhat")]
     pars          mean         2.5%           50%      97.5%      Rhat
1   mu[1]    6.79339994   -7.3416200    6.29819500  24.255822 1.0025112
2   mu[2]    7.00525842   -9.8262373    6.62467500  25.906515 1.0020121
3   mu[3]   10.00420574  -10.2934800    9.54871000  32.884777 1.0020391
4   mu[4]   14.17552251   -9.8871502   13.22500000  43.529005 1.0021440
5   mu[5]   15.79423690  -11.5091975   14.19170000  51.024470 1.0018605
6   mu[6]   14.91282994  -16.2165400   13.47330000  53.875820 1.0013985
7   mu[7]   11.78185670  -23.1517175   10.66245000  52.491470 1.0014437
8   mu[8]    8.98947734  -30.5464300    8.34256000  51.795440 1.0013861
9   mu[9]   10.22753437  -31.1475950    9.22867000  57.110805 1.0014782
10 mu[10]    5.17739316  -43.3538025    5.61581000  52.208208 1.0016759
11 mu[11]    8.68064794  -39.9813275    8.10738000  62.456967 1.0012023
12 mu[12]   11.21037595  -38.6699800    9.77248000  70.220530 1.0009209
13 mu[13]    9.12604321  -45.1523150    8.48092000  68.925915 1.0011709
14 mu[14]    9.66003927  -46.7209850    8.69359500  72.393600 1.0009930
15 mu[15]    8.34799586  -51.6467225    7.86142000  73.750967 1.0009285
16 mu[16]    8.91463812  -54.1950950    8.25848000  78.426770 1.0011341
17 mu[17]   15.92161841  -44.0820250   12.55655000  96.735045 1.0010939
18 mu[18]   13.50239461  -51.0584775   11.14935000  92.396042 1.0009911
19 mu[19]   12.16042959  -55.0423100   10.30665000  91.387732 1.0008457
20 mu[20]    8.59927382  -65.8660850    8.07129500  85.420040 1.0009710
21 mu[21]    8.67308463  -66.9528800    8.13816500  88.386950 1.0008589
22 mu[22]   10.68095676  -64.7539925    9.11640000  94.511585 1.0008425
23 mu[23]   11.66738769  -63.8011000    9.71258500  99.416305 1.0007653
24 mu[24]   10.67252275  -67.7587900    9.26283500  99.550175 1.0006854
25 mu[25]    4.46247493  -85.6077700    5.96726000  88.362342 1.0010638
26 mu[26]    5.25664790  -86.2436500    6.23650500  91.415665 1.0009723
27 mu[27]    6.94913898  -84.6390625    7.18897500  97.852452 1.0007838
28 mu[28]    5.94109156  -88.8497000    6.96365500  97.520570 1.0009333
29 mu[29]    8.67991302  -83.9789425    8.38005000 106.197325 1.0007988
30 mu[30]    7.01167893  -90.5994425    7.55640500 104.117425 1.0009464
31   b[1]    1.36764948   -5.8865112    1.29819500   8.891522 1.0026289
32   b[2]    1.36460049   -5.5808450    1.33045000   8.410540 1.0026196
33   b[3]    1.56827138   -5.3015877    1.57881500   8.130886 1.0023412
34   b[4]    1.86788110   -5.0311797    2.05318000   8.000933 1.0030966
35   b[5]    1.83777685   -4.8585223    2.07953000   7.469443 1.0029382
36   b[6]    1.47925292   -4.8162015    1.74511000   6.737251 1.0022140
37   b[7]    0.80813643   -4.9444172    0.94043400   5.865995 1.0016619
38   b[8]    0.26672859   -4.9908265    0.31786450   5.105644 1.0012669
39   b[9]    0.44543545   -4.7480005    0.55949100   5.127882 1.0016059
40  b[10]   -0.52228114   -5.2743660   -0.58261000   4.174606 1.0018584
41  b[11]    0.13803157   -4.7166420    0.17938550   4.546472 1.0011370
42  b[12]    0.58114766   -4.3175907    0.69467800   4.828782 1.0012055
43  b[13]    0.05076165   -4.5406767    0.09411200   4.216238 1.0010304
44  b[14]    0.12922702   -4.4065887    0.21097200   4.123219 1.0010046
45  b[15]   -0.20061872   -4.5424440   -0.17193550   3.772313 1.0008660
46  b[16]   -0.16695237   -4.4475340   -0.13276500   3.717166 1.0010267
47  b[17]    1.39034738   -3.5003240    1.59733500   5.089817 1.0017051
48  b[18]    0.66610697   -3.7353512    0.80470550   4.261122 1.0010616
49  b[19]    0.33497191   -3.8738380    0.44761550   3.907629 1.0009365
50  b[20]   -0.59397875   -4.4238105   -0.56222150   3.049575 1.0009200
51  b[21]   -0.56724585   -4.3629465   -0.54418500   2.975355 1.0008284
52  b[22]   -0.10427863   -3.9380277   -0.03215285   3.346544 1.0009489
53  b[23]    0.08185365   -3.7895763    0.16795800   3.418827 1.0008668
54  b[24]   -0.16649849   -3.9677252   -0.10876850   3.189701 1.0008024
55  b[25]   -1.86883094   -5.2553737   -1.93369000   1.758043 1.0010995
56  b[26]   -1.56867865   -4.9025745   -1.61181000   1.952822 1.0009070
57  b[27]   -1.09695901   -4.4646460   -1.09950000   2.280139 1.0009128
58  b[28]   -1.43786386   -4.7214505   -1.47439000   1.970848 1.0010786
59  b[29]   -0.62766118   -4.0268208   -0.60636100   2.610024 1.0008054
60  b[30]   -1.17175580   -4.4185780   -1.19117000   2.106571 1.0008829
61    s_w    6.76339584    1.1301642    5.93060000  16.779705 1.0099947
62    s_t    0.91081977    0.2405074    0.92336100   1.597050 1.0067681
63    s_v    6.08949798    0.9193153    5.38633000  14.772415 1.0055988
64   a[1]    8.16104926   -2.5735067    7.36538500  22.393072 1.0013117
65   a[2]    9.73445937    0.6827111    8.95648500  21.735502 1.0005524
66   a[3]   14.70902053    5.5092655   14.40930000  24.551005 0.9998855
67   a[4]   21.64704795   10.3910675   22.39545000  30.503155 1.0004367
68   a[5]   24.98311939   12.1851750   26.08885000  33.934725 1.0010392
69   a[6]   23.78834761   11.8425625   24.69140000  32.935485 1.0011276
70   a[7]   17.43881123    7.5571200   17.34180000  27.471297 1.0001178
71   a[8]   11.12330693    2.1418478   10.29525000  22.781458 1.0001303
72   a[9]   14.23645295    2.0781945   15.16450000  23.252695 1.0006582
73  a[10]   -0.04541684   -9.6991952   -1.76255000  16.427677 1.0016772
74  a[11]   10.19899572   -0.4909817   10.26570000  20.986375 1.0005853
75  a[12]   18.18414826    4.6872148   19.37585000  27.600042 1.0009173
76  a[13]    9.78594373   -0.6640661    9.31811500  21.861302 1.0007833
77  a[14]   11.46921675   -0.4750549   11.85605000  22.136315 1.0000242
78  a[15]    5.33871341   -4.7941455    4.35400500  19.046980 1.0004375
79  a[16]    6.24340022   -3.7336053    4.59835500  22.285680 1.0012242
80  a[17]   39.55752451   16.5498625   42.34905000  50.683530 1.0024190
81  a[18]   25.49232068   12.4753350   25.78990000  36.995102 1.0000393
82  a[19]   18.52489469    4.9040028   19.36305000  29.341365 1.0002043
83  a[20]   -3.28030167  -13.9790150   -4.86875000  13.244600 1.0005497
84  a[21]   -3.23907620  -14.2522075   -4.26937500  11.835370 1.0001652
85  a[22]    8.38682602   -5.0095197    8.78599000  19.860980 1.0001829
86  a[23]   13.55002116   -3.2304617   14.76350000  25.039542 1.0006606
87  a[24]    6.67655854  -11.5911250    8.61564500  17.530750 1.0012242
88  a[25]  -42.25829628  -53.9786375  -45.01530000 -18.936123 1.0021766
89  a[26]  -35.52899830  -47.5835650  -36.30360000 -20.318730 1.0000664
90  a[27]  -22.66875370  -36.9523150  -21.98885000 -11.307388 1.0007899
91  a[28]  -34.31909410  -45.6921000  -35.65730000 -17.773935 1.0010248
92  a[29]   -9.52225973  -27.9028200   -7.80200000   1.896585 1.0006758
93  a[30]  -28.14099366  -40.8638200  -28.77065000 -13.139335 1.0000575
94   lp__ -130.12457493 -168.8321500 -133.28550000 -78.587435 1.0420706

\(\hat{\textrm{R}}\)の確認

rhat <- bayesplot::rhat(fit_sample)
bayesplot::mcmc_rhat(rhat = rhat) + bayesplot::yaxis_text(hjust = 1) + coord_flip() + theme(axis.text.x = element_text(angle = 90))
Figure 4

トレースプロット

bayesplot::mcmc_combo(x = rstan::extract(fit_sample, permuted = F), pars = c("s_w", "s_t", "s_v"))
Figure 5
bayesplot::mcmc_combo(x = rstan::extract(fit_sample, permuted = F), pars = c("mu[1]", "b[1]"))
Figure 6

推定された時変係数

mcmc_sample <- rstan::extract(fit_sample)
apply(X = mcmc_sample[["b"]], MARGIN = 2, FUN = quantile, probs = c(0.025, 0.5, 0.975)) %>%
    t
       
             2.5%         50%    97.5%
   [1,] -5.886511  1.29819500 8.891522
   [2,] -5.580845  1.33045000 8.410540
   [3,] -5.301588  1.57881500 8.130886
   [4,] -5.031180  2.05318000 8.000933
   [5,] -4.858522  2.07953000 7.469443
   [6,] -4.816202  1.74511000 6.737251
   [7,] -4.944417  0.94043400 5.865995
   [8,] -4.990827  0.31786450 5.105644
   [9,] -4.748000  0.55949100 5.127882
  [10,] -5.274366 -0.58261000 4.174606
  [11,] -4.716642  0.17938550 4.546472
  [12,] -4.317591  0.69467800 4.828782
  [13,] -4.540677  0.09411200 4.216238
  [14,] -4.406589  0.21097200 4.123219
  [15,] -4.542444 -0.17193550 3.772313
  [16,] -4.447534 -0.13276500 3.717166
  [17,] -3.500324  1.59733500 5.089817
  [18,] -3.735351  0.80470550 4.261122
  [19,] -3.873838  0.44761550 3.907629
  [20,] -4.423811 -0.56222150 3.049575
  [21,] -4.362946 -0.54418500 2.975355
  [22,] -3.938028 -0.03215285 3.346544
  [23,] -3.789576  0.16795800 3.418827
  [24,] -3.967725 -0.10876850 3.189701
  [25,] -5.255374 -1.93369000 1.758043
  [26,] -4.902575 -1.61181000 1.952822
  [27,] -4.464646 -1.09950000 2.280139
  [28,] -4.721450 -1.47439000 1.970848
  [29,] -4.026821 -0.60636100 2.610024
  [30,] -4.418578 -1.19117000 2.106571
dim(mcmc_sample[["b"]])
[1] 22500    30

時変係数の比較

setwd(stan_output)
sampledf <- read.csv(file = "sample-10.csv")
b <- sampledf$b
b_hat_mean <- apply(X = mcmc_sample[["b"]], MARGIN = 2, FUN = mean)
b_hat_median <- apply(X = mcmc_sample[["b"]], MARGIN = 2, FUN = median)
b_df <- data.frame(n = factor(seq(n)), b, b_hat_mean, b_hat_median)
tidydf <- b_df %>%
    gather(key = "key", value = "value", colnames(.)[-1])
ggplot(data = tidydf, mapping = aes(x = n, y = value, fill = key)) + geom_bar(stat = "identity", position = "dodge") + theme(axis.title = element_blank(), legend.title = element_blank())
Figure 7

時系列チャート

result_df <- data.frame(t(apply(X = mcmc_sample[["a"]], MARGIN = 2, quantile, probs = c(0.025, 0.5, 0.975))))
colnames(result_df) <- c("lwr", "fit_median", "upr")
result_df$time <- seq(n)
result_df$obs <- y
ggplot(data = result_df, mapping = aes(x = time)) + geom_line(aes(y = fit_median, color = "fit_median")) + geom_ribbon(aes(ymin = lwr, ymax = upr), alpha = 0.1) + geom_point(aes(x = time, y = obs, color = "obs")) + theme_minimal() + theme(axis.title = element_blank(), legend.title = element_blank())
Figure 8

参考引用資料

最終更新

Sys.time()
[1] "2024-03-19 12:58:13 JST"

R、Quarto、Package

R.Version()$version.string
[1] "R version 4.3.2 Patched (2023-12-27 r85754 ucrt)"
quarto::quarto_version()
[1] '1.4.542'
packageVersion(pkg = "tidyverse")
[1] '2.0.0'

著者