library(rstan)
library(dplyr)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
packageVersion("rstan")
[1] '2.32.6'
Rでデータサイエンス
n <- 30
x <- seq(n)
mu <- a <- b <- y <- vector()
mu[1] <- 5
b[1] <- 1
s_w <- 5
s_t <- 1
s_v <- 4
for (i in 2:n) {
mu[i] <- rnorm(n = 1, mean = mu[i - 1], s_w)
b[i] <- rnorm(n = 1, mean = b[i - 1], s_t)
}
for (i in 1:n) {
a[i] <- mu[i] + b[i] * x[i]
y[i] <- rnorm(n = 1, mean = a[i], s_v)
}
data.frame(mu, b, a, y) %>%
write.csv(file = "sample-10.csv", fileEncoding = "UTF-8", row.names = F)
ggplot(mapping = aes(x = x)) + geom_line(mapping = aes(y = y, col = "y")) + geom_line(mapping = aes(y = a, col = "a")) + theme(legend.title = element_blank())
# stan code
model_code <- "
data {
int n;
vector[n] x;
vector[n] y;
}
parameters {
vector[n] mu;
vector[n] b;
real<lower=0> s_w;
real<lower=0> s_t;
real<lower=0> s_v;
}
transformed parameters {
vector[n] a;
for(i in 1:n) {
a[i] = mu[i] + b[i] * x[i];
}
}
model {
// 状態方程式
for(i in 2:n) {
mu[i] ~ normal(mu[i-1], s_w);
b[i] ~ normal(b[i-1], s_t);
}
// 観測方程式
for(i in 1:n) {
y[i] ~ normal(a[i], s_v);
}
}"
setwd(stan_output)
csvfiles <- dir(pattern = "sample_file_stan_code_10")
fit_sample <- read_stan_csv(csvfiles = csvfiles)
result_summary <- fit_sample %>%
summary() %>%
.$summary %>%
data.frame(check.names = F) %>%
tibble::add_column(pars = row.names(.), .before = 1) %>%
{
row.names(.) <- NULL
.
}
result_summary[, c("pars", "mean", "2.5%", "50%", "97.5%", "Rhat")]
pars mean 2.5% 50% 97.5% Rhat
1 mu[1] 6.79339994 -7.3416200 6.29819500 24.255822 1.0025112
2 mu[2] 7.00525842 -9.8262373 6.62467500 25.906515 1.0020121
3 mu[3] 10.00420574 -10.2934800 9.54871000 32.884777 1.0020391
4 mu[4] 14.17552251 -9.8871502 13.22500000 43.529005 1.0021440
5 mu[5] 15.79423690 -11.5091975 14.19170000 51.024470 1.0018605
6 mu[6] 14.91282994 -16.2165400 13.47330000 53.875820 1.0013985
7 mu[7] 11.78185670 -23.1517175 10.66245000 52.491470 1.0014437
8 mu[8] 8.98947734 -30.5464300 8.34256000 51.795440 1.0013861
9 mu[9] 10.22753437 -31.1475950 9.22867000 57.110805 1.0014782
10 mu[10] 5.17739316 -43.3538025 5.61581000 52.208208 1.0016759
11 mu[11] 8.68064794 -39.9813275 8.10738000 62.456967 1.0012023
12 mu[12] 11.21037595 -38.6699800 9.77248000 70.220530 1.0009209
13 mu[13] 9.12604321 -45.1523150 8.48092000 68.925915 1.0011709
14 mu[14] 9.66003927 -46.7209850 8.69359500 72.393600 1.0009930
15 mu[15] 8.34799586 -51.6467225 7.86142000 73.750967 1.0009285
16 mu[16] 8.91463812 -54.1950950 8.25848000 78.426770 1.0011341
17 mu[17] 15.92161841 -44.0820250 12.55655000 96.735045 1.0010939
18 mu[18] 13.50239461 -51.0584775 11.14935000 92.396042 1.0009911
19 mu[19] 12.16042959 -55.0423100 10.30665000 91.387732 1.0008457
20 mu[20] 8.59927382 -65.8660850 8.07129500 85.420040 1.0009710
21 mu[21] 8.67308463 -66.9528800 8.13816500 88.386950 1.0008589
22 mu[22] 10.68095676 -64.7539925 9.11640000 94.511585 1.0008425
23 mu[23] 11.66738769 -63.8011000 9.71258500 99.416305 1.0007653
24 mu[24] 10.67252275 -67.7587900 9.26283500 99.550175 1.0006854
25 mu[25] 4.46247493 -85.6077700 5.96726000 88.362342 1.0010638
26 mu[26] 5.25664790 -86.2436500 6.23650500 91.415665 1.0009723
27 mu[27] 6.94913898 -84.6390625 7.18897500 97.852452 1.0007838
28 mu[28] 5.94109156 -88.8497000 6.96365500 97.520570 1.0009333
29 mu[29] 8.67991302 -83.9789425 8.38005000 106.197325 1.0007988
30 mu[30] 7.01167893 -90.5994425 7.55640500 104.117425 1.0009464
31 b[1] 1.36764948 -5.8865112 1.29819500 8.891522 1.0026289
32 b[2] 1.36460049 -5.5808450 1.33045000 8.410540 1.0026196
33 b[3] 1.56827138 -5.3015877 1.57881500 8.130886 1.0023412
34 b[4] 1.86788110 -5.0311797 2.05318000 8.000933 1.0030966
35 b[5] 1.83777685 -4.8585223 2.07953000 7.469443 1.0029382
36 b[6] 1.47925292 -4.8162015 1.74511000 6.737251 1.0022140
37 b[7] 0.80813643 -4.9444172 0.94043400 5.865995 1.0016619
38 b[8] 0.26672859 -4.9908265 0.31786450 5.105644 1.0012669
39 b[9] 0.44543545 -4.7480005 0.55949100 5.127882 1.0016059
40 b[10] -0.52228114 -5.2743660 -0.58261000 4.174606 1.0018584
41 b[11] 0.13803157 -4.7166420 0.17938550 4.546472 1.0011370
42 b[12] 0.58114766 -4.3175907 0.69467800 4.828782 1.0012055
43 b[13] 0.05076165 -4.5406767 0.09411200 4.216238 1.0010304
44 b[14] 0.12922702 -4.4065887 0.21097200 4.123219 1.0010046
45 b[15] -0.20061872 -4.5424440 -0.17193550 3.772313 1.0008660
46 b[16] -0.16695237 -4.4475340 -0.13276500 3.717166 1.0010267
47 b[17] 1.39034738 -3.5003240 1.59733500 5.089817 1.0017051
48 b[18] 0.66610697 -3.7353512 0.80470550 4.261122 1.0010616
49 b[19] 0.33497191 -3.8738380 0.44761550 3.907629 1.0009365
50 b[20] -0.59397875 -4.4238105 -0.56222150 3.049575 1.0009200
51 b[21] -0.56724585 -4.3629465 -0.54418500 2.975355 1.0008284
52 b[22] -0.10427863 -3.9380277 -0.03215285 3.346544 1.0009489
53 b[23] 0.08185365 -3.7895763 0.16795800 3.418827 1.0008668
54 b[24] -0.16649849 -3.9677252 -0.10876850 3.189701 1.0008024
55 b[25] -1.86883094 -5.2553737 -1.93369000 1.758043 1.0010995
56 b[26] -1.56867865 -4.9025745 -1.61181000 1.952822 1.0009070
57 b[27] -1.09695901 -4.4646460 -1.09950000 2.280139 1.0009128
58 b[28] -1.43786386 -4.7214505 -1.47439000 1.970848 1.0010786
59 b[29] -0.62766118 -4.0268208 -0.60636100 2.610024 1.0008054
60 b[30] -1.17175580 -4.4185780 -1.19117000 2.106571 1.0008829
61 s_w 6.76339584 1.1301642 5.93060000 16.779705 1.0099947
62 s_t 0.91081977 0.2405074 0.92336100 1.597050 1.0067681
63 s_v 6.08949798 0.9193153 5.38633000 14.772415 1.0055988
64 a[1] 8.16104926 -2.5735067 7.36538500 22.393072 1.0013117
65 a[2] 9.73445937 0.6827111 8.95648500 21.735502 1.0005524
66 a[3] 14.70902053 5.5092655 14.40930000 24.551005 0.9998855
67 a[4] 21.64704795 10.3910675 22.39545000 30.503155 1.0004367
68 a[5] 24.98311939 12.1851750 26.08885000 33.934725 1.0010392
69 a[6] 23.78834761 11.8425625 24.69140000 32.935485 1.0011276
70 a[7] 17.43881123 7.5571200 17.34180000 27.471297 1.0001178
71 a[8] 11.12330693 2.1418478 10.29525000 22.781458 1.0001303
72 a[9] 14.23645295 2.0781945 15.16450000 23.252695 1.0006582
73 a[10] -0.04541684 -9.6991952 -1.76255000 16.427677 1.0016772
74 a[11] 10.19899572 -0.4909817 10.26570000 20.986375 1.0005853
75 a[12] 18.18414826 4.6872148 19.37585000 27.600042 1.0009173
76 a[13] 9.78594373 -0.6640661 9.31811500 21.861302 1.0007833
77 a[14] 11.46921675 -0.4750549 11.85605000 22.136315 1.0000242
78 a[15] 5.33871341 -4.7941455 4.35400500 19.046980 1.0004375
79 a[16] 6.24340022 -3.7336053 4.59835500 22.285680 1.0012242
80 a[17] 39.55752451 16.5498625 42.34905000 50.683530 1.0024190
81 a[18] 25.49232068 12.4753350 25.78990000 36.995102 1.0000393
82 a[19] 18.52489469 4.9040028 19.36305000 29.341365 1.0002043
83 a[20] -3.28030167 -13.9790150 -4.86875000 13.244600 1.0005497
84 a[21] -3.23907620 -14.2522075 -4.26937500 11.835370 1.0001652
85 a[22] 8.38682602 -5.0095197 8.78599000 19.860980 1.0001829
86 a[23] 13.55002116 -3.2304617 14.76350000 25.039542 1.0006606
87 a[24] 6.67655854 -11.5911250 8.61564500 17.530750 1.0012242
88 a[25] -42.25829628 -53.9786375 -45.01530000 -18.936123 1.0021766
89 a[26] -35.52899830 -47.5835650 -36.30360000 -20.318730 1.0000664
90 a[27] -22.66875370 -36.9523150 -21.98885000 -11.307388 1.0007899
91 a[28] -34.31909410 -45.6921000 -35.65730000 -17.773935 1.0010248
92 a[29] -9.52225973 -27.9028200 -7.80200000 1.896585 1.0006758
93 a[30] -28.14099366 -40.8638200 -28.77065000 -13.139335 1.0000575
94 lp__ -130.12457493 -168.8321500 -133.28550000 -78.587435 1.0420706
mcmc_sample <- rstan::extract(fit_sample)
apply(X = mcmc_sample[["b"]], MARGIN = 2, FUN = quantile, probs = c(0.025, 0.5, 0.975)) %>%
t
2.5% 50% 97.5%
[1,] -5.886511 1.29819500 8.891522
[2,] -5.580845 1.33045000 8.410540
[3,] -5.301588 1.57881500 8.130886
[4,] -5.031180 2.05318000 8.000933
[5,] -4.858522 2.07953000 7.469443
[6,] -4.816202 1.74511000 6.737251
[7,] -4.944417 0.94043400 5.865995
[8,] -4.990827 0.31786450 5.105644
[9,] -4.748000 0.55949100 5.127882
[10,] -5.274366 -0.58261000 4.174606
[11,] -4.716642 0.17938550 4.546472
[12,] -4.317591 0.69467800 4.828782
[13,] -4.540677 0.09411200 4.216238
[14,] -4.406589 0.21097200 4.123219
[15,] -4.542444 -0.17193550 3.772313
[16,] -4.447534 -0.13276500 3.717166
[17,] -3.500324 1.59733500 5.089817
[18,] -3.735351 0.80470550 4.261122
[19,] -3.873838 0.44761550 3.907629
[20,] -4.423811 -0.56222150 3.049575
[21,] -4.362946 -0.54418500 2.975355
[22,] -3.938028 -0.03215285 3.346544
[23,] -3.789576 0.16795800 3.418827
[24,] -3.967725 -0.10876850 3.189701
[25,] -5.255374 -1.93369000 1.758043
[26,] -4.902575 -1.61181000 1.952822
[27,] -4.464646 -1.09950000 2.280139
[28,] -4.721450 -1.47439000 1.970848
[29,] -4.026821 -0.60636100 2.610024
[30,] -4.418578 -1.19117000 2.106571
setwd(stan_output)
sampledf <- read.csv(file = "sample-10.csv")
b <- sampledf$b
b_hat_mean <- apply(X = mcmc_sample[["b"]], MARGIN = 2, FUN = mean)
b_hat_median <- apply(X = mcmc_sample[["b"]], MARGIN = 2, FUN = median)
b_df <- data.frame(n = factor(seq(n)), b, b_hat_mean, b_hat_median)
tidydf <- b_df %>%
gather(key = "key", value = "value", colnames(.)[-1])
ggplot(data = tidydf, mapping = aes(x = n, y = value, fill = key)) + geom_bar(stat = "identity", position = "dodge") + theme(axis.title = element_blank(), legend.title = element_blank())
result_df <- data.frame(t(apply(X = mcmc_sample[["a"]], MARGIN = 2, quantile, probs = c(0.025, 0.5, 0.975))))
colnames(result_df) <- c("lwr", "fit_median", "upr")
result_df$time <- seq(n)
result_df$obs <- y
ggplot(data = result_df, mapping = aes(x = time)) + geom_line(aes(y = fit_median, color = "fit_median")) + geom_ribbon(aes(ymin = lwr, ymax = upr), alpha = 0.1) + geom_point(aes(x = time, y = obs, color = "obs")) + theme_minimal() + theme(axis.title = element_blank(), legend.title = element_blank())