Skip to content

Commit

Permalink
set cuda env vars for tensorflow
Browse files Browse the repository at this point in the history
  • Loading branch information
t-kalinowski committed Aug 1, 2023
1 parent 56a7dbc commit 3cca61b
Showing 1 changed file with 61 additions and 1 deletion.
62 changes: 61 additions & 1 deletion R/package.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,46 @@ tf_v2 <- function() {
.globals <- new.env(parent = emptyenv())
.globals$tensorboard <- NULL

set_cuda_env_vars <- function(python_too = FALSE) {

# For TF 2.13, this assumes that someone already has cudn 11-8 installed,
# e.g., on ubuntu:
# sudo apt install cuda-toolkit-11-8

python <- py_discover_config()$python
cudnn_module_path <- suppressWarnings(system2(python, c("-c",
shQuote("import nvidia.cudnn;print(nvidia.cudnn.__file__)")),
stdout = TRUE, stderr = TRUE))
if (!is.null(attr(cudnn_module_path, "status")) ||
!is_string(cudnn_module_path) ||
!file.exists(cudnn_module_path))
return()

cudnn_path <- dirname(cudnn_module_path)
ld_library_path <- paste0(c(paste0(cudnn_path, "/lib"),
Sys.getenv("LD_LIBRARY_PATH")),
collapse = ":")

vars <- list(CUDNN_PATH = cudnn_path,
LD_LIBRARY_PATH = ld_library_path)

out <- Sys.getenv(names(vars), unset = NA, names = TRUE)

do.call(Sys.setenv, vars)

if(py_available()) {
# none of these 3 approaches seem to do anything for tensorflow
import("os", convert = FALSE)$environ$update(vars)
# environ <- (os <- import("os", convert = FALSE))$environ
# for(i in seq_along(vars)) {
# py_set_item(environ, names(vars)[[i]], vars[[i]])
# os$putenv(names(vars)[[i]], vars[[i]])
# }
}

invisible(out)
}

.onLoad <- function(libname, pkgname) {

# if TENSORFLOW_PYTHON is defined then forward it to RETICULATE_PYTHON
Expand All @@ -44,13 +84,30 @@ tf_v2 <- function() {
if (!is.null(cpp_log_opt))
Sys.setenv(TF_CPP_MIN_LOG_LEVEL = max(min(cpp_log_opt, 1), 0))

if (is_linux()) {
if (!py_available())
options("reticulate.python.beforeInitialized" = set_cuda_env_vars)
else {
# This is likely too little too late: it would seem that LD_LIBRARY_PATH
# *must* be set before python symbols are loaded (that is, before reticulate
# initialized python). calling os.putenv()/os.environ.set() after python
# has initialized but before importing the tensorflow module (e.g., from a
# before_load() module import hook) don't seem to get the job done.
set_cuda_env_vars()
}
}

# delay load tensorflow
tf <<- import("tensorflow", delay_load = list(

priority = 5,
priority = 5, # keras sets priority = 10

environment = "r-tensorflow",

# before_load = function() {
#
# },

on_load = function() {

# register warning suppression handler
Expand Down Expand Up @@ -116,6 +173,9 @@ tf_v2 <- function() {
}


is_string <- function(x) {
is.character(x) && length(x) == 1L && !is.na(x)
}

#' TensorFlow configuration information
#'
Expand Down

0 comments on commit 3cca61b

Please sign in to comment.