Skip to content

Commit

Permalink
On NixOS find cudatoolkit via which nvcc, fixes Rust-GPU#92.
Browse files Browse the repository at this point in the history
  • Loading branch information
ralfbiedert committed Oct 25, 2022
1 parent 8a6cb73 commit a2f823f
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions crates/find_cuda_helper/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Tiny crate for common logic for finding and including CUDA.
use std::process::Command;
use std::{
env,
path::{Path, PathBuf},
Expand Down Expand Up @@ -57,6 +58,7 @@ pub fn find_cuda_root() -> Option<PathBuf> {

#[cfg(target_os = "windows")]
pub fn find_cuda_lib_dirs() -> Vec<PathBuf> {
let x = detect_cuda_root_via_which_nvcc();
if let Some(root_path) = find_cuda_root() {
// To do this the right way, we check to see which target we're building for.
let target = env::var("TARGET")
Expand Down Expand Up @@ -131,6 +133,7 @@ pub fn find_cuda_lib_dirs() -> Vec<PathBuf> {
candidates.push(e)
}
candidates.push(PathBuf::from("/usr/lib/cuda"));
candidates.push(detect_cuda_root_via_which_nvcc());

let mut valid_paths = vec![];
for base in &candidates {
Expand All @@ -150,6 +153,24 @@ pub fn find_cuda_lib_dirs() -> Vec<PathBuf> {
valid_paths
}

#[cfg(not(target_os = "windows"))]
fn detect_cuda_root_via_which_nvcc() -> PathBuf {
let output = Command::new("which")
.arg("nvcc")
.output()
.expect("Command `which` must be available on *nix like systems.")
.stdout;

let path: PathBuf = String::from_utf8(output)
.expect("Result must be valid UTF-8")
.trim()
.to_string()
.into();

// The above finds `CUDASDK/bin/nvcc`, so we have to go 2 up for the SDK root.
path.parent().unwrap().parent().unwrap().to_path_buf()
}

#[cfg(target_os = "windows")]
pub fn find_optix_root() -> Option<PathBuf> {
// the optix SDK installer sets OPTIX_ROOT_DIR whenever it installs.
Expand Down

0 comments on commit a2f823f

Please sign in to comment.