Skip to content

Commit

Permalink
cmdstanr version
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk committed Oct 23, 2023
1 parent 5eff03c commit 74cf477
Show file tree
Hide file tree
Showing 36 changed files with 279 additions and 251 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/render-readme.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:

jobs:
render-readme:
runs-on: ubuntu-latest
runs-on: macos-latest
env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
steps:
Expand All @@ -22,6 +22,12 @@ jobs:
with:
extra-packages: any::pkgdown, local::.

- name: Install cmdstan
run: |
cmdstanr::check_cmdstan_toolchain(fix = TRUE)
cmdstanr::install_cmdstan(cores = 2, quiet = TRUE)
shell: Rscript {0}

- name: Compile the readme
run: |
rmarkdown::render("README.Rmd")
Expand Down
8 changes: 6 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ Imports:
rlang (>= 0.4.7),
rstan (>= 2.26.0),
rstantools (>= 2.2.0),
cmdstanr,
runner,
scales,
stats,
Expand Down Expand Up @@ -142,6 +143,8 @@ LinkingTo:
RcppParallel (>= 5.0.1),
rstan (>= 2.26.0),
StanHeaders (>= 2.26.0)
Remotes:
stan-dev/cmdstanr
Biarch: true
Config/testthat/edition: 3
Encoding: UTF-8
Expand All @@ -150,6 +153,7 @@ LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
NeedsCompilation: yes
SystemRequirements: GNU make
C++17
SystemRequirements: GNU make,
C++17,
CmdStan (>= 2.29)
VignetteBuilder: knitr
6 changes: 5 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ importFrom(data.table,rbindlist)
importFrom(data.table,setDT)
importFrom(data.table,setDTthreads)
importFrom(data.table,setcolorder)
importFrom(data.table,setkey)
importFrom(data.table,setnames)
importFrom(data.table,setorder)
importFrom(data.table,setorderv)
Expand Down Expand Up @@ -169,6 +170,10 @@ importFrom(lifecycle,deprecate_warn)
importFrom(lubridate,days)
importFrom(lubridate,wday)
importFrom(patchwork,plot_layout)
importFrom(posterior,ess_bulk)
importFrom(posterior,ess_tail)
importFrom(posterior,mcse_mean)
importFrom(posterior,rhat)
importFrom(progressr,progressor)
importFrom(progressr,with_progress)
importFrom(purrr,compact)
Expand Down Expand Up @@ -218,4 +223,3 @@ importFrom(truncnorm,rtruncnorm)
importFrom(utils,capture.output)
importFrom(utils,head)
importFrom(utils,tail)
useDynLib(EpiNow2, .registration=TRUE)
8 changes: 0 additions & 8 deletions R/EpiNow2-package.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
#' @keywords internal
#' @import Rcpp
#' @import methods
#' @import rstantools
#' @importFrom rstan sampling extract
#' @useDynLib EpiNow2, .registration=TRUE
"_PACKAGE"

# The following block is used by usethis to automatically manage
# roxygen namespace tags. Modify with care!
## usethis namespace: start
Expand Down
6 changes: 3 additions & 3 deletions R/dist.R
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ dist_fit <- function(values = NULL, samples = 1000, cores = 1,
fit <- rstan::sampling(
model,
data = data,
iter = samples + 1000,
warmup = 1000,
control = list(adapt_delta = adapt_delta),
iter_sampling = samples,
iter_warmup = 1000,
adapt_delta = adapt_delta,
chains = chains,
cores = cores,
refresh = ifelse(verbose, 50, 0)
Expand Down
18 changes: 12 additions & 6 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
#' generation_time = generation_time_opts(generation_time),
#' delays = delay_opts(incubation_period + reporting_delay),
#' rt = rt_opts(prior = list(mean = 2, sd = 0.1)),
#' stan = stan_opts(control = list(adapt_delta = 0.95))
#' stan = stan_opts(adapt_delta = 0.95)
#' )
#' # real time estimates
#' summary(def)
Expand Down Expand Up @@ -316,7 +316,8 @@ init_cumulative_fit <- function(args, samples = 50, warmup = 50,
cores = 2,
open_progress = FALSE,
show_messages = FALSE,
control = list(adapt_delta = 0.9, max_treedepth = 13),
adapt_delta = 0.9,
max_treedepth = 13,
refresh = ifelse(verbose, 50, -1)
)
# change observations to be cumulative in order to protect against noise and
Expand Down Expand Up @@ -396,10 +397,12 @@ fit_model_with_nuts <- function(args, future = FALSE, max_execution_time = Inf,

fit_chain <- function(chain, stan_args, max_time, catch = FALSE) {
stan_args$chain_id <- chain
model <- stan_args$object
stan_args$object <- NULL
if (catch) {
fit <- tryCatch(
withCallingHandlers(
R.utils::withTimeout(do.call(rstan::sampling, stan_args),
R.utils::withTimeout(do.call(model$sample, stan_args),
timeout = max_time,
onTimeout = "silent"
),
Expand All @@ -422,7 +425,7 @@ fit_model_with_nuts <- function(args, future = FALSE, max_execution_time = Inf,
}
)
} else {
fit <- R.utils::withTimeout(do.call(rstan::sampling, stan_args),
fit <- R.utils::withTimeout(do.call(model$sample, stan_args),
timeout = max_time,
onTimeout = "silent"
)
Expand All @@ -436,7 +439,7 @@ fit_model_with_nuts <- function(args, future = FALSE, max_execution_time = Inf,
}

if (!future) {
fit <- fit_chain(1,
fit <- fit_chain(seq_len(args$chains),
stan_args = args, max_time = max_execution_time,
catch = !id %in% c("estimate_infections", "epinow")
)
Expand Down Expand Up @@ -482,6 +485,7 @@ fit_model_with_nuts <- function(args, future = FALSE, max_execution_time = Inf,
)
}
}

fit <- rstan::sflist2stanfit(fit)
}
}
Expand Down Expand Up @@ -522,7 +526,9 @@ fit_model_with_vb <- function(args, future = FALSE, id = "stan") {
}

fit_vb <- function(stan_args) {
fit <- do.call(rstan::vb, stan_args)
model <- stan_args$object
stan_args$object <- NULL
fit <- do.call(model$variational, stan_args)

if (length(names(fit)) == 0) {
return(NULL)
Expand Down
24 changes: 9 additions & 15 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
#' @inheritParams estimate_infections
#' @inheritParams update_secondary_args
#' @inheritParams calc_CrIs
#' @importFrom rstan sampling
#' @importFrom lubridate wday
#' @importFrom data.table as.data.table merge.data.table
#' @examples
Expand Down Expand Up @@ -183,9 +182,9 @@ estimate_secondary <- function(reports,
)
# fit
if (is.null(model)) {
model <- stanmodels$estimate_secondary
model <- epinow2_model("estimate_secondary")
}
fit <- rstan::sampling(model,
fit <- model$sample(
data = data,
init = inits,
refresh = ifelse(verbose, 50, 0),
Expand Down Expand Up @@ -619,12 +618,8 @@ forecast_secondary <- function(estimate,
updated_primary <- primary

## extract samples from given stanfit object
draws <- rstan::extract(estimate$fit,
pars = c(
"sim_secondary", "log_lik",
"lp__", "secondary"
),
include = FALSE
draws <- extract_samples(
estimate$fit
)
# extract data from stanfit
data <- estimate$data
Expand Down Expand Up @@ -661,7 +656,7 @@ forecast_secondary <- function(estimate,

# load model
if (is.null(model)) {
model <- stanmodels$simulate_secondary
model <- epinow2_model("simulate_secondary")
}

# allocate empty parameters
Expand All @@ -671,16 +666,15 @@ forecast_secondary <- function(estimate,
)
data$all_dates <- as.integer(all_dates)
## simulate
sims <- rstan::sampling(
object = model,
data = data, chains = 1, iter = 1,
algorithm = "Fixed_param",
sims <- model$sample(
data = data, chains = 1,
iter_sampling = 1, fixed_param = TRUE,
refresh = 0
)

# extract samples and organise
dates <- unique(primary_fit$date)
samples <- rstan::extract(sims, "sim_secondary")$sim_secondary
samples <- extract_samples(sims, variables = "sim_secondary")$sim_secondary
samples <- as.data.table(samples)
colnames(samples) <- c("iterations", "sample", "time", "value")
samples <- samples[, c("iterations", "time") := NULL]
Expand Down
14 changes: 7 additions & 7 deletions R/estimate_truncation.R
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10,
obs_start <- max(nrow(obs) - trunc_max - sum(is.na(obs$`1`)) + 1, 1)
obs_dist <- purrr::map_dbl(2:(ncol(obs)), ~ sum(is.na(obs[[.]])))
obs_data <- obs[, -1][, purrr::map(.SD, ~ ifelse(is.na(.), 0, .))]
obs_data <- obs_data[obs_start:.N]
obs_data <- as.matrix(obs_data[obs_start:.N])

# convert to stan list
data <- list(
Expand Down Expand Up @@ -252,9 +252,9 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10,

# fit
if (is.null(model)) {
model <- stanmodels$estimate_truncation
model <- epinow2_model("estimate_truncation")
}
fit <- rstan::sampling(model,
fit <- model$sample(
data = data,
init = init_fn,
refresh = ifelse(verbose, 50, 0),
Expand All @@ -264,10 +264,10 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10,
out <- list()
# Summarise fit truncation distribution for downstream usage
out$dist <- dist_spec(
mean = round(rstan::summary(fit, pars = "delay_mean")$summary[1], 3),
mean_sd = round(rstan::summary(fit, pars = "delay_mean")$summary[3], 3),
sd = round(rstan::summary(fit, pars = "delay_sd")$summary[1], 3),
sd_sd = round(rstan::summary(fit, pars = "delay_sd")$summary[3], 3),
mean = round(fit$summary(variables = "delay_mean")[[2]], 3),
mean_sd = round(fit$summary(variables = "delay_mean")[[4]], 3),
sd = round(fit$summary(variables = "delay_sd")[[2]], 3),
sd_sd = round(fit$summary(variables = "delay_sd")[[4]], 3),
max = truncation$max
)
out$dist$dist <- truncation$dist
Expand Down
Loading

0 comments on commit 74cf477

Please sign in to comment.