SSブログ

JAGSをsnowで並列化 [統計]

syou6162さんのところの解説を参考にして、snowパッケージをつかってJAGSを並列化してみた。これはいいかも。

classic-bugs/vol1/blockerでやってみる。
library(snow)
library(rjags)

read.jagsdata <- function(file)
{
  e <- new.env()
  eval(parse(file), e)
  return(as.list(e))
}

doBlocker <- function(file, data, inits, variable.names,
                      n.iter, n.adapt, thin) {
  m <- jags.model(file, data, inits,
                  n.chains = 1, n.adapt = n.adapt)
  coda.samples(m, variable.names, n.iter, thin)
}

## set data
file <- "blocker.bug"
data <- read.jagsdata("blocker-data.R")

## set inits
inits <- list()
inits[[1]] <- inits[[2]] <- read.jagsdata("blocker-init.R")

## set RNGs
inits[[1]]$.RNG.name <- "base::Mersenne-Twister"
inits[[1]]$.RNG.seed <- 123
inits[[2]]$.RNG.name <- "base::Mersenne-Twister"
inits[[2]]$.RNG.seed <- 12345

## set parameters
vars <- c("d", "delta.new", "sigma")
n.iter <- 30000
n.adapt <- 3000
thin <- 10
n.chains <- 2

params <- list("doBlocker", "jags.model", "coda.samples",
               "mcmc", "mcmc.list",
               "file", "data", "inits", "vars",
               "n.iter", "n.adapt", "thin")
cl <- makeCluster(2, type = "SOCK")
clusterExport(cl, params)
r <- parLapply(cl, 1:n.chains,
               function(x) doBlocker(file, data, inits[[x]],
                                     vars, n.iter, n.adapt, thin))
stopCluster(cl)

post <- mcmc.list(lapply(1:n.chains, function(x) r[[x]][[1]]))

なにも指定しないと、同時に実行するチェーンで乱数系列が同じになってしまうので、乱数のシードを指定しておく。ついでなので、RNGにはMersenne-Twisterを指定しておいた。

計算時間を比較してみる。MacBook Pro (Intel Core 2 Duo)使用。
> system.time(r <- parLapply(cl, 1:n.chains,
+                function(x) doBlocker(file, data, inits[[x]],
+                                      vars, n.iter, n.adapt, thin)))
   ユーザ   システム       経過  
     0.040      0.010      8.334 


シリアルに実行すると、
> system.time(r <- lapply(1:n.chains,
+                function(x) doBlocker(file, data, inits[[x]],
+                                      vars, n.iter, n.adapt, thin)))
   ユーザ   システム       経過  
    16.285      0.072     16.275 

n.chains = 2では、2倍近い速さになっている。n.chains = 3のときは、makeCluster(3, type = "SOCK")でよい。

むりやりGCD対応にしたのよりも、こっちの方がよさそうだ。



タグ:rjags MCMC jags R
nice!(3)  コメント(0)  トラックバック(0) 
共通テーマ:パソコン・インターネット

nice! 3

コメント 0

コメントを書く

お名前:
URL:
コメント:
画像認証:
下の画像に表示されている文字を入力してください。

Facebook コメント

トラックバック 0