Skip to content

Commit

Permalink
[mpact][compiler] enable norm and scale test with MPACT JIT (#56)
Browse files Browse the repository at this point in the history
Note:
We still have an empty first row bug lingering
but this PR at least enables the JITting of the
two kernels with MPACT
  • Loading branch information
aartbik authored Jun 28, 2024
1 parent afd4aa1 commit e56115a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
18 changes: 13 additions & 5 deletions test/python/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,18 @@
# CHECK: [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000],
# CHECK: [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000],
# CHECK: [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]{{\]}})
#
# TODO: first row?
#
# CHECK: mpact
# CHECK: {{\[}}[0. 0. 0. 0. 0. 0. 0. 0. ]
# CHECK: [0. 0.25 0. 0. 0.25 0. 0. 0. ]
# CHECK: [0. 0. 1. 0. 0. 0. 0. 0. ]
# CHECK: [0. 0. 0. 0.25 0.25 0. 0. 0. ]
# CHECK: [0. 0. 0. 0.25 0.25 0. 0. 0. ]
# CHECK: [0. 0. 0. 0. 0. 1. 0. 0. ]
# CHECK: [0. 0. 0. 0. 0. 0. 1. 0. ]
# CHECK: [0. 0. 0. 0. 0. 0. 0. 1. ]{{\]}}
#

# Run it with PyTorch.
Expand All @@ -37,9 +48,6 @@
print(res)

# Run it with MPACT.
#
# TODO: make this work, crashes in TORCH-MLIR
#
print("mpact")
# res = mpact_jit(net, adj_mat)
# print(res)
res = mpact_jit(net, adj_mat)
print(res)
24 changes: 19 additions & 5 deletions test/python/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,24 @@
# CHECK: [0.0774, 0.1561, 0.1275, 0.3896, 0.0735, 0.1128, 0.0630],
# CHECK: [0.0093, 0.0611, 0.2731, 0.2124, 0.2180, 0.1546, 0.0716],
# CHECK: [0.2026, 0.0115, 0.0481, 0.0839, 0.2826, 0.2749, 0.0964]{{\]}})
#
# TODO: first row?
#
# CHECK: mpact
# CHECK: {{\[}}[0. 0. 0. 0. 0. 0.
# CHECK: 0. ]
# CHECK: [0.30635384 0.15570773 0.21608633 0.11923195 0.13728413 0.00762967
# CHECK: 0.05770639]
# CHECK: [0.08555716 0.15095268 0.20310582 0.23290026 0.04687909 0.08217437
# CHECK: 0.19843069]
# CHECK: [0.22065267 0.09574053 0.2107584 0.10111907 0.13330552 0.22970453
# CHECK: 0.00871931]
# CHECK: [0.07743214 0.15609969 0.12754099 0.3896042 0.07353575 0.11279855
# CHECK: 0.06298868]
# CHECK: [0.00931544 0.06112389 0.2730649 0.2123639 0.21801054 0.15456341
# CHECK: 0.07155795]
# CHECK: [0.20259099 0.01148908 0.04807246 0.08394676 0.28260148 0.2748705
# CHECK: 0.09642864]]
#

# Run it with PyTorch.
Expand All @@ -31,9 +48,6 @@
print(res)

# Run it with MPACT.
#
# TODO: make this work, crashes in TORCH-MLIR
#
print("mpact")
# res = mpact_jit(net, features)
# print(res)
res = mpact_jit(net, features)
print(res)

0 comments on commit e56115a

Please sign in to comment.