diff --git a/build.rs b/build.rs index 53b83a8..c7b9a29 100644 --- a/build.rs +++ b/build.rs @@ -1,26 +1,35 @@ -#[cfg(unix)] fn main() { - cpp_build::Config::new() - .include("/usr/local/cuda/include") - .build("src/lib.rs"); - println!("cargo:rustc-link-search=/usr/local/cuda/lib64"); - println!("cargo:rustc-link-lib=cudart"); - #[cfg(feature = "npp")] - link_npp_libraries(); -} + let cuda_path = std::env::var("CUDA_PATH").map(std::path::PathBuf::from); + + #[cfg(not(windows))] + let cuda_path = cuda_path.unwrap_or_else(|_| std::path::PathBuf::from("/usr/local/cuda")); + #[cfg(windows)] + let cuda_path = cuda_path.expect("Missing environment variable `CUDA_PATH`."); + + let cuda_include_path = std::env::var("CUDA_INCLUDE_PATH") + .map(std::path::PathBuf::from) + .unwrap_or_else(|_| cuda_path.join("include")); + + let cuda_lib_path = std::env::var("CUDA_LIB_PATH") + .map(std::path::PathBuf::from) + .unwrap_or_else(|_| { + #[cfg(not(windows))] + { + cuda_path.join("lib64") + } + #[cfg(windows)] + { + cuda_path.join("lib").join("x64") + } + }); -#[cfg(windows)] -fn main() { - let cuda_path = std::env::var("CUDA_PATH").expect("Missing environment variable `CUDA_PATH`."); - let cuda_path = std::path::Path::new(&cuda_path); cpp_build::Config::new() - .include(cuda_path.join("include")) + .include(cuda_include_path) .build("src/lib.rs"); - println!( - "cargo:rustc-link-search={}", - cuda_path.join("lib").join("x64").display() - ); + + println!("cargo:rustc-link-search={}", cuda_lib_path.display()); println!("cargo:rustc-link-lib=cudart"); + #[cfg(feature = "npp")] link_npp_libraries(); }