From adff010e6e55a38c2bfc3944101b5a506e50b32e Mon Sep 17 00:00:00 2001 From: yinying-lisa-li Date: Wed, 26 Jun 2024 18:04:19 +0000 Subject: [PATCH] check output type --- benchmark/python/utils/benchmark_utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/benchmark/python/utils/benchmark_utils.py b/benchmark/python/utils/benchmark_utils.py index ec29309..ac9aa63 100644 --- a/benchmark/python/utils/benchmark_utils.py +++ b/benchmark/python/utils/benchmark_utils.py @@ -65,12 +65,15 @@ def run_benchmark( ): """Run benchmark with specified backends.""" output = [] + output_type = None with torch.no_grad(): for backend in backends: match backend: case Backends.TORCH_SPARSE_EAGER: - output.append(torch_net(*sparse_inputs)) + sparse_out = torch_net(*sparse_inputs) + output_type = sparse_out.layout + output.append(sparse_out) runtime_results.append( timer( "torch_net(*sparse_inputs)", @@ -133,8 +136,14 @@ def run_benchmark( output.append( torch.sparse_csr_tensor(*sp_out, size=dense_out.shape) ) + # Check MPACT and torch eager both return sparse csr output. + if output_type: + assert(output_type == torch.sparse_csr) else: output.append(torch.from_numpy(sp_out)) + # Check MPACT and torch eager both return dense output. + if output_type: + assert(output_type == torch.strided) invoker, f = mpact_jit_compile(torch_net, *sparse_inputs) compile_time_results.append( timer(