From 229471c73a2864af0a5934bf1af9a022d31428cc Mon Sep 17 00:00:00 2001 From: wenhaozhao Date: Thu, 5 Sep 2024 12:14:51 +0800 Subject: [PATCH 1/2] support empty_cuda_cache (cherry picked from commit 78008d966b81c740c0857dc88c1eeec97e27ebc7) --- tests/cuda_empty_cache_tests.rs | 14 ++++++++++++++ torch-sys/build.rs | 17 +++++++++++++++-- torch-sys/libtch/torch_c10_cuda_api.cpp | 7 +++++++ torch-sys/libtch/torch_c10_cuda_api.h | 10 ++++++++++ torch-sys/src/c10_cuda.rs | 8 ++++++++ torch-sys/src/lib.rs | 1 + 6 files changed, 55 insertions(+), 2 deletions(-) create mode 100644 tests/cuda_empty_cache_tests.rs create mode 100644 torch-sys/libtch/torch_c10_cuda_api.cpp create mode 100644 torch-sys/libtch/torch_c10_cuda_api.h create mode 100644 torch-sys/src/c10_cuda.rs diff --git a/tests/cuda_empty_cache_tests.rs b/tests/cuda_empty_cache_tests.rs new file mode 100644 index 00000000..38007514 --- /dev/null +++ b/tests/cuda_empty_cache_tests.rs @@ -0,0 +1,14 @@ +use tch::{Device, Kind, Tensor}; +use torch_sys::c10_cuda::empty_cuda_cache; + +#[test] +fn cuda_empty_cache() { + println!("Cuda empty cache test started..."); + println!("Create tensor..."); + let tensor = Tensor::randn([1024, 1024, 1024], (Kind::Float, Device::Cpu)); + println!("Push tensor to cuda"); + let tensor_cuda = tensor.to_kind(Kind::Half).to_device(Device::Cuda(0)); + println!("Empty cuda cache"); + drop(tensor_cuda); + let _ = empty_cuda_cache(); +} diff --git a/torch-sys/build.rs b/torch-sys/build.rs index 16d8278e..191f9f3d 100644 --- a/torch-sys/build.rs +++ b/torch-sys/build.rs @@ -365,8 +365,18 @@ impl SystemInfo { println!("cargo:rerun-if-changed=libtch/stb_image_write.h"); println!("cargo:rerun-if-changed=libtch/stb_image_resize.h"); println!("cargo:rerun-if-changed=libtch/stb_image.h"); + + let mut cuda_include_dirs: Vec = vec![]; + if use_cuda { + let cuda_home = env::var("CUDA_HOME").expect("Cannot found CUDA_HOME!!!"); + let cuda_include = PathBuf::from(format!("{cuda_home}/include")); + let _ = cuda_include_dirs.push(cuda_include); + println!("cargo:rerun-if-changed=libtch/torch_c10_cuda_api.cpp"); + println!("cargo:rerun-if-changed=libtch/torch_c10_cuda_api.h"); + } + let mut c_files = - vec!["libtch/torch_api.cpp", "libtch/torch_api_generated.cpp", cuda_dependency]; + vec!["libtch/torch_api.cpp", "libtch/torch_api_generated.cpp", "libtch/torch_c10_cuda_api.cpp", cuda_dependency]; if cfg!(feature = "python-extension") { c_files.push("libtch/torch_python.cpp") } @@ -382,6 +392,7 @@ impl SystemInfo { .pic(true) .warnings(false) .includes(&self.libtorch_include_dirs) + .includes(&cuda_include_dirs) .flag(&format!("-Wl,-rpath={}", self.libtorch_lib_dir.display())) .flag("-std=c++17") .flag(&format!("-D_GLIBCXX_USE_CXX11_ABI={}", self.cxx11_abi)) @@ -398,6 +409,7 @@ impl SystemInfo { .pic(true) .warnings(false) .includes(&self.libtorch_include_dirs) + .includes(&cuda_include_dirs) .flag("/std:c++17") .flag("/p:DefineConstants=GLOG_USE_GLOG_EXPORT") .files(&c_files) @@ -455,7 +467,8 @@ fn main() -> anyhow::Result<()> { println!("cargo:rustc-link-lib=static=tch"); if use_cuda { - system_info.link("torch_cuda") + system_info.link("torch_cuda"); + system_info.link("c10_cuda"); } if use_cuda_cu { system_info.link("torch_cuda_cu") diff --git a/torch-sys/libtch/torch_c10_cuda_api.cpp b/torch-sys/libtch/torch_c10_cuda_api.cpp new file mode 100644 index 00000000..20e6211d --- /dev/null +++ b/torch-sys/libtch/torch_c10_cuda_api.cpp @@ -0,0 +1,7 @@ + +#include "torch_c10_cuda_api.h" + + +void emptyCache(){ + c10::cuda::CUDACachingAllocator::emptyCache(); +} \ No newline at end of file diff --git a/torch-sys/libtch/torch_c10_cuda_api.h b/torch-sys/libtch/torch_c10_cuda_api.h new file mode 100644 index 00000000..5609e34f --- /dev/null +++ b/torch-sys/libtch/torch_c10_cuda_api.h @@ -0,0 +1,10 @@ +#ifndef __TORCH_C10_CUDA_API_H__ +#define __TORCH_C10_CUDA_API_H__ + +#include + +extern "C" { + void emptyCache(); +} + +#endif \ No newline at end of file diff --git a/torch-sys/src/c10_cuda.rs b/torch-sys/src/c10_cuda.rs new file mode 100644 index 00000000..59da11ca --- /dev/null +++ b/torch-sys/src/c10_cuda.rs @@ -0,0 +1,8 @@ +extern "C" { + /// empty cuda cache + pub fn emptyCache(); +} + +pub fn empty_cuda_cache() { + unsafe { emptyCache() }; +} \ No newline at end of file diff --git a/torch-sys/src/lib.rs b/torch-sys/src/lib.rs index 254d4148..28b945ef 100644 --- a/torch-sys/src/lib.rs +++ b/torch-sys/src/lib.rs @@ -161,6 +161,7 @@ extern "C" { } pub mod c_generated; +pub mod c10_cuda; extern "C" { pub fn get_and_reset_last_err() -> *mut c_char; From 098a242a5c9ab7c283b755d689d849da7d87eef4 Mon Sep 17 00:00:00 2001 From: wenhaozhao Date: Thu, 5 Sep 2024 14:52:16 +0800 Subject: [PATCH 2/2] fix build error on macos (cherry picked from commit 9a36e619c5af7b111689c712e5c4f4d763e74a20) --- torch-sys/build.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch-sys/build.rs b/torch-sys/build.rs index 191f9f3d..657b205e 100644 --- a/torch-sys/build.rs +++ b/torch-sys/build.rs @@ -376,7 +376,10 @@ impl SystemInfo { } let mut c_files = - vec!["libtch/torch_api.cpp", "libtch/torch_api_generated.cpp", "libtch/torch_c10_cuda_api.cpp", cuda_dependency]; + vec!["libtch/torch_api.cpp", "libtch/torch_api_generated.cpp", cuda_dependency]; + if use_cuda { + c_files.push("libtch/torch_c10_cuda_api.cpp") + } if cfg!(feature = "python-extension") { c_files.push("libtch/torch_python.cpp") }