Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile model methods once and reuse for all models #894

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ CmdStanFit$set("public", name = "init", value = init)
#' @param seed (integer) The random seed to use when initializing the model.
#' @param verbose (logical) Whether to show verbose logging during compilation.
#' @param hessian (logical) Whether to expose the (experimental) hessian method.
#' @param force_recompile (logical) Whether to recompile cached model methods.
#'
#' @examples
#' \dontrun{
Expand All @@ -332,25 +333,26 @@ CmdStanFit$set("public", name = "init", value = init)
#' [unconstrain_variables()], [unconstrain_draws()], [variable_skeleton()],
#' [hessian()]
#'
init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE) {
init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE, force_recompile = FALSE) {
if (os_is_wsl()) {
stop("Additional model methods are not currently available with ",
"WSL CmdStan and will not be compiled",
call. = FALSE)
}
require_suggested_package("Rcpp")
if (length(private$model_methods_env_$hpp_code_) == 0) {
if (length(private$model_methods_env_$hpp_code_) == 0 && (
is.null(private$model_methods_env_$obj_file_) ||
!file.exists(private$model_methods_env_$obj_file_))) {
stop("Model methods cannot be used with a pre-compiled Stan executable, ",
"the model must be compiled again", call. = FALSE)
}
if (hessian) {
message("The hessian method relies on higher-order autodiff ",
"which is still experimental. Please report any compilation ",
"errors that you encounter")
warning("The hessian argument is deprecated and will be removed in a future release.\n",
"The hessian method is now exposed by default.")
}
message("Compiling additional model methods...")
if (is.null(private$model_methods_env_$model_ptr)) {
expose_model_methods(private$model_methods_env_, verbose, hessian)
expose_model_methods(private$model_methods_env_, verbose = verbose,
force_recompile = force_recompile)
}
initialize_model_pointer(private$model_methods_env_, self$data_file(), seed)
invisible(NULL)
Expand Down
2 changes: 2 additions & 0 deletions R/install.R
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ build_cmdstan <- function(dir,
clean_cmdstan <- function(dir = cmdstan_path(),
cores = getOption("mc.cores", 2),
quiet = FALSE) {
unlink(file.path(dir, "model_methods.o"))
unlink(file.path(dir, "model_methods.cpp"))
withr::with_path(
c(
toolchain_PATH_env_var(),
Expand Down
18 changes: 11 additions & 7 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,8 @@ CmdStanModel <- R6::R6Class(
#' @param compile_model_methods (logical) Compile additional model methods
#' (`log_prob()`, `grad_log_prob()`, `constrain_variables()`,
#' `unconstrain_variables()`).
#' @param compile_hessian_method (logical) Should the (experimental) `hessian()` method be
#' be compiled with the model methods?
#' @param compile_hessian_method (logical) Deprecated and will be removed in a future release.
#' The hessian method is now compiled by default.
#' @param compile_standalone (logical) Should functions in the Stan model be
#' compiled for use in R? If `TRUE` the functions will be available via the
#' `functions` field in the compiled model object. This can also be done after
Expand Down Expand Up @@ -504,6 +504,10 @@ compile <- function(quiet = TRUE,
warning("'threads' is deprecated. Please use 'cpp_options = list(stan_threads = TRUE)' instead.")
cpp_options[["stan_threads"]] <- TRUE
}
if (isTRUE(compile_hessian_method)) {
warning("'compile_hessian_method' is deprecated. The hessian method is now compiled by default.")
compile_hessian_method <- FALSE
}

if (length(self$exe_file()) == 0) {
if (is.null(dir)) {
Expand Down Expand Up @@ -655,9 +659,10 @@ compile <- function(quiet = TRUE,
run_log <- wsl_compatible_run(
command = make_cmd(),
args = c(wsl_safe_path(tmp_exe),
cpp_options_to_compile_flags(cpp_options),
cpp_options_to_compile_flags(c(cpp_options, list("KEEP_OBJECT"="true"))),
stancflags_val),
wd = cmdstan_path(),
env = c("current", "CXXFLAGS" = "-fPIC"),
echo = !quiet || is_verbose_mode(),
echo_cmd = is_verbose_mode(),
spinner = quiet && rlang::is_interactive() && !identical(Sys.getenv("IN_PKGDOWN"), "true"),
Expand Down Expand Up @@ -708,6 +713,7 @@ compile <- function(quiet = TRUE,
file.remove(exe)
}
file.copy(tmp_exe, exe, overwrite = TRUE)
private$model_methods_env_$obj_file_ <- paste0(temp_file_no_ext, ".o")
if (os_is_wsl()) {
res <- processx::run(
command = "wsl",
Expand All @@ -726,11 +732,9 @@ compile <- function(quiet = TRUE,
private$precompile_stanc_options_ <- NULL
private$precompile_include_paths_ <- NULL

if(!dry_run) {
if (!dry_run) {
if (compile_model_methods) {
expose_model_methods(env = private$model_methods_env_,
verbose = !quiet,
hessian = compile_hessian_method)
expose_model_methods(private$model_methods_env_, verbose = !quiet)
}
}
invisible(self)
Expand Down
117 changes: 94 additions & 23 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -728,45 +728,116 @@ get_cmdstan_flags <- function(flag_name) {
paste(flags, collapse = " ")
}

rcpp_source_stan <- function(code, env, verbose = FALSE) {
with_cmdstan_flags <- function(expr, model_methods = FALSE) {
cxxflags <- get_cmdstan_flags("CXXFLAGS")
cmdstanr_includes <- system.file("include", package = "cmdstanr", mustWork = TRUE)
cmdstanr_includes <- paste0(" -I\"", cmdstanr_includes,"\"")
cmdstanr_includes <- paste0("-I", shQuote(cmdstanr_includes))

r_includes <- paste(
paste0("-I", shQuote(system.file("include", package = "Rcpp", mustWork = TRUE))),
paste0("-I", shQuote(R.home(component = "include")))
)

libs <- c("LDLIBS", "LIBSUNDIALS", "TBB_TARGETS", "LDFLAGS_TBB")
libs <- paste(sapply(libs, get_cmdstan_flags), collapse = " ")
if (.Platform$OS.type == "windows") {
if (os_is_windows()) {
libs <- paste(libs, "-fopenmp")
}
lib_paths <- c("/stan/lib/stan_math/lib/tbb/",
"/stan/lib/stan_math/lib/sundials_6.1.1/lib/")
withr::with_path(paste0(cmdstan_path(), lib_paths),
withr::with_makevars(
c(
USE_CXX14 = 1,
PKG_CPPFLAGS = ifelse(cmdstan_version() <= "2.30.1", "-DCMDSTAN_JSON", ""),
PKG_CXXFLAGS = paste0(cxxflags, cmdstanr_includes, collapse = " "),
PKG_LIBS = libs
),
Rcpp::sourceCpp(code = code, env = env, verbose = verbose)
new_makevars <- c(
PKG_CPPFLAGS = ifelse(cmdstan_version() <= "2.30.1", "-DCMDSTAN_JSON", ""),
PKG_CXXFLAGS = paste(cxxflags, cmdstanr_includes, r_includes, collapse = " "),
PKG_LIBS = libs
)
if (os_is_windows() && model_methods) {
new_makevars <- c(
new_makevars,
SHLIB_LD = paste0(rtools4x_toolchain_path(),"/gcc"),
LOCAL_CPPFLAGS = paste0("-I'",rtools4x_toolchain_path(),"/../include'"),
LOCAL_LIBS = paste0("-L'",rtools4x_toolchain_path(),"/../lib' -lstdc++"),
BINPREF = paste0(rtools4x_toolchain_path(), "/")
)
}
withr::with_path(
c(
paste0(cmdstan_path(), lib_paths),
toolchain_PATH_env_var()
),
withr::with_makevars(new_makevars, expr)
)
}

rcpp_source_stan <- function(code, env, verbose = FALSE) {
with_cmdstan_flags(Rcpp::sourceCpp(code = code, env = env, verbose = verbose))
invisible(NULL)
}

expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
code <- c(env$hpp_code_,
readLines(system.file("include", "model_methods.cpp",
package = "cmdstanr", mustWork = TRUE)))
initialize_method_functions <- function(env, so_name) {
env$model_ptr <-
function(...) { .Call("model_ptr_", ..., PACKAGE = so_name) }
env$log_prob <-
function(...) { .Call("log_prob_", ..., PACKAGE = so_name) }
env$grad_log_prob <-
function(...) { .Call("grad_log_prob_", ..., PACKAGE = so_name) }
env$hessian <-
function(...) { .Call("hessian_", ..., PACKAGE = so_name) }
env$get_num_upars <-
function(...) { .Call("get_num_upars_", ..., PACKAGE = so_name) }
env$get_param_metadata <-
function(...) { .Call("get_param_metadata_", ..., PACKAGE = so_name) }
env$unconstrain_variables <-
function(...) { .Call("unconstrain_variables_", ..., PACKAGE = so_name) }
env$constrain_variables <-
function(...) { .Call("constrain_variables_", ..., PACKAGE = so_name) }
env$unconstrained_param_names <-
function(...) { .Call("unconstrained_param_names_", ..., PACKAGE = so_name) }
env$constrained_param_names <-
function(...) { .Call("constrained_param_names_", ..., PACKAGE = so_name) }
}

if (hessian) {
code <- c("#include <stan/math/mix.hpp>",
code,
readLines(system.file("include", "hessian.cpp",
package = "cmdstanr", mustWork = TRUE)))
expose_model_methods <- function(env, force_recompile = FALSE, verbose = FALSE) {
precomp_methods_file <- file.path(cmdstan_path(), "model_methods.o")
if (file.exists(precomp_methods_file) && force_recompile) {
unlink(precomp_methods_file)
}
model_methods_cpp <- system.file("include", "model_methods.cpp",
package = "cmdstanr", mustWork = TRUE)
source_file <- paste0(strip_ext(precomp_methods_file), ".cpp")
file.copy(model_methods_cpp, source_file, overwrite = FALSE)

model_obj_file <- env$obj_file_
if (!file.exists(model_obj_file)) {
if (rlang::is_interactive()) {
message("Model object file not found, recompiling model...")
}
temp_hpp_file <- tempfile()
writeLines(env$hpp_code_, con = paste0(temp_hpp_file, ".cpp"))
model_obj_file <- paste0(temp_hpp_file, ".o")
}

if (!file.exists(precomp_methods_file) && rlang::is_interactive()) {
message("Compiling and caching additional model methods...")
}
if (rlang::is_interactive()) {
message("Linking precompiled model methods to model object file...")
}

methods_dll <- tempfile(fileext = .Platform$dynlib.ext)
with_cmdstan_flags(
processx::run(
command = file.path(R.home(component = "bin"), "R"),
args = c("CMD", "SHLIB", repair_path(model_obj_file), repair_path(precomp_methods_file),
"-o", repair_path(methods_dll)),
echo = verbose || is_verbose_mode(),
echo_cmd = is_verbose_mode(),
error_on_status = FALSE
),
model_methods = TRUE
)

code <- paste(code, collapse = "\n")
rcpp_source_stan(code, env, verbose)
env$methods_dll_info <- with_cmdstan_flags(dyn.load(methods_dll, local = TRUE, now = TRUE))
initialize_method_functions(env, strip_ext(basename(methods_dll)))
invisible(NULL)
}

Expand Down
41 changes: 0 additions & 41 deletions inst/include/hessian.cpp

This file was deleted.

Loading
Loading