SSブログ

Stan: 欠測値の推定 [統計]

Stanのマニュアルを見ながら、欠測値の推定をやってみたのでメモ。

下のようなデータがあったとする。

 [1,] -3.0150087  0.8224304  -3.5474533
 [2,] -2.0796367  1.0398335  -1.4325317
 [3,] -2.2329870  1.1210478  -0.9401105
 [4,] -2.8172679 -1.5561365 -10.7433671
 [5,] -1.2279092  0.5832502  -0.6014049
 [6,] -2.1656119  1.5512307   0.3979848
 [7,] -1.0271256  1.2240323   1.4461063
 [8,] -0.2834660  1.1681719   3.3905226
 [9,] -1.7447630  2.0518301   3.1254418
[10,] -1.6334189  1.3834487   0.6055146
[11,] -0.8192108  2.2543968   5.4048861
[12,] -1.3568079  1.6684795   1.3956139
[13,] -0.7046781  1.7936807   3.1889774
[14,] -1.8120819  3.2153371   4.3616880
[15,] -0.4087949  3.8920564  10.9359402
[16,] -2.0551791  2.0741935   1.9299091
[17,] -1.1615289  3.7516962   7.7151888
[18,] -1.8406299  1.7685126   1.7925101
[19,] -1.3740456  2.5434525   4.5620398
[20,] -1.3664153  1.0109986   1.2107255
[21,] -1.3189724  2.3155315   3.3852880
[22,] -2.6820334  4.4423275   8.2837741
[23,] -2.7232567  2.5496929   2.7953275
[24,] -0.3264740  1.9707566   5.1704778
[25,] -2.5957556  1.1692166  -2.3269244
[26,] -0.8401562  3.2464344   8.2907162
[27,] -1.8825776  0.6242447  -1.4520879
[28,] -1.7407786  1.6688916   1.5948284
[29,] -1.6176379  2.9809115   5.5199313
[30,] -2.7114817  4.1907706   7.1706177

これが、下のようにいくつかが欠測になったとする。

              x1         x2           y
 [1,] -3.0150087  0.8224304  -3.5474533
 [2,]         NA  1.0398335  -1.4325317
 [3,] -2.2329870  1.1210478  -0.9401105
 [4,] -2.8172679 -1.5561365 -10.7433671
 [5,] -1.2279092  0.5832502  -0.6014049
 [6,] -2.1656119  1.5512307   0.3979848
 [7,] -1.0271256  1.2240323   1.4461063
 [8,] -0.2834660  1.1681719   3.3905226
 [9,] -1.7447630  2.0518301   3.1254418
[10,]         NA         NA   0.6055146
[11,] -0.8192108  2.2543968   5.4048861
[12,] -1.3568079  1.6684795   1.3956139
[13,]         NA  1.7936807   3.1889774
[14,]         NA  3.2153371   4.3616880
[15,] -0.4087949         NA  10.9359402
[16,] -2.0551791  2.0741935   1.9299091
[17,] -1.1615289  3.7516962   7.7151888
[18,] -1.8406299  1.7685126   1.7925101
[19,]         NA  2.5434525   4.5620398
[20,] -1.3664153         NA   1.2107255
[21,] -1.3189724         NA   3.3852880
[22,] -2.6820334  4.4423275   8.2837741
[23,]         NA  2.5496929   2.7953275
[24,] -0.3264740         NA   5.1704778
[25,] -2.5957556  1.1692166  -2.3269244
[26,]         NA  3.2464344   8.2907162
[27,] -1.8825776  0.6242447  -1.4520879
[28,] -1.7407786  1.6688916   1.5948284
[29,] -1.6176379  2.9809115   5.5199313
[30,] -2.7114817  4.1907706   7.1706177

このようなデータで、欠測部分をStanで推定してみる。Reference Manual の 7.1節に例題があるが、Stanでは欠測値を扱えないので、欠測値の部分をパラメーターにして計算する。RとStanのコードは以下のとおり。x1_missx2_missの事前分布を指定しないと収束しない。

# Data
set.seed(17)
N <- 30
x1 <- rnorm(N, -2, 1)
x2 <- rnorm(N, 2, 1)
y <- 2 * x1 + 3 * x2 + rnorm(N, 0, 0.5)

# save original values
ox1 <- x1
ox2 <- x2
oy <- y

# Missing data
x1[sample(N, 7)] <- NA
x2[sample(N, 5)] <- NA

print(cbind(x1, x2, y))

## sort
idx0 <- (1:N)[!is.na(x1) & !is.na(x2)]
idx1 <- (1:N)[!is.na(x1) & is.na(x2)]
idx2 <- (1:N)[is.na(x1) & !is.na(x2)]
idx3 <- (1:N)[is.na(x1) & is.na(x2)]

x1 <- x1[c(idx0, idx1, idx2, idx3)]
x2 <- x2[c(idx0, idx1, idx2, idx3)]
y <- y[c(idx0, idx1, idx2, idx3)]
ox1 <- ox1[c(idx0, idx1, idx2, idx3)]
ox2 <- ox2[c(idx0, idx1, idx2, idx3)]
oy <- oy[c(idx0, idx1, idx2, idx3)]

print(cbind(x1, x2, y))

## Stan
library(rstan)
library(parallel)

stancode <- "
data {
  int<lower=0> N;
  int<lower=0> M0;
  int<lower=0> M1;
  int<lower=0> M2;
  real x1[M0 + M1];
  real x2[M0 + M2];
  real y[N];
}
transformed data {
  int M3;
  M3 <- N - M0 - M1 - M2;
}
parameters {
  real x1_miss[M2 + M3];
  real x2_miss[M1 + M3];
  real beta[3];
  real<lower=0> sigma;
}
model {
  for (i in 1:M0) {
    y[i] ~ normal(beta[1] +
                    beta[2] * x1[i] +
                    beta[3] * x2[i],
                  sigma);
  }
  for (i in (M0 + 1):(M0 + M1)) {
    y[i] ~ normal(beta[1] +
                    beta[2] * x1[i] +
                    beta[3] * x2_miss[i - M0],
                  sigma);
  }
  for (i in (M0 + 1):(M0 + M2)) {
    y[M1 + i] ~ normal(beta[1] +
                         beta[2] * x1_miss[i - M0] +
                         beta[3] * x2[i],
                       sigma);
  }
  for (i in 1:M3) {
    y[M0 + M1 + M2 + i] ~ normal(beta[1] +
                                   beta[2] * x1_miss[M2 + i] +
                                   beta[3] * x2_miss[M1 + i],
                                 sigma);
  }
  x1_miss ~ normal(mean(x1), sd(x1));
  x2_miss ~ normal(mean(x2), sd(x2));
}
"

model <- stan_model(model_code = stancode)

data <- list(N = N,
             M0 = length(idx0),
             M1 = length(idx1),
             M2 = length(idx2),
             x1 = x1[!is.na(x1)],
             x2 = x2[!is.na(x2)],
             y = y)
seeds <- c(1, 2, 3, 5)

sflist <- mclapply(1:4,
                    function(i) {
                        sampling(model,
                                 data = data, seed = seeds[i],
                                 chains = 1, iter = 2000)},
                   mc.cores = 4)
fit <- sflist2stanfit(sflist)

##
library(ggplot2)

samp <- extract(fit, pars = c("x1_miss", "x2_miss"))
q.x1_miss <- apply(samp$x1_miss, 2, quantile, probs = c(0.025, 0.5, 0.975))
q.x2_miss <- apply(samp$x2_miss, 2, quantile, probs = c(0.025, 0.5, 0.975))

df <- data.frame(var = c(rep("x1_miss", ncol(q.x1_miss)),
                         rep("x2_miss", ncol(q.x2_miss))),
                 x = as.factor(c(1:ncol(q.x1_miss),
                                 1:ncol(q.x2_miss))),
                 orig = c(ox1[is.na(x1)], ox2[is.na(x2)]),
                 low = c(q.x1_miss[1, ], q.x2_miss[1, ]),
                 med = c(q.x1_miss[2, ], q.x2_miss[2, ]),
                 high = c(q.x1_miss[3, ], q.x2_miss[3, ]))
p <- ggplot(df) +
    geom_pointrange(aes(x = x, y = med, ymax = high, ymin = low),
                    colour = "red", size = 1, alpha = 0.8) +
    geom_point(aes(x = x, y = orig), shape = 3, size = 4,
               colour = "black", alpha = 1) +
    xlab("x") + ylab("value") +
    facet_wrap(~ var, scales = "free_x")
print(p)

結果

Inference for Stan model: stancode.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

            mean se_mean   sd  2.5%   25%   50%   75% 97.5% n_eff Rhat
x1_miss[1] -2.19    0.00 0.27 -2.71 -2.37 -2.19 -2.01 -1.65  4000    1
x1_miss[2] -1.12    0.00 0.26 -1.62 -1.29 -1.12 -0.96 -0.62  2912    1
x1_miss[3] -2.50    0.00 0.27 -3.03 -2.67 -2.50 -2.34 -1.97  4000    1
x1_miss[4] -1.52    0.00 0.26 -2.03 -1.68 -1.52 -1.36 -1.00  4000    1
x1_miss[5] -2.33    0.00 0.26 -2.84 -2.50 -2.32 -2.15 -1.81  4000    1
x1_miss[6] -0.79    0.00 0.28 -1.35 -0.96 -0.78 -0.61 -0.24  3839    1
x1_miss[7] -1.80    0.02 0.76 -3.32 -2.30 -1.81 -1.30 -0.32  2520    1
x2_miss[1]  3.91    0.00 0.21  3.51  3.78  3.91  4.05  4.33  3482    1
x2_miss[2]  1.34    0.00 0.18  0.97  1.22  1.34  1.46  1.68  4000    1
x2_miss[3]  2.03    0.00 0.18  1.67  1.91  2.02  2.14  2.40  4000    1
x2_miss[4]  1.95    0.00 0.20  1.56  1.83  1.95  2.08  2.34  3049    1
x2_miss[5]  1.43    0.01 0.53  0.38  1.08  1.44  1.78  2.47  2391    1
beta[1]     0.00    0.01 0.36 -0.72 -0.23 -0.01  0.23  0.70  1946    1
beta[2]     2.01    0.00 0.16  1.70  1.91  2.02  2.11  2.34  2209    1
beta[3]     2.98    0.00 0.09  2.81  2.92  2.98  3.04  3.15  3013    1
sigma       0.52    0.00 0.10  0.37  0.45  0.51  0.58  0.76  1383    1
lp__        0.65    0.14 3.86 -8.15 -1.71  1.03  3.49  6.76   795    1

Samples were drawn using NUTS(diag_e) at Sat Mar  7 14:59:25 2015.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

Rplot001.png
赤の点と線が事後分布の中央値および95%信用区間、黒の十字が元のデータ。

だいたいうまく推定できているが、x1とx2の両方が欠測になったx1_miss[7]とx2_miss[5]では信用区間が広くなっている。


タグ:RStan
nice!(2)  コメント(0)  トラックバック(0) 
共通テーマ:日記・雑感

nice! 2

コメント 0

コメントを書く

お名前:
URL:
コメント:
画像認証:
下の画像に表示されている文字を入力してください。

Facebook コメント

トラックバック 0