Skip to content

Commit

Permalink
[mpact][test] add a count-equal idiom (for sparse consideration) (#73)
Browse files Browse the repository at this point in the history
The equal operator currently does not sparsify under
PyTorch, but if it were, this would be a great candidate
to further optimize with doing the sum() without
materializing the intermediate result!
  • Loading branch information
aartbik authored Aug 27, 2024
1 parent a5c7bfa commit 664f828
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/mpact/models/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def forward(self, x):
return (x * x).sum()


class CountEq(torch.nn.Module):
def forward(self, x, s):
nums = (x == s).sum()
return nums


class FeatureScale(torch.nn.Module):
def forward(self, F):
sum_vector = torch.sum(F, dim=1)
Expand Down
56 changes: 56 additions & 0 deletions test/python/counteq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# RUN: %PYTHON %s | FileCheck %s

import torch
import numpy as np

from mpact.mpactbackend import mpact_jit

from mpact.models.kernels import CountEq


net = CountEq()

# Construct dense and sparse matrices.
A = torch.tensor(
[
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 2.0],
[0.0, 0.0, 1.0, 1.0],
[3.0, 0.0, 3.0, 0.0],
],
dtype=torch.float32,
)

# TODO: very interesting idiom to sparsify (collapse the sum
# into the eq for full sparsity), but needs PyTorch support
S = A
# S = A.to_sparse()
# S = A.to_sparse_csr()

#
# CHECK: pytorch
# CHECK: 10
# CHECK: 3
# CHECK: 1
# CHECK: 2
# CHECK: 0
# CHECK: mpact
# CHECK: 10
# CHECK: 3
# CHECK: 1
# CHECK: 2
# CHECK: 0
#

# Run it with PyTorch.
print("pytorch")
for i in range(5):
target = torch.tensor(i)
res = net(S, target).item()
print(res)

print("mpact")
for i in range(5):
target = torch.tensor(i)
res = mpact_jit(net, S, target)
print(res)

1 comment on commit 664f828

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 1.20.

Benchmark suite Current: 664f828 Previous: a5c7bfa Ratio
benchmark/python/benchmarks/regression_benchmark.py::test_nop_sparse 921316.0419683264 iter/sec (stddev: 3.040267017295168e-7) 1129071.6661179548 iter/sec (stddev: 8.14334988631514e-8) 1.23

This comment was automatically generated by workflow using github-action-benchmark.

CC: @reidtatge

Please sign in to comment.