From 9d5901098d6933001dfdf9059ec8a9c45313d260 Mon Sep 17 00:00:00 2001 From: Gerwin van der Lugt Date: Tue, 16 Jan 2024 13:44:33 +0100 Subject: [PATCH] respect `CUDA_PATH` and friends in `build.rs` (with windows support) --- build.rs | 45 ++++++++++++++++++++++----------------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/build.rs b/build.rs index 44edd74..c7b9a29 100644 --- a/build.rs +++ b/build.rs @@ -1,40 +1,39 @@ -#[cfg(unix)] fn main() { - let cuda_path = std::env::var("CUDA_PATH").unwrap_or_else(|_| "/usr/local/cuda".to_string()); + let cuda_path = std::env::var("CUDA_PATH").map(std::path::PathBuf::from); - let cuda_include_path = - std::env::var("CUDA_INCLUDE_PATH").unwrap_or_else(|_| format!("{cuda_path}/include")); + #[cfg(not(windows))] + let cuda_path = cuda_path.unwrap_or_else(|_| std::path::PathBuf::from("/usr/local/cuda")); + #[cfg(windows)] + let cuda_path = cuda_path.expect("Missing environment variable `CUDA_PATH`."); - let cuda_lib_path = - std::env::var("CUDA_LIB_PATH").unwrap_or_else(|_| format!("{cuda_path}/lib64")); + let cuda_include_path = std::env::var("CUDA_INCLUDE_PATH") + .map(std::path::PathBuf::from) + .unwrap_or_else(|_| cuda_path.join("include")); + + let cuda_lib_path = std::env::var("CUDA_LIB_PATH") + .map(std::path::PathBuf::from) + .unwrap_or_else(|_| { + #[cfg(not(windows))] + { + cuda_path.join("lib64") + } + #[cfg(windows)] + { + cuda_path.join("lib").join("x64") + } + }); cpp_build::Config::new() .include(cuda_include_path) .build("src/lib.rs"); - println!("cargo:rustc-link-search={cuda_lib_path}"); + println!("cargo:rustc-link-search={}", cuda_lib_path.display()); println!("cargo:rustc-link-lib=cudart"); #[cfg(feature = "npp")] link_npp_libraries(); } -#[cfg(windows)] -fn main() { - let cuda_path = std::env::var("CUDA_PATH").expect("Missing environment variable `CUDA_PATH`."); - let cuda_path = std::path::Path::new(&cuda_path); - cpp_build::Config::new() - .include(cuda_path.join("include")) - .build("src/lib.rs"); - println!( - "cargo:rustc-link-search={}", - cuda_path.join("lib").join("x64").display() - ); - println!("cargo:rustc-link-lib=cudart"); - #[cfg(feature = "npp")] - link_npp_libraries(); -} - #[cfg(feature = "npp")] fn link_npp_libraries() { println!("cargo:rustc-link-lib=nppc");