From 1c7378c8d62664ea0705c8cdc28e96824e659383 Mon Sep 17 00:00:00 2001 From: Lev Kozlov Date: Tue, 17 Sep 2024 16:31:29 +0900 Subject: [PATCH] fix: try cuda first --- benchmarks/run_benchmark.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/benchmarks/run_benchmark.py b/benchmarks/run_benchmark.py index d8c7572..69951af 100644 --- a/benchmarks/run_benchmark.py +++ b/benchmarks/run_benchmark.py @@ -42,7 +42,7 @@ class ColabPaths: torch.manual_seed(0) -N_ENVS_SWEEP = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768] +N_ENVS_SWEEP = [1, 2, 4, 8, 16, 32, 64, 128, 256] # , 512, 1024, 2048, 4096, 8192, 16384, 32768] N_EVALS = 20 # Load functions for CUDA benchmarking @@ -112,20 +112,26 @@ def run_jaxadi_benchmark(fn, inputs): print(f"Running Jaxadi benchmark for {n_envs} environments with function {fn_name}...") inputs = sample_jaxadi_input(fn, n_envs) + # warmup + vmapped_fn(*inputs) + for j in range(N_EVALS): results[fn_name][i, j] = run_jaxadi_benchmark(vmapped_fn, inputs) + # remove the compiled function from the memory and inputs + del inputs + return results def main(): - jaxadi_results = run_jaxadi_benchmarks() - np.savez(f"{cur_dir}/jaxadi_benchmark_results.npz", **jaxadi_results) - if PathsProvider.RUN_CUSADI: cuda_results = run_cuda_benchmarks() np.savez(f"{cur_dir}/cuda_benchmark_results.npz", **cuda_results) + jaxadi_results = run_jaxadi_benchmarks() + np.savez(f"{cur_dir}/jaxadi_benchmark_results.npz", **jaxadi_results) + print("Benchmark results saved.")