Skip to content

Commit

Permalink
fix: try cuda first
Browse files Browse the repository at this point in the history
  • Loading branch information
lvjonok committed Sep 17, 2024
1 parent 8fce298 commit 1c7378c
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions benchmarks/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class ColabPaths:

torch.manual_seed(0)

N_ENVS_SWEEP = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
N_ENVS_SWEEP = [1, 2, 4, 8, 16, 32, 64, 128, 256] # , 512, 1024, 2048, 4096, 8192, 16384, 32768]
N_EVALS = 20

# Load functions for CUDA benchmarking
Expand Down Expand Up @@ -112,20 +112,26 @@ def run_jaxadi_benchmark(fn, inputs):
print(f"Running Jaxadi benchmark for {n_envs} environments with function {fn_name}...")
inputs = sample_jaxadi_input(fn, n_envs)

# warmup
vmapped_fn(*inputs)

for j in range(N_EVALS):
results[fn_name][i, j] = run_jaxadi_benchmark(vmapped_fn, inputs)

# remove the compiled function from the memory and inputs
del inputs

return results


def main():
jaxadi_results = run_jaxadi_benchmarks()
np.savez(f"{cur_dir}/jaxadi_benchmark_results.npz", **jaxadi_results)

if PathsProvider.RUN_CUSADI:
cuda_results = run_cuda_benchmarks()
np.savez(f"{cur_dir}/cuda_benchmark_results.npz", **cuda_results)

jaxadi_results = run_jaxadi_benchmarks()
np.savez(f"{cur_dir}/jaxadi_benchmark_results.npz", **jaxadi_results)

print("Benchmark results saved.")


Expand Down

0 comments on commit 1c7378c

Please sign in to comment.