From eca7917e14dd523e7f048d9e53bd35a55bbf5283 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 30 Jul 2024 16:32:45 -0700 Subject: [PATCH] [mpact][file-formats] add matrix market and extended frostt utils (#66) * [mpact][file-formats] add matrix market and extended frostt utils * add mm back * add benchmark util dep to test --- benchmark/python/utils/tensor_generator.py | 53 +++++++++++++++++++ test/CMakeLists.txt | 1 + test/python/file_formats.py | 61 ++++++++++++++++++++++ 3 files changed, 115 insertions(+) create mode 100644 test/python/file_formats.py diff --git a/benchmark/python/utils/tensor_generator.py b/benchmark/python/utils/tensor_generator.py index c98a68b..61e69a9 100644 --- a/benchmark/python/utils/tensor_generator.py +++ b/benchmark/python/utils/tensor_generator.py @@ -75,3 +75,56 @@ def generate_tensor( result = np.reshape(flat_output, shape).astype(dtype) return torch.from_numpy(result) + + +def print_matrix_market_format(tensor: torch.Tensor): + """Prints the matrix market format for a sparse matrix. + + Args: + tensor: sparse matrix (real type) + """ + if len(tensor.shape) != 2: + raise ValueError("Unexpected rank : %d (matrices only)" % len(tensor.shape)) + if tensor.dtype != torch.float32 and tensor.dtype != torch.float64: + raise ValueError("Unexpected type : %s" % tensor.dtype) + + h = tensor.shape[0] + w = tensor.shape[1] + nnz = sum([1 if tensor[i, j] != 0 else 0 for i in range(h) for j in range(w)]) + density = (100.0 * nnz) / tensor.numel() + print("%%MatrixMarket matrix coordinate real general") + print("% https://math.nist.gov/MatrixMarket/formats.html") + print("%") + print("%% density = %4.2f%%" % density) + print("%") + print(h, w, nnz) + for i in range(h): + for j in range(w): + if tensor[i, j] != 0: + print(i + 1, j + 1, tensor[i, j].item()) + + +def print_extended_frostt_format(tensor: torch.Tensor): + """Prints the Extended FROSTT format for a sparse tensor. + + Args: + tensor: sparse tensor + """ + a = tensor.numpy() + nnz = sum([1 if x != 0 else 0 for x in np.nditer(a)]) + density = (100.0 * nnz) / tensor.numel() + print("# Tensor in Extended FROSTT file format") + print("# http://frostt.io/tensors/file-formats.html") + print("# extended with two metadata lines:") + print("# rank nnz") + print("# dims (one per rank)") + print("#") + print("# density = %4.2f%%" % density) + print("#") + print(len(tensor.shape), nnz) + print(*tensor.shape, sep=" ") + it = np.nditer(a, flags=["multi_index"]) + for x in it: + if x != 0: + print(*[i + 1 for i in it.multi_index], sep=" ", end=" ") + print(x) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 019d820..43c3ab9 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -11,6 +11,7 @@ configure_lit_site_cfg( set(MPACT_TEST_DEPENDS FileCheck count not + MPACTBenchmarkPythonModules MPACTPythonModules TorchMLIRPythonModules torch-mlir-opt diff --git a/test/python/file_formats.py b/test/python/file_formats.py new file mode 100644 index 0000000..6736015 --- /dev/null +++ b/test/python/file_formats.py @@ -0,0 +1,61 @@ +# RUN: %PYTHON %s | FileCheck %s + +import numpy as np + +from mpact_benchmark.utils.tensor_generator import ( + generate_tensor, + print_matrix_market_format, + print_extended_frostt_format, +) + +x = generate_tensor( + seed=0, shape=(4, 7), sparsity=0.5, dtype=np.float32, drange=(4.0, 4.0) +) + +# CHECK: %%MatrixMarket matrix coordinate real general +# CHECK: % https://math.nist.gov/MatrixMarket/formats.html +# CHECK: % +# CHECK: % density = 50.00% +# CHECK: % +# CHECK: 4 7 14 +# CHECK: 1 2 4.0 +# CHECK: 1 3 4.0 +# CHECK: 1 6 4.0 +# CHECK: 2 4 4.0 +# CHECK: 2 5 4.0 +# CHECK: 2 7 4.0 +# CHECK: 3 1 4.0 +# CHECK: 3 3 4.0 +# CHECK: 3 4 4.0 +# CHECK: 3 7 4.0 +# CHECK: 4 2 4.0 +# CHECK: 4 4 4.0 +# CHECK: 4 5 4.0 +# CHECK: 4 7 4.0 +print_matrix_market_format(x) + +# CHECK: # Tensor in Extended FROSTT file format +# CHECK: # http://frostt.io/tensors/file-formats.html +# CHECK: # extended with two metadata lines: +# CHECK: # rank nnz +# CHECK: # dims (one per rank) +# CHECK: # +# CHECK: # density = 50.00% +# CHECK: # +# CHECK: 2 14 +# CHECK: 4 7 +# CHECK: 1 2 4.0 +# CHECK: 1 3 4.0 +# CHECK: 1 6 4.0 +# CHECK: 2 4 4.0 +# CHECK: 2 5 4.0 +# CHECK: 2 7 4.0 +# CHECK: 3 1 4.0 +# CHECK: 3 3 4.0 +# CHECK: 3 4 4.0 +# CHECK: 3 7 4.0 +# CHECK: 4 2 4.0 +# CHECK: 4 4 4.0 +# CHECK: 4 5 4.0 +# CHECK: 4 7 4.0 +print_extended_frostt_format(x)