Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support empty_cuda_cache #888

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/cuda_empty_cache_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use tch::{Device, Kind, Tensor};
use torch_sys::c10_cuda::empty_cuda_cache;

#[test]
fn cuda_empty_cache() {
println!("Cuda empty cache test started...");
println!("Create tensor...");
let tensor = Tensor::randn([1024, 1024, 1024], (Kind::Float, Device::Cpu));
println!("Push tensor to cuda");
let tensor_cuda = tensor.to_kind(Kind::Half).to_device(Device::Cuda(0));
println!("Empty cuda cache");
drop(tensor_cuda);
let _ = empty_cuda_cache();
}
20 changes: 17 additions & 3 deletions torch-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ impl SystemInfo {
}
}

fn make(&self) {
fn make(&self, use_cuda: bool) {
println!("cargo:rerun-if-changed=libtch/torch_python.cpp");
println!("cargo:rerun-if-changed=libtch/torch_python.h");
println!("cargo:rerun-if-changed=libtch/torch_api_generated.cpp");
Expand All @@ -361,7 +361,18 @@ impl SystemInfo {
println!("cargo:rerun-if-changed=libtch/stb_image_write.h");
println!("cargo:rerun-if-changed=libtch/stb_image_resize.h");
println!("cargo:rerun-if-changed=libtch/stb_image.h");
let mut cuda_include_dirs: Vec<PathBuf> = vec![];
if use_cuda {
let cuda_home = env::var("CUDA_HOME").expect("Cannot found CUDA_HOME!!!");
let cuda_include = PathBuf::from(format!("{cuda_home}/include"));
let _ = cuda_include_dirs.push(cuda_include);
println!("cargo:rerun-if-changed=libtch/torch_c10_cuda_api.cpp");
println!("cargo:rerun-if-changed=libtch/torch_c10_cuda_api.h");
}
let mut c_files = vec!["libtch/torch_api.cpp", "libtch/torch_api_generated.cpp"];
if use_cuda {
c_files.push("libtch/torch_c10_cuda_api.cpp")
}
if cfg!(feature = "python-extension") {
c_files.push("libtch/torch_python.cpp")
}
Expand All @@ -377,6 +388,7 @@ impl SystemInfo {
.pic(true)
.warnings(false)
.includes(&self.libtorch_include_dirs)
.includes(&cuda_include_dirs)
.flag(format!("-Wl,-rpath={}", self.libtorch_lib_dir.display()))
.flag("-std=c++17")
.flag(format!("-D_GLIBCXX_USE_CXX11_ABI={}", self.cxx11_abi))
Expand All @@ -393,6 +405,7 @@ impl SystemInfo {
.pic(true)
.warnings(false)
.includes(&self.libtorch_include_dirs)
.includes(&cuda_include_dirs)
.flag("/std:c++17")
.flag("/p:DefineConstants=GLOG_USE_GLOG_EXPORT")
.files(&c_files)
Expand Down Expand Up @@ -449,11 +462,12 @@ fn main() -> anyhow::Result<()> {
si_lib.join("libtorch_hip.so").exists() || si_lib.join("torch_hip.dll").exists();
println!("cargo:rustc-link-search=native={}", si_lib.display());

system_info.make();
system_info.make(use_cuda);

println!("cargo:rustc-link-lib=static=tch");
if use_cuda {
system_info.link("torch_cuda")
system_info.link("torch_cuda");
system_info.link("c10_cuda");
}
if use_cuda_cu {
system_info.link("torch_cuda_cu")
Expand Down
7 changes: 7 additions & 0 deletions torch-sys/libtch/torch_c10_cuda_api.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

#include "torch_c10_cuda_api.h"


void emptyCache(){
c10::cuda::CUDACachingAllocator::emptyCache();
}
10 changes: 10 additions & 0 deletions torch-sys/libtch/torch_c10_cuda_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef __TORCH_C10_CUDA_API_H__
#define __TORCH_C10_CUDA_API_H__

#include<torch/csrc/cuda/CUDAPluggableAllocator.h>

extern "C" {
void emptyCache();
}

#endif
8 changes: 8 additions & 0 deletions torch-sys/src/c10_cuda.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
extern "C" {
/// empty cuda cache
pub fn emptyCache();
}

pub fn empty_cuda_cache() {
unsafe { emptyCache() };
}
1 change: 1 addition & 0 deletions torch-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ extern "C" {
}

pub mod c_generated;
pub mod c10_cuda;

extern "C" {
pub fn get_and_reset_last_err() -> *mut c_char;
Expand Down