SSブログ

RからCmdStanをつかう関数 [統計]

コマンドラインのStanはCmdStanとよばれるようになったようだ。RStanより速いみたいなので、Rから並列化して実行する関数をつくってみた。

(注) OS Xでしか動作確認していません(が、R GUIからではなぜかうまく動きません。ESSやRStudioからは動作します)。エラー処理はいいかげんです。

library(coda)
library(parallel)

# function definition
mcstan <- function(file, data, stan.home,
                   chains = 4, 
                   iter = 2000, warmup = iter / 2,
                   thin = max(1, floor(iter - warmup) / 1000),
                   init = 2, seed = -1, 
                   mc.cores = getOption("mc.cores", 2L)) {
  # parameters:
  #   file: a character string file name of Stan model file
  #   data: a list of data passed to Stan
  #   stan.home: a character string directory name of Stan directory
  #              (where "makefile" exists)
  
  # function: chkdir
  chkdir <- function(dirname) {
    if (!file.exists(dirname)) {
      stop(paste(dirname, "not found."))
    } else if (!file.info(dirname)$isdir) {
      stop(paste(dirname, "is not a directory."))
    }
  }
  
  # function: runstan
  runstan <- function(chain, program.dir, program, data.file,
                      init = 2, seed = -1) {
    result.file <- file.path(program.dir,
                             paste(program, ".chain", chain, ".csv", sep = ""))
    ini <- ""
    if (is.numeric(init) & init >= 0) {
      ini <- paste(" init=", init, sep = "")
    } else if (is.character(init)) {
      init <- file.path(program.dir, init)
      if (file.exists(init)) {
        ini <- paste(" init=", init, sep = "")
      } else {
        stop(paster(init, "not found."))
      }
    }
    cmd <- paste(file.path(program.dir, program), 
                 " sample",
                 " num_samples=", iter,
                 " num_warmup=", warmup,
                 " thin=", thin,
                 " data file=", file.path(program.dir, data.file),
                 ini,
                 " random seed=", seed, 
                 " id=", chain,
                 " output file=", result.file,
                 " refresh=", iter,
                 sep = "")
#    print(cmd)
    system(cmd)
    
    if (!file.exists(result.file)) {
      stop(paste(result.file, "not found"))
    } else {
      s <- read.csv(result.file, comment.char = "#")
      mcmc(s, start = warmup + 1, thin = thin)
    }
  }
  
  # check model file
  if (!file.exists(file)) {
    stop(paste(file, "not found."))
  } else {
    program.dir <- normalizePath(dirname(file))
    program <- sub("\\.stan$", "", basename(file))
  }
  
  # save current directory
  cur.dir <- getwd()

  # dump data
  chkdir(program.dir)
  setwd(program.dir)
  data.file <- paste(program, "Rdump", sep = ".")
  attach(data)
  dump(ls(data), data.file)
  detach(data)
  
  # compile
  chkdir(stan.home)
  setwd(stan.home)
  # print(paste("make", file.path(program.dir, program)))
  system(paste("make", file.path(program.dir, program)))
  setwd(cur.dir)

  # initialization
  if (is.numeric(init)) {
    inits <- rep(init[1], chains)
  } else if (is.list(init)) {
    if (length(init) >= chains) {
      chkdir(program.dir)
      setwd(program.dir)
      inits <- sapply(1:chains, function(i) paste(program, "init", i, "Rdump", sep = "."))
      for (i in 1:chains) {
        attach(init[[i]])
        dump(ls(init[[i]]), inits[i])
        detach(init[[i]])
      }
    }
  }
  
  # random number seed
  len.seed <- length(seed)
  seeds <- vector("numeric", chains)
  if (len.seed == 1) {
    seeds <- rep(seed, chains)
  } else if (len.seed >= chains) {
    seeds <- seed[1:chains]
  } else {
    stop("invalid number of seeds:", len.seed)
  }
  
  # call stan with mclapply
  as.mcmc.list(mclapply(1:chains,
                        function(i) runstan(i, program.dir, program, data.file,
                                            init = inits[i], seed = seeds[i]),
                        mc.cores = mc.cores))

  # for debug
#  as.mcmc.list(lapply(1:chains,
#                      function(i) runstan(i, program.dir, program, data.file,
#                                          init = inits[i], seed = seeds[i])))
}

実行例

d <- read.csv(url("http://hosho.ees.hokudai.ac.jp/~kubo/stat/iwanamibook/fig/hbm/data7a.csv"))

stan.d <- "~/Documents/src/stan-2.0.1"
seeds <- c(123, 1234, 12345, 123456)
inits <- lapply(1:4, function(i) list(beta = rnorm(1, 0, 10),
                                      sigma = runif(1, 0, 10),
                                      r = rnorm(nrow(d), 0, 2)))
samples <- mcstan(file = "kubo10.stan",
                  data = list(N = nrow(d), Y = d$y),
                  stan.home = stan.d, chains = 4,
                  iter = 10000, warmup = 100, thin = 10,
                  init = inits, seed = seeds,
                  mc.cores = 4)
summary(samples[, c("beta", "sigma")])

Stanファイル: kubo10.stan

// Kubo Book Chapter 10
data {
  int<lower=0> N;     // sample size
  int<lower=0> Y[N];  // response variable
}
parameters {
  real beta;
  real r[N];
  real<lower=0> sigma;
}
transformed parameters {
  real q[N];

  for (i in 1:N) {
    q[i] <- inv_logit(beta + r[i]); // 生存確率
  }
}
model {
  for (i in 1:N) {
		Y[i] ~ binomial(8, q[i]); // 二項分布
  }
  beta ~ normal(0, 100);      // 無情報事前分布
  r ~ normal(0, sigma);       // 階層事前分布
  sigma ~ uniform(0, 1.0e+4); // 無情報事前分布
}

結果

> summary(samples[, c("beta", "sigma")])
Iterations = 101:10091
Thinning interval = 10 
Number of chains = 4 
Sample size per chain = 1000 

1. Empirical mean and standard deviation for each variable,
   plus standard error of the mean:

         Mean     SD Naive SE Time-series SE
beta  0.03747 0.3285 0.005194       0.005811
sigma 3.06428 0.3683 0.005823       0.006018

2. Quantiles for each variable:

         2.5%     25%     50%    75%  97.5%
beta  -0.5929 -0.1814 0.03811 0.2546 0.6995
sigma  2.4155  2.8126 3.03395 3.2955 3.8564


タグ:R STAn
nice!(2)  コメント(0)  トラックバック(0) 
共通テーマ:日記・雑感

nice! 2

コメント 0

コメントを書く

お名前:
URL:
コメント:
画像認証:
下の画像に表示されている文字を入力してください。

Facebook コメント

トラックバック 0