From 664f828a95fd68221dc33c459af603ba867101c6 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 27 Aug 2024 13:28:53 -0700 Subject: [PATCH] [mpact][test] add a count-equal idiom (for sparse consideration) (#73) The equal operator currently does not sparsify under PyTorch, but if it were, this would be a great candidate to further optimize with doing the sum() without materializing the intermediate result! --- python/mpact/models/kernels.py | 6 ++++ test/python/counteq.py | 56 ++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) create mode 100644 test/python/counteq.py diff --git a/python/mpact/models/kernels.py b/python/mpact/models/kernels.py index 36e2394..14f505c 100644 --- a/python/mpact/models/kernels.py +++ b/python/mpact/models/kernels.py @@ -36,6 +36,12 @@ def forward(self, x): return (x * x).sum() +class CountEq(torch.nn.Module): + def forward(self, x, s): + nums = (x == s).sum() + return nums + + class FeatureScale(torch.nn.Module): def forward(self, F): sum_vector = torch.sum(F, dim=1) diff --git a/test/python/counteq.py b/test/python/counteq.py new file mode 100644 index 0000000..3cdd90a --- /dev/null +++ b/test/python/counteq.py @@ -0,0 +1,56 @@ +# RUN: %PYTHON %s | FileCheck %s + +import torch +import numpy as np + +from mpact.mpactbackend import mpact_jit + +from mpact.models.kernels import CountEq + + +net = CountEq() + +# Construct dense and sparse matrices. +A = torch.tensor( + [ + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 2.0], + [0.0, 0.0, 1.0, 1.0], + [3.0, 0.0, 3.0, 0.0], + ], + dtype=torch.float32, +) + +# TODO: very interesting idiom to sparsify (collapse the sum +# into the eq for full sparsity), but needs PyTorch support +S = A +# S = A.to_sparse() +# S = A.to_sparse_csr() + +# +# CHECK: pytorch +# CHECK: 10 +# CHECK: 3 +# CHECK: 1 +# CHECK: 2 +# CHECK: 0 +# CHECK: mpact +# CHECK: 10 +# CHECK: 3 +# CHECK: 1 +# CHECK: 2 +# CHECK: 0 +# + +# Run it with PyTorch. +print("pytorch") +for i in range(5): + target = torch.tensor(i) + res = net(S, target).item() + print(res) + +print("mpact") +for i in range(5): + target = torch.tensor(i) + res = mpact_jit(net, S, target) + print(res)