diff --git a/tests/cuda_empty_cache_tests.rs b/tests/cuda_empty_cache_tests.rs new file mode 100644 index 00000000..38007514 --- /dev/null +++ b/tests/cuda_empty_cache_tests.rs @@ -0,0 +1,14 @@ +use tch::{Device, Kind, Tensor}; +use torch_sys::c10_cuda::empty_cuda_cache; + +#[test] +fn cuda_empty_cache() { + println!("Cuda empty cache test started..."); + println!("Create tensor..."); + let tensor = Tensor::randn([1024, 1024, 1024], (Kind::Float, Device::Cpu)); + println!("Push tensor to cuda"); + let tensor_cuda = tensor.to_kind(Kind::Half).to_device(Device::Cuda(0)); + println!("Empty cuda cache"); + drop(tensor_cuda); + let _ = empty_cuda_cache(); +} diff --git a/torch-sys/build.rs b/torch-sys/build.rs index dacc5405..eb597f34 100644 --- a/torch-sys/build.rs +++ b/torch-sys/build.rs @@ -351,7 +351,7 @@ impl SystemInfo { } } - fn make(&self) { + fn make(&self, use_cuda: bool) { println!("cargo:rerun-if-changed=libtch/torch_python.cpp"); println!("cargo:rerun-if-changed=libtch/torch_python.h"); println!("cargo:rerun-if-changed=libtch/torch_api_generated.cpp"); @@ -361,7 +361,18 @@ impl SystemInfo { println!("cargo:rerun-if-changed=libtch/stb_image_write.h"); println!("cargo:rerun-if-changed=libtch/stb_image_resize.h"); println!("cargo:rerun-if-changed=libtch/stb_image.h"); + let mut cuda_include_dirs: Vec = vec![]; + if use_cuda { + let cuda_home = env::var("CUDA_HOME").expect("Cannot found CUDA_HOME!!!"); + let cuda_include = PathBuf::from(format!("{cuda_home}/include")); + let _ = cuda_include_dirs.push(cuda_include); + println!("cargo:rerun-if-changed=libtch/torch_c10_cuda_api.cpp"); + println!("cargo:rerun-if-changed=libtch/torch_c10_cuda_api.h"); + } let mut c_files = vec!["libtch/torch_api.cpp", "libtch/torch_api_generated.cpp"]; + if use_cuda { + c_files.push("libtch/torch_c10_cuda_api.cpp") + } if cfg!(feature = "python-extension") { c_files.push("libtch/torch_python.cpp") } @@ -377,6 +388,7 @@ impl SystemInfo { .pic(true) .warnings(false) .includes(&self.libtorch_include_dirs) + .includes(&cuda_include_dirs) .flag(format!("-Wl,-rpath={}", self.libtorch_lib_dir.display())) .flag("-std=c++17") .flag(format!("-D_GLIBCXX_USE_CXX11_ABI={}", self.cxx11_abi)) @@ -393,6 +405,7 @@ impl SystemInfo { .pic(true) .warnings(false) .includes(&self.libtorch_include_dirs) + .includes(&cuda_include_dirs) .flag("/std:c++17") .flag("/p:DefineConstants=GLOG_USE_GLOG_EXPORT") .files(&c_files) @@ -449,11 +462,12 @@ fn main() -> anyhow::Result<()> { si_lib.join("libtorch_hip.so").exists() || si_lib.join("torch_hip.dll").exists(); println!("cargo:rustc-link-search=native={}", si_lib.display()); - system_info.make(); + system_info.make(use_cuda); println!("cargo:rustc-link-lib=static=tch"); if use_cuda { - system_info.link("torch_cuda") + system_info.link("torch_cuda"); + system_info.link("c10_cuda"); } if use_cuda_cu { system_info.link("torch_cuda_cu") diff --git a/torch-sys/libtch/torch_c10_cuda_api.cpp b/torch-sys/libtch/torch_c10_cuda_api.cpp new file mode 100644 index 00000000..20e6211d --- /dev/null +++ b/torch-sys/libtch/torch_c10_cuda_api.cpp @@ -0,0 +1,7 @@ + +#include "torch_c10_cuda_api.h" + + +void emptyCache(){ + c10::cuda::CUDACachingAllocator::emptyCache(); +} \ No newline at end of file diff --git a/torch-sys/libtch/torch_c10_cuda_api.h b/torch-sys/libtch/torch_c10_cuda_api.h new file mode 100644 index 00000000..5609e34f --- /dev/null +++ b/torch-sys/libtch/torch_c10_cuda_api.h @@ -0,0 +1,10 @@ +#ifndef __TORCH_C10_CUDA_API_H__ +#define __TORCH_C10_CUDA_API_H__ + +#include + +extern "C" { + void emptyCache(); +} + +#endif \ No newline at end of file diff --git a/torch-sys/src/c10_cuda.rs b/torch-sys/src/c10_cuda.rs new file mode 100644 index 00000000..59da11ca --- /dev/null +++ b/torch-sys/src/c10_cuda.rs @@ -0,0 +1,8 @@ +extern "C" { + /// empty cuda cache + pub fn emptyCache(); +} + +pub fn empty_cuda_cache() { + unsafe { emptyCache() }; +} \ No newline at end of file diff --git a/torch-sys/src/lib.rs b/torch-sys/src/lib.rs index 96c8eb09..8afb5ad7 100644 --- a/torch-sys/src/lib.rs +++ b/torch-sys/src/lib.rs @@ -161,6 +161,7 @@ extern "C" { } pub mod c_generated; +pub mod c10_cuda; extern "C" { pub fn get_and_reset_last_err() -> *mut c_char;