Skip to content

Commit

Permalink
[mpact][benchmark] check output type (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
yinying-lisa-li authored Jun 27, 2024
1 parent 385c8b6 commit 0459510
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion benchmark/python/utils/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,15 @@ def run_benchmark(
):
"""Run benchmark with specified backends."""
output = []
output_type = None

with torch.no_grad():
for backend in backends:
match backend:
case Backends.TORCH_SPARSE_EAGER:
output.append(torch_net(*sparse_inputs))
sparse_out = torch_net(*sparse_inputs)
output_type = sparse_out.layout
output.append(sparse_out)
runtime_results.append(
timer(
"torch_net(*sparse_inputs)",
Expand Down Expand Up @@ -133,8 +136,16 @@ def run_benchmark(
output.append(
torch.sparse_csr_tensor(*sp_out, size=dense_out.shape)
)
# Check MPACT and torch eager both return sparse csr output
# only when torch sparse eager has been run.
if output_type:
assert output_type == torch.sparse_csr
else:
output.append(torch.from_numpy(sp_out))
# Check MPACT and torch eager both return dense output
# only when torch sparse eager has been run.
if output_type:
assert output_type == torch.strided
invoker, f = mpact_jit_compile(torch_net, *sparse_inputs)
compile_time_results.append(
timer(
Expand Down

0 comments on commit 0459510

Please sign in to comment.