SSブログ

Stanで隠れマルコフモデル (3) Multistate model [統計]

隠れマルコフモデルでBPA 9章のMultistate model風のモデルをStanに実装してみました。

# Stan users mailing listにて、“Successful unsupervised hidden Markov model”という議論が去年あったことに気がつきました。

データは、動物の標識再捕獲を想定しています。この動物は、サイトAとサイトBの2つのサイト間で移動しており、サイトA、Bでの生存率がそれぞれphi_A、phi_B、サイトAからBへの移動率がpsi_AB、サイトBからAへの移動率がpsi_BAとします。また、サイトA、Bでの捕獲率をそれぞれp_A、p_Bとします。観測データは各測定ごとに、1、2、3の値をとり、1がサイトAでの捕獲、2がサイトBでの捕獲、3が捕獲されないことをしめします。潜在状態も1、2、3の値をとり、1がサイトAに滞在、2がサイトBに滞在、3が死亡をしめします。以下のようなRコードでデータを生成しました。

set.seed(123)
Ns <- 80
Nt <- 20
phi_A <- 0.7;
phi_B <- 0.8;
psi_AB <- 0.5;
psi_BA <- 0.1;
p_A <- 0.9;
p_B <- 0.6;
Mt <- matrix(c(phi_A * (1 - psi_AB), phi_A * psi_AB, 1 - phi_A,
               phi_B * psi_BA, phi_B * (1 - psi_BA), 1 - phi_B,
               0.0, 0.0, 1.0), ncol = 3, byrow = TRUE)
Me <- matrix(c(p_A, 0.0, 1.0 - p_A,
               0.0, p_B, 1.0 - p_B,
               0.0, 0.0, 1.0), ncol = 3, byrow = TRUE)

y <- z <- matrix(0, nrow = Ns, ncol = Nt)
z[, 1] <- rbinom(Ns, 1, 0.5) + 1
for (i in 1:Ns) {
  for (t in 2:Nt)
    z[i, t] <- grep(1, rmultinom(1, 1, Mt[z[i, t - 1], ]))
  for (t in 1:Nt)      
    y[i, t] <- grep(1, rmultinom(1, 1, Me[z[i, t], ]))
}

Stanコードは以下のようになります。前回までとおなじく、Stan Modeling Language User's Guide and Reference Manualの隠れマルコフモデルのコードをつかっていますが、遷移確率行列・出力確率行列が0のところもあるので、gamma、accは対数にはせず、最後にまとめて対数確率を計算するようにしています。

data {
  int<lower=1> Ns;                      // number of series
  int<lower=1> Nt;                      // number of occasions
  int<lower=1,upper=3> y[Ns, Nt];       // observed values
}

parameters {
  real<lower=0,upper=1> phi_A;
  real<lower=0,upper=1> phi_B;
  real<lower=0,upper=1> psi_AB;
  real<lower=0,upper=1> psi_BA;
  real<lower=0,upper=1> p_A;
  real<lower=0,upper=1> p_B;
}

transformed parameters {
  simplex[3] ps[3];
  simplex[3] po[3];

  ps[1, 1] <- phi_A * (1.0 - psi_AB);
  ps[1, 2] <- phi_A * psi_AB;
  ps[1, 3] <- 1.0 - phi_A;
  ps[2, 1] <- phi_B * psi_BA;
  ps[2, 2] <- phi_B * (1.0 - psi_BA);
  ps[2, 3] <- 1.0 - phi_B;
  ps[3, 1] <- 0.0;
  ps[3, 2] <- 0.0;
  ps[3, 3] <- 1.0;
  po[1, 1] <- p_A;
  po[1, 2] <- 0.0;
  po[1, 3] <- 1.0 - p_A;
  po[2, 1] <- 0.0;
  po[2, 2] <- p_B;
  po[2, 3] <- 1.0 - p_B;
  po[3, 1] <- 0.0;
  po[3, 2] <- 0.0;
  po[3, 3] <- 1.0;
}

model {
  real acc[3];
  real gamma[Nt, 3];
  real log_gamma[3];

  // priors
  phi_A ~ uniform(0, 1);
  phi_B ~ uniform(0, 1);
  psi_AB ~ uniform(0, 1);
  psi_BA ~ uniform(0, 1);
  p_A ~ uniform(0, 1);
  p_B ~ uniform(0, 1);

  // likelihood
  for (i in 1:Ns) {
    for (k in 1:3)
      gamma[1, k] <- po[k, y[i, 1]];
    for (t in 2:Nt) {
      for (k in 1:3) {
        for (j in 1:3)
          acc[j] <- gamma[t - 1, j] * ps[j, k] *
                    po[k, y[i, t]];
        gamma[t, k] <- sum(acc);
      }
    }
    for (j in 1:3)
      log_gamma[j] <- log(gamma[Nt, j]);
    increment_log_prob(log_sum_exp(log_gamma));
  }
}

実行するRコードです。

library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

model_file <- "hmm_test2.stan"
data <- list(Ns = Ns, Nt = Nt, y = y)
n_chains <- 4
inits <- lapply(1:n_chains, function(i)
    list(phi_A = runif(1, 0.7, 0.9),
         phi_B = runif(1, 0.7, 0.9),
         psi_AB = runif(1, 0.3, 0.7),
         psi_BA = runif(1, 0.3, 0.7),
         p_A = runif(1, 0.3, 0.7),
         p_B = runif(1, 0.3, 0.7)))
params <- c("phi_A", "phi_B", "psi_AB", "psi_BA", "p_A", "p_B")
fit <- stan(model_file, data = data, init = inits,
            pars = params, chains = n_chains,
            iter = 2000, warmup = 1000, thin = 1,
            control = list(adapt_delta = 0.8))

結果です。

Inference for Stan model: hmm_test2.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

          mean se_mean   sd    2.5%     25%     50%     75%   97.5% n_eff Rhat
phi_A     0.70    0.00 0.06    0.58    0.66    0.70    0.73    0.81  2875    1
phi_B     0.77    0.00 0.03    0.71    0.75    0.77    0.79    0.82  2980    1
psi_AB    0.64    0.00 0.07    0.50    0.60    0.64    0.69    0.78  3123    1
psi_BA    0.10    0.00 0.02    0.06    0.09    0.10    0.12    0.16  2582    1
p_A       0.93    0.00 0.06    0.79    0.90    0.94    0.97    1.00  1787    1
p_B       0.64    0.00 0.04    0.57    0.62    0.65    0.67    0.72  2690    1
lp__   -392.11    0.05 1.80 -396.44 -393.08 -391.72 -390.78 -389.62  1354    1

Samples were drawn using NUTS(diag_e) at Wed Feb 10 18:26:39 2016.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

おおよそのところデータを生成した値を再現できました。


タグ:STAn
nice!(2)  コメント(0)  トラックバック(0) 
共通テーマ:日記・雑感

nice! 2

コメント 0

コメントを書く

お名前:
URL:
コメント:
画像認証:
下の画像に表示されている文字を入力してください。

Facebook コメント

トラックバック 0