Stan: ガウス過程をためしてみたメモ(2) [統計]
こんどは、2次元の入力に対するクラス分類をためしてみました。
Stanコードです。マニュアルのモデルを合成して つくりました。
// This model is derived from Chapter 18 of // Stan Modeling Language User's Guide Reference Manual // by Stan Development Team // http://mc-stan.org/ functions { matrix L_cov_exp_quad_ARD(vector[] x, real alpha, vector rho, real delta) { int N = size(x); matrix[N, N] K; real sq_alpha = square(alpha); for (i in 1:(N - 1)) { K[i, i] = sq_alpha + delta; for (j in (i + 1):N) { K[i, j] = sq_alpha * exp(-0.5 * dot_self((x[i] - x[j]) ./ rho)); K[j, i] = K[i, j]; } } K[N, N] = sq_alpha + delta; return cholesky_decompose(K); } } data { int<lower=1> N1; int<lower=1> D; vector[D] X1[N1]; int Y1[N1]; int<lower=1> N2; vector[D] X2[N2]; } transformed data { real delta = 1e-9; int<lower=1> N = N1 + N2; vector[D] x[N]; for (n1 in 1:N1) x[n1] = X1[n1]; for (n2 in 1:N2) x[N1 + n2] = X2[n2]; } parameters { vector<lower=0>[D] rho; real<lower=0> alpha; real a; vector[N] eta; } transformed parameters { vector[N] f; { matrix[N, N] L_K = L_cov_exp_quad_ARD(x, alpha, rho, delta); f = L_K * eta; } } model { rho ~ inv_gamma(5, 5); alpha ~ normal(0, 1); a ~ normal(0, 1); eta ~ normal(0, 1); Y1 ~ bernoulli_logit(a + f[1:N1]); } generated quantities { int y2[N2]; for (n2 in 1:N2) y2[n2] = bernoulli_logit_rng(a + f[N1 + n2]); }
テストに用意したデータです。
library(dplyr) library(rstan) options(mc.cores = parallel::detectCores()) rstan_options(auto_write = TRUE) model_file <- "GP3.stan" set.seed(109) N <- 64 x <- matrix(runif(N * 2, -1, 1), ncol = 2) y <- xor(x[, 1] > 0, x[, 2] > 0) data.frame(x1 = x[, 1], x2 = x[, 2], y = y) %>% ggplot() + geom_point(aes(x = x1, y = x2, colour = y)) + coord_fixed()
それぞれの象限に新データをおいて、予測をおこないます。
x_new <- matrix(c(0.5, 0.5, -0.5, -0.5, 0.5, -0.5, 0.5, -0.5), ncol = 2) fit <- stan(model_file, data = list(N1 = N, D = 2, X1 = x, Y1 = as.integer(y), N2 = nrow(x_new), X2 = x_new)) print(fit, pars = c("rho", "alpha", "a", "y2", "lp__"))
結果です。
Inference for Stan model: GP3. 4 chains, each with iter=2000; warmup=1000; thin=1; post-warmup draws per chain=1000, total post-warmup draws=4000. mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat rho[1] 0.68 0.00 0.17 0.42 0.56 0.65 0.77 1.06 2452 1 rho[2] 0.69 0.00 0.17 0.44 0.57 0.66 0.78 1.08 3174 1 alpha 2.79 0.01 0.58 1.75 2.37 2.75 3.17 4.00 4000 1 a 0.01 0.02 0.86 -1.65 -0.58 0.02 0.58 1.67 2909 1 y2[1] 0.06 0.00 0.25 0.00 0.00 0.00 0.00 1.00 3939 1 y2[2] 0.89 0.00 0.31 0.00 1.00 1.00 1.00 1.00 4000 1 y2[3] 0.94 0.00 0.24 0.00 1.00 1.00 1.00 1.00 4000 1 y2[4] 0.04 0.00 0.20 0.00 0.00 0.00 0.00 1.00 4000 1 lp__ -66.28 0.14 5.90 -78.79 -70.08 -65.89 -62.22 -55.47 1714 1
タグ:STAn
コメント 0