diff --git a/benchmark/python/utils/benchmark_utils.py b/benchmark/python/utils/benchmark_utils.py index ec29309..ac9aa63 100644 --- a/benchmark/python/utils/benchmark_utils.py +++ b/benchmark/python/utils/benchmark_utils.py @@ -65,12 +65,15 @@ def run_benchmark( ): """Run benchmark with specified backends.""" output = [] + output_type = None with torch.no_grad(): for backend in backends: match backend: case Backends.TORCH_SPARSE_EAGER: - output.append(torch_net(*sparse_inputs)) + sparse_out = torch_net(*sparse_inputs) + output_type = sparse_out.layout + output.append(sparse_out) runtime_results.append( timer( "torch_net(*sparse_inputs)", @@ -133,8 +136,14 @@ def run_benchmark( output.append( torch.sparse_csr_tensor(*sp_out, size=dense_out.shape) ) + # Check MPACT and torch eager both return sparse csr output. + if output_type: + assert(output_type == torch.sparse_csr) else: output.append(torch.from_numpy(sp_out)) + # Check MPACT and torch eager both return dense output. + if output_type: + assert(output_type == torch.strided) invoker, f = mpact_jit_compile(torch_net, *sparse_inputs) compile_time_results.append( timer(