SSブログ

自前JAGS呼び出し関数 2011-08-21版 [統計]

並列化にも対応した(はずの)自前のJAGS呼び出し関数の2011-08-21版(前の版)。JAGS 3でも とりあえずエラーは出なくなったが、例によってエラー処理とかはいいかげん。

ライセンスをGPL2としてみた。

##  Copyright (C) 2007-2011  ITÔ Hiroki 
##
##  This program is free software; you can redistribute it and/or modify
##  it under the terms of the GNU General Public License as published by
##  the Free Software Foundation; either version 2 of the License, or
##  any later version.
##
##  This program is distributed in the hope that it will be useful,
##  but WITHOUT ANY WARRANTY; without even the implied warranty of
##  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
##  GNU General Public License for more details.
##
##  You should have received a copy of the GNU General Public License
##  along with this program; if not, write to the Free Software
##  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

library(coda)
library(snow)

jags.makedata <- function(data, file = "jags-data.R") {
  ## write a data file
  if (is.list(data)) {
    for (i in 1:length(data)) {
      lab <- labels(data)[i]
      dat.l <- paste("\"", lab, "\" <- ", sep = "")
      dat <- data[[i]]
      if (length(dat) > 1 & is.vector(dat)) {
        dat.l <- paste(dat.l, "c(")
        for (j in 1:(length(dat) - 1)) {
          dat.l <- paste(dat.l, dat[j], ", ", sep = "")
        }
        dat.l <- paste(dat.l, dat[length(dat)], ")", sep = "")
      } else if (is.matrix(dat) || is.array(dat)) {
        dat.l <- paste(dat.l, "structure(c(")
        v <- as.vector(dat)
        for (j in 1:(length(v) - 1)) {
          dat.l <- paste(dat.l, v[j], ", ", sep = "")
        }
        dat.l <- paste(dat.l, dat[length(v)], "),", sep = "")
        dat.l <- paste(dat.l, ".Dim=c(",
                       paste(dim(dat), collapse = ","), "))")
      } else {
        if (is.character(dat)) {
          dat <- paste("\"", dat, "\"", sep = "")
        }
        dat.l <- paste(dat.l, as.character(dat))
      }
      write(dat.l, file, append=(i > 1))
    }
  } else {
    stop("jags.makedata: first parameter must be a list.")
  }
}

jags.makeinit <- function(inits, file = "jags-init.R") {
  ## write a init file
  jags.makedata(inits, file)
}

jags.makecmd <- function(parameters.to.save,
                         cmd.file = "jags.cmd",
                         model.file = "jags.bug",
                         data.file = "jags-data.R",
                         init.file = "jags-init.R",
                         out = "",
                         n.chains = 3,
                         n.iter = 2000,
                         n.burnin = floor(n.iter/2),
                         n.thin = max(1, (n.iter - n.burnin) %/% 1000)) {
  ## write a command file
  write(paste("model in \"", model.file, "\"", sep=""), cmd.file)
  write(paste("data in \"", data.file, "\"", sep=""), cmd.file,
        append = TRUE)
  write(paste("compile, nchains(", n.chains, ")", sep=""), cmd.file,
        append = TRUE)
  if (n.chains == 1) {
    write(paste("parameters in \"", init.file, "\", chain(", 1, ")",
                sep=""), cmd.file,
          append = TRUE)
  } else {
    for (i in 1:n.chains) {
      init.i <- gsub("^([^\\.]*)\\.([^\\.]*$)",
                     paste("\\1-", i, ".\\2", sep = ""),
                     init.file)
      write(paste("parameters in \"", init.i, "\", chain(", i, ")",
                  sep=""), cmd.file,
            append = TRUE)
    }
  }
  write("initialize", cmd.file, append = TRUE)
  write(paste("update", format(n.burnin, scientfic = FALSE)), cmd.file,
        append = TRUE)
  for (i in 1:length(parameters.to.save)) {
    write(paste("monitor set ", parameters.to.save[i],
                ",thin(", format(n.thin, scientific = FALSE), ")", sep=""),
                cmd.file,
          append = TRUE)	
  }
  write(paste("update", format(n.iter - n.burnin, scientific = FALSE)),
        cmd.file,
        append = TRUE)
  write(paste("coda *,stem(\"", out, "\")", sep=""), cmd.file,
        append = TRUE)
  write("exit", cmd.file, append = TRUE)
}

jags.run <- function(data, inits, parameters.to.save,
                     model.file = "jags.bug",
                     data.file = "jags-data.R",
                     init.file = "jags-init.R",
                     cmd.file = "jags.cmd",
                     out = "",
                     jags = "/usr/local/bin/jags",
                     n.chains = 3,
                     n.iter = 2000,
                     n.burnin = floor(n.iter/2),
                     n.thin = max(1, (n.iter - n.burnin) %/% 1000)) {
  ## run jags
  
  # check the model file
  if (!file.exists(model.file)) {
    stop(paste(model.file, "does not exist."))
  }
  # check number of chains
  if (n.chains < 1) {
  	stop("n.chains must be larger than 0.")
  } else if (n.chains >= 2 & n.chains != length(inits)) {
    stop("n.chains != length(inits).")
  } else {
    result <- vector("list", n.chains)
    jags.makedata(data = data, file = data.file)
    if (n.chains == 1) {
      jags.makeinit(inits = inits, file = init.file)
    } else {
      for (i in 1:n.chains) {
        init.i <- gsub("^([^\\.]*)\\.([^\\.]*$)",
                       paste("\\1-", i, ".\\2", sep = ""),
                       init.file)
        jags.makeinit(inits = inits[[i]], file = init.i)
      }
    }
    jags.makecmd(parameters.to.save = parameters.to.save,
                 cmd.file = cmd.file,
                 model.file = model.file,
                 data.file = data.file,
                 init.file = init.file,
                 out = out,
                 n.chains = n.chains,
                 n.iter = n.iter,
                 n.burnin = n.burnin,
                 n.thin = n.thin)
    if (system(paste(jags, cmd.file, sep = " "), wait = TRUE) == 0) {
      for (i in 1:n.chains) {
        output.file <- paste(out, "chain", i, ".txt", sep = "")
        index.file <- paste(out, "index.txt", sep = "")
        result[[i]] <- read.coda(output.file, index.file)
      }
      rslt <- mcmc.list(result)
    } else {
       stop("calculation failed!")
    }
  }
  return(rslt)
}

jags.parrun <- function(data, inits, parameters.to.save,
                        model.file = "jags.bug",
                        data.prefix = "jags-data",
                        init.prefix = "jags-init",
                        cmd.prefix = "jags",
                        out = "",
                        jags = "/usr/local/bin/jags",
                        n.chains = 3,
                        n.iter = 2000,
                        n.burnin = floor(n.iter/2),
                        n.thin = max(1, (n.iter - n.burnin) %/% 1000),
                        spec = rep("localhost", n.chains)) {
  ## run jags in paralell
  if (!file.exists(model.file)) {
    stop(paste(model.file, "does not exist."))
  }
  if (n.chains < 1) {
  	stop("n.chains must be larger than 0.")
  } else if (n.chains >= 2 & n.chains != length(inits)) {
    stop("n.chains != length(inits).")
  } else if (!is.list(data)) {
    stop("data must be a list.")
  } else {
  	dp <- data.prefix
  	ip <- init.prefix
  	cm <- cmd.prefix
  	ot <- out
  	jg <- jags
  	par <- parameters.to.save
  	iter <- n.iter
  	burnin <- n.burnin
  	thin <- n.thin
    params <- list("jags.run", "mcmc", "mcmc.list", "read.coda",
                   "jags.makedata", "jags.makeinit", "jags.makecmd")
    cl <- makeCluster(spec, type = "SOCK")
    clusterExport(cl, params)
    r <- parLapply(cl, 1:n.chains,
                   function(i) {
                     jags.run(data, inits[[i]], par,
                              model.file = model.file,
                              data.file = paste(dp, i, ".R", sep = ""),
                              init.file = paste(ip, i, ".R", sep = ""),
                              cmd.file = paste(cm, i, ".cmd", sep = ""),
                              out = paste(ot, i, "CODA", sep = ""),
                              jags = jg,
                              n.chains = 1,
                              n.iter = iter,
                              n.burnin = burnin,
                              n.thin = thin)
                   })
    stopCluster(cl)
    rslt <- mcmc.list(lapply(1:n.chains, function(x) r[[x]][[1]]))
  }
  return(rslt)
}

nice!(0)  コメント(0)  トラックバック(1) 
共通テーマ:パソコン・インターネット

nice! 0

コメント 0

コメントを書く

お名前:
URL:
コメント:
画像認証:
下の画像に表示されている文字を入力してください。

Facebook コメント

トラックバック 1