From 7c88219cb60ff74de750fa9aa88619a78a1a8cb7 Mon Sep 17 00:00:00 2001 From: Lev Kozlov Date: Tue, 17 Sep 2024 16:47:49 +0900 Subject: [PATCH] fix: try with memory cleaning --- benchmarks/run_benchmark.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/benchmarks/run_benchmark.py b/benchmarks/run_benchmark.py index 69951af..ebf6f7c 100644 --- a/benchmarks/run_benchmark.py +++ b/benchmarks/run_benchmark.py @@ -1,3 +1,4 @@ +import gc import os import time @@ -73,6 +74,8 @@ def run_cusadi_benchmark(fn, inputs): results["N_EVALS"] = N_EVALS for fn in benchmark_fns: + with torch.no_grad(): + torch.cuda.empty_cache() fn_name = fn.name() for i, n_envs in enumerate(N_ENVS_SWEEP): print(f"Running CUDA benchmark for {n_envs} environments with function {fn_name}...") @@ -128,6 +131,7 @@ def main(): if PathsProvider.RUN_CUSADI: cuda_results = run_cuda_benchmarks() np.savez(f"{cur_dir}/cuda_benchmark_results.npz", **cuda_results) + gc.collect() jaxadi_results = run_jaxadi_benchmarks() np.savez(f"{cur_dir}/jaxadi_benchmark_results.npz", **jaxadi_results)