SSブログ

Stan: 常微分方程式をつかったモデルのパラメーター推定 [統計]

前回のつづきで、常微分方程式をつかったモデルのパラメーターを推定してみました。

ロジスティック式のrKを推定してみます。以下のようなデータを生成しました。

set.seed(123)
r <- 0.1
K <- 1000
ts <- seq(0, 30)
N <- vector("integer", length(ts))
N[1] <- 100
for (t in seq_along(ts)[-1]) {
  N[t] <- rpois(1, N[t -1] + r * N[t - 1] * (K - N[t - 1]) / K)
}

このようなデータになります。

library(ggplot2)
p <- ggplot(data.frame(ts, N))
p + geom_point(aes(x = ts, y = N)) + xlab("Time")

Rplot04.png

このデータからrKを推定するStanコードです。

functions {
  real[] logistic(real t,
                  real[] y,
                  real[] theta,
                  real[] x_r,
                  int[] x_i) {
    real dNdt[1];
    real r;
    real K;
    real N;
    r <- theta[1];
    K <- theta[2];
    N <- y[1];
    dNdt[1] <- r * N * (K - N) / K;
    return dNdt;
  }
}
data {
  int<lower=1> T;
  int<lower=0> N[T];
  real N0;
  real t0;
  real ts[T];
}
transformed data {
  real x_r[0];
  int x_i[0];
  real y0[1];
  y0[1] <- N0;
}
parameters {
  real theta[2];	// r <- theta[1]; K <- theta[2];
}
model {
  real y_hat[T, 1];
  y_hat <- integrate_ode(logistic, y0, t0, ts, theta, x_r, x_i);
  for (t in 1:T) {
    N[t] ~ poisson(y_hat[t, 1]);
  }
  // priors
  theta[1] ~ normal(0, 3);
  theta[2] ~ uniform(0, 1.0e+6);
}

通常、rの値は0からおおきく離れることはないので、事前分布はNormal(0, 32)にしました。

計算を実行するRコードです。

library(rstan)
library(parallel)

inits <- list()
inits[[1]] <- list(theta = c(0.01, 1500))
inits[[2]] <- list(theta = c(0.1, 500))
inits[[3]] <- list(theta = c(0.05, 2000))
inits[[4]] <- list(theta = c(0.2, 1000))
seeds <- c(3, 31, 314, 3141)

model <- stan_model(file = "logistic2.stan");
sflist <- mclapply(1:4, mc.cores = 4,
                   function(i)
                     sampling(model,
                              data = list(T = length(ts) - 1,
                                          N0 = N[1],
                                          N = N[-1],
                                          t0 = ts[1],
                                          ts = ts[-1]),
                              init = list(inits[[i]]),
                              chain_id = i, seed = seeds[i],
                              chains = 1, iter = 2000))
fit <- sflist2stanfit(sflist)

初期値によっては非常に時間がかかるようなので、だいたいそれらしい値をあたえるようにしています。

結果です。

print(fit)

r_hat <- get_posterior_mean(fit)[1]
K_hat <- get_posterior_mean(fit)[2]
x <- seq(0, 100, 1)
y <- vector("numeric", 100)
y[1] <- N[1]
for (t in seq_along(x)[-1]) {
  y[t] <- y[t - 1] + r_hat * y[t - 1] * (K_hat - y[t - 1]) / K_hat
}

library(ggplot2)
p <- ggplot(data.frame(ts, N))
p + geom_point(aes(x = ts, y = N)) +
    geom_line(data = data.frame(x, y), aes(x = x, y = y)) +
    xlab("Time")
Inference for Stan model: logistic2.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

             mean se_mean    sd     2.5%      25%      50%      75%    97.5%
theta[1]     0.09    0.00  0.00     0.09     0.09     0.09     0.10     0.10
theta[2]   971.37    4.84 79.63   843.63   914.22   959.91  1019.93  1153.48
lp__     49905.04    0.04  0.98 49902.36 49904.62 49905.35 49905.76 49906.02
         n_eff Rhat
theta[1]   293 1.01
theta[2]   271 1.01
lp__       613 1.01

Samples were drawn using NUTS(diag_e) at Fri Oct 24 23:14:39 2014.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

Rplot05.png


タグ:RStan STAn
nice!(2)  コメント(0)  トラックバック(0) 
共通テーマ:日記・雑感

nice! 2

コメント 0

コメントを書く

お名前:
URL:
コメント:
画像認証:
下の画像に表示されている文字を入力してください。

Facebook コメント

トラックバック 0