-
Notifications
You must be signed in to change notification settings - Fork 0
/
torchscript_e2e_config.py
63 lines (49 loc) · 2.21 KB
/
torchscript_e2e_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Copyright 2021 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import iree.runtime as ireert
import iree.compiler as ireec
from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend
from torch_mlir_e2e_test.torchscript.configs import LinalgOnTensorsBackendTestConfig
import sys
class IREEInvoker:
def __init__(self, iree_module):
self._iree_module = iree_module
def __getattr__(self, function_name: str):
def invoke(*args):
return self._iree_module[function_name](*args)
return invoke
class IREELinalgOnTensorsBackend(LinalgOnTensorsBackend):
"""Main entry-point for the reference backend."""
def __init__(self):
super().__init__()
def compile(self, imported_module):
"""Compiles an imported module, with a flat list of functions.
The module is expected to be in linalg-on-tensors + scalar code form.
TODO: More clearly define the backend contract. Generally this will
extend to support globals, lists, and other stuff.
Args:
imported_module: The MLIR module consisting of funcs in the torch
dialect.
Returns:
An opaque, backend specific compiled artifact object that can be
passed to `load`.
"""
original_stdout = sys.stdout
with open('../hbc_verification/HBC.mlir','w') as f:
sys.stdout = f
print(imported_module)
sys.stdout = original_stdout
return ireec.compile_str(str(imported_module),
target_backends=["dylib-llvm-aot"])
def load(self, flatbuffer) -> IREEInvoker:
"""Loads a compiled artifact into the runtime."""
vm_module = ireert.VmModule.from_flatbuffer(flatbuffer)
config = ireert.Config(driver_name="dylib")
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(vm_module)
return IREEInvoker(ctx.modules.module)
config = LinalgOnTensorsBackendTestConfig(IREELinalgOnTensorsBackend())
xfail_set = COMMON_TORCH_MLIR_LOWERING_XFAILS