Skip to content

Commit

Permalink
[mpact][compiler] extract linalg module import into own method
Browse files Browse the repository at this point in the history
  • Loading branch information
aartbik committed Sep 9, 2024
1 parent c21ae86 commit 2683803
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
12 changes: 8 additions & 4 deletions python/mpact/mpactbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,8 @@ def export_and_import(f, *args, **kwargs):
return fx_importer.module


def mpact_jit_compile(f, *args, opt_level=2, use_sp_it=False, **kwargs):
"""This method compiles the given callable using the MPACT backend."""
# Import module and lower into Linalg IR.
def mpact_linalg(f, *args, **kwargs):
"""Imports a function as module and lowers it into Linalg IR."""
module = export_and_import(f, *args, **kwargs)
run_pipeline_with_repro_report(
module,
Expand All @@ -333,7 +332,12 @@ def mpact_jit_compile(f, *args, opt_level=2, use_sp_it=False, **kwargs):
"Lowering TorchFX IR -> Linalg IR",
enable_ir_printing=False,
)
# Compile with MPACT backend compiler.
return module


def mpact_jit_compile(f, *args, opt_level=2, use_sp_it=False, **kwargs):
"""This method compiles the given callable using the MPACT backend."""
module = mpact_linalg(f, *args, **kwargs)
backend = MpactBackendCompiler(opt_level=opt_level, use_sp_it=use_sp_it)
compiled = backend.compile(module)
invoker = backend.load(compiled)
Expand Down
31 changes: 31 additions & 0 deletions test/python/mm_print.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# RUN: %PYTHON %s | FileCheck %s

import torch
import numpy as np

from mpact.mpactbackend import mpact_linalg

from mpact.models.kernels import MMNet


net = MMNet()

X = torch.arange(0, 16, dtype=torch.float32).view(4, 4)
Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4)

#
# CHECK: module {
# CHECK: func.func @main(%[[A0:.*]]: tensor<4x4xf32>, %[[A1:.*]]: tensor<4x4xf32>) -> tensor<4x4xf32> {
# CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
# CHECK: %[[T0:.*]] = tensor.empty() : tensor<4x4xf32>
# CHECK: %[[T1:.*]] = linalg.fill ins(%[[C0]] : f32) outs(%[[T0]] : tensor<4x4xf32>) -> tensor<4x4xf32>
# CHECK: %[[T2:.*]] = linalg.matmul
# CHECK-SAME: ins(%[[A0]], %[[A1]] : tensor<4x4xf32>, tensor<4x4xf32>)
# CHECK-SAME: outs(%[[T1]] : tensor<4x4xf32>) -> tensor<4x4xf32>
# CHECK: return %2 : tensor<4x4xf32>
# CHECK: }
# CHECK: }
#

linalg = mpact_linalg(net, X, Y)
print(linalg)

0 comments on commit 2683803

Please sign in to comment.