Stan: 常微分方程式をつかったモデルのパラメーター推定 [統計]
前回のつづきで、常微分方程式をつかったモデルのパラメーターを推定してみました。
ロジスティック式のrとKを推定してみます。以下のようなデータを生成しました。
set.seed(123) r <- 0.1 K <- 1000 ts <- seq(0, 30) N <- vector("integer", length(ts)) N[1] <- 100 for (t in seq_along(ts)[-1]) { N[t] <- rpois(1, N[t -1] + r * N[t - 1] * (K - N[t - 1]) / K) }
このようなデータになります。
library(ggplot2) p <- ggplot(data.frame(ts, N)) p + geom_point(aes(x = ts, y = N)) + xlab("Time")
このデータからrとKを推定するStanコードです。
functions { real[] logistic(real t, real[] y, real[] theta, real[] x_r, int[] x_i) { real dNdt[1]; real r; real K; real N; r <- theta[1]; K <- theta[2]; N <- y[1]; dNdt[1] <- r * N * (K - N) / K; return dNdt; } } data { int<lower=1> T; int<lower=0> N[T]; real N0; real t0; real ts[T]; } transformed data { real x_r[0]; int x_i[0]; real y0[1]; y0[1] <- N0; } parameters { real theta[2]; // r <- theta[1]; K <- theta[2]; } model { real y_hat[T, 1]; y_hat <- integrate_ode(logistic, y0, t0, ts, theta, x_r, x_i); for (t in 1:T) { N[t] ~ poisson(y_hat[t, 1]); } // priors theta[1] ~ normal(0, 3); theta[2] ~ uniform(0, 1.0e+6); }
通常、rの値は0からおおきく離れることはないので、事前分布はNormal(0, 32)にしました。
計算を実行するRコードです。
library(rstan) library(parallel) inits <- list() inits[[1]] <- list(theta = c(0.01, 1500)) inits[[2]] <- list(theta = c(0.1, 500)) inits[[3]] <- list(theta = c(0.05, 2000)) inits[[4]] <- list(theta = c(0.2, 1000)) seeds <- c(3, 31, 314, 3141) model <- stan_model(file = "logistic2.stan"); sflist <- mclapply(1:4, mc.cores = 4, function(i) sampling(model, data = list(T = length(ts) - 1, N0 = N[1], N = N[-1], t0 = ts[1], ts = ts[-1]), init = list(inits[[i]]), chain_id = i, seed = seeds[i], chains = 1, iter = 2000)) fit <- sflist2stanfit(sflist)
初期値によっては非常に時間がかかるようなので、だいたいそれらしい値をあたえるようにしています。
結果です。
print(fit) r_hat <- get_posterior_mean(fit)[1] K_hat <- get_posterior_mean(fit)[2] x <- seq(0, 100, 1) y <- vector("numeric", 100) y[1] <- N[1] for (t in seq_along(x)[-1]) { y[t] <- y[t - 1] + r_hat * y[t - 1] * (K_hat - y[t - 1]) / K_hat } library(ggplot2) p <- ggplot(data.frame(ts, N)) p + geom_point(aes(x = ts, y = N)) + geom_line(data = data.frame(x, y), aes(x = x, y = y)) + xlab("Time")
Inference for Stan model: logistic2. 4 chains, each with iter=2000; warmup=1000; thin=1; post-warmup draws per chain=1000, total post-warmup draws=4000. mean se_mean sd 2.5% 25% 50% 75% 97.5% theta[1] 0.09 0.00 0.00 0.09 0.09 0.09 0.10 0.10 theta[2] 971.37 4.84 79.63 843.63 914.22 959.91 1019.93 1153.48 lp__ 49905.04 0.04 0.98 49902.36 49904.62 49905.35 49905.76 49906.02 n_eff Rhat theta[1] 293 1.01 theta[2] 271 1.01 lp__ 613 1.01 Samples were drawn using NUTS(diag_e) at Fri Oct 24 23:14:39 2014. For each parameter, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence, Rhat=1).
コメント 0