From 92c8035864a5ab9fe45c020d76e4cc4c61618d78 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 21 Dec 2023 21:52:45 -0800 Subject: [PATCH] Add AOT support. --- python/shark_turbine/aot/support/ir_utils.py | 4 +- .../shark_turbine/dynamo/type_conversion.py | 23 +- python/shark_turbine/ops/iree.py | 4 +- python/shark_turbine/runtime/device.py | 2 +- python/shark_turbine/runtime/op_reg/base.py | 127 +++++++--- python/shark_turbine/support/conversions.py | 22 ++ .../transforms/general/custom_op_expansion.py | 238 ++++++++++++++++++ tests/ops/iree_test.py | 1 - tests/runtime/op_reg/kernel_aot_test.py | 62 +++++ 9 files changed, 436 insertions(+), 47 deletions(-) create mode 100644 python/shark_turbine/transforms/general/custom_op_expansion.py create mode 100644 tests/runtime/op_reg/kernel_aot_test.py diff --git a/python/shark_turbine/aot/support/ir_utils.py b/python/shark_turbine/aot/support/ir_utils.py index 5dc9e9c79..5dfe4e435 100644 --- a/python/shark_turbine/aot/support/ir_utils.py +++ b/python/shark_turbine/aot/support/ir_utils.py @@ -246,8 +246,10 @@ def create_tensor_global( array = np.array(detached_tensor) # We know that a Numpy array is a ReadableBuffer so ignore type error. contents = memoryview(array) # type: ignore + shape_desc = "_".join([str(d) for d in t.shape]) + blob_name = f"torch_tensor_{shape_desc}_{str(t.dtype)}" elements_attr = DenseResourceElementsAttr.get_from_buffer( - contents, "from_py", tensor_type + contents, blob_name, tensor_type ) ir_attrs["initial_value"] = elements_attr diff --git a/python/shark_turbine/dynamo/type_conversion.py b/python/shark_turbine/dynamo/type_conversion.py index c6d332447..79fb3ea43 100644 --- a/python/shark_turbine/dynamo/type_conversion.py +++ b/python/shark_turbine/dynamo/type_conversion.py @@ -46,7 +46,7 @@ def __init__(self, context: Context): self.torch_type_to_native ) - def torch_type_to_native(self, torch_type: IrType) -> IrType: + def torch_type_to_native(self, torch_type: IrType, signless: bool = True) -> IrType: """Converts a presumed torch type to a corresponding native type. This mirrors the type conversion in torch-mlir's BackendTypeConversion.cpp. @@ -56,6 +56,8 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType: !torch.float -> f64 !torch.bool -> i1 !torch.vtensor -> tensor + + If `signless=False`, then integer types will retain their signs. """ # We don't presently have API support for introspecting torch type, # and even if we did, it is likely that this is more efficient. @@ -66,7 +68,11 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType: if name == "bool": return IntegerType.get_signless(1) if name == "int": - return IntegerType.get_signless(64) + return ( + IntegerType.get_signless(64) + if signless + else IntegerType.get_signed(64) + ) elif name == "float": return F64Type.get() elif name == "vtensor": @@ -75,7 +81,7 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType: dim_list_str, dtype_str = tm.groups() dim_list = parse_tensor_dim_list(dim_list_str) dtype = self.convert_torch_element_type_to_native( - IrType.parse(dtype_str) + IrType.parse(dtype_str), signless=signless ) # TODO: Eliminate RankedTensorType dependence on Location. # See: https://github.com/nod-ai/SHARK-Turbine/issues/145 @@ -83,14 +89,17 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType: return RankedTensorType.get(dim_list, dtype) raise TypeError(f"Unsupported torch type conversion for {torch_type}") - def convert_torch_element_type_to_native(self, torch_type: IrType) -> IrType: + def convert_torch_element_type_to_native( + self, torch_type: IrType, signless: bool = True + ) -> IrType: # Torch uses the builtin type hierarchy of IntegerType and FloatType # to represent dtypes. These are mostly the same, but it always uses # signed IntegerTypes which we must convert to signless for the native # type system. - if IntegerType.isinstance(torch_type): - signed_int_type = IntegerType(torch_type) - return IntegerType.get_signless(signed_int_type.width) + if signless: + if IntegerType.isinstance(torch_type): + signed_int_type = IntegerType(torch_type) + return IntegerType.get_signless(signed_int_type.width) return torch_type def materialize_native_to_torch( diff --git a/python/shark_turbine/ops/iree.py b/python/shark_turbine/ops/iree.py index 77cdfa047..feabdc23b 100644 --- a/python/shark_turbine/ops/iree.py +++ b/python/shark_turbine/ops/iree.py @@ -44,7 +44,7 @@ def _emit_tensor_trace(kb: KernelBuilder, key: str, ts: list[Value]): @CustomOp.register(library=IREE_LIBRARY) class trace_tensor(CustomOp): name = "trace_tensor" - signature = "(str trace_key, Tensor t) -> ()" + signature = "(str trace_key, Tensor tensor) -> ()" def select(self, ksel: KernelSelection): ksel.attr_str(0) @@ -58,7 +58,7 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): @CustomOp.register(library=IREE_LIBRARY) class trace_tensors(CustomOp): name = "trace_tensors" - signature = "(str trace_key, Tensor[] self) -> ()" + signature = "(str trace_key, Tensor[] tensors) -> ()" def select(self, ksel: KernelSelection): ksel.attr_str(0) diff --git a/python/shark_turbine/runtime/device.py b/python/shark_turbine/runtime/device.py index a07e139cf..aef747ba9 100644 --- a/python/shark_turbine/runtime/device.py +++ b/python/shark_turbine/runtime/device.py @@ -263,7 +263,7 @@ def _device_import_torch_tensor_cpu(device: Device, t: torch.Tensor) -> HalBuffe memory_type=MemoryType.DEVICE_LOCAL, allowed_usage=BufferUsage.DEFAULT, device=hal_device, - buffer=t.numpy(), + buffer=t.detach().numpy(), element_type=element_type, ) return bv diff --git a/python/shark_turbine/runtime/op_reg/base.py b/python/shark_turbine/runtime/op_reg/base.py index a0832269c..e258aa5b8 100644 --- a/python/shark_turbine/runtime/op_reg/base.py +++ b/python/shark_turbine/runtime/op_reg/base.py @@ -8,7 +8,7 @@ dispatcher. """ -from typing import Any, Callable, Optional, Type, Union +from typing import Any, Callable, Optional, Sequence, Type, Union from abc import ABC, abstractmethod, abstractproperty import functools @@ -67,6 +67,14 @@ def def_library(ns) -> torch.library.Library: return torch.library.Library(ns, "DEF") +def default_dispatch_keys() -> list[str]: + # TODO: Dynamically determine what devices to register against. + # Note that we have to register against specific keys instead of the + # fallback, as fallback is too broad and breaks certain elements of + # fx tracing. + return ["CPU"] + + # All such custom kernels are registered in the 'turbine' library/namespace. # We also allow extending existing libraries outside of this, but that is # the non default case. @@ -81,7 +89,7 @@ def register( op_class: Optional[Type["CustomOp"]] = None, *, library: torch.library.Library = TURBINE_LIBRARY, - dispatch_key: str = "", + dispatch_key: Union[str, Sequence[str], None] = None, register_meta: bool = True, register_impl: bool = True, ) -> Callable: @@ -120,7 +128,7 @@ def __init__( self, *, library: torch.library.Library, - dispatch_key: str, + dispatch_key: Union[str, Sequence[str], None], register_meta: bool, register_impl: bool, ): @@ -138,7 +146,15 @@ def __init__( library.impl(name, _get_meta_impl(self), "Meta") if register_impl: - library.impl(name, _create_impl_trampoline(self), dispatch_key) + if dispatch_key is None: + dispatch_key = default_dispatch_keys() + elif isinstance(dispatch_key, str): + dispatch_key = [dispatch_key] + for k in dispatch_key: + library.impl(name, _create_impl_trampoline(self), k) + + fq_name = f"{library.ns}.{name}" + ALL_CUSTOM_OP_REGS[fq_name] = self @abstractproperty def name(self) -> str: @@ -190,7 +206,12 @@ def generate(self, ksel: "KernelSelection", kb: "KernelBuilder"): ... -class KernelSelection: +# All instantiated CustomOp instances, keyed by fully qualified name. This is +# used by the AOT compiler to expand custom ops that were captured in a trace. +ALL_CUSTOM_OP_REGS: dict[str, CustomOp] = {} + + +class KernelSelection(ABC): """Represents a selected kernel based on a concrete signature. The `CustomOp.select` method must yield an instance of this, and @@ -204,17 +225,15 @@ class KernelSelection: """ __slots__ = [ - "args", "arg_descs", "op", "result_descs", "variant", ] - def __init__(self, op: CustomOp, args: list[Any]): + def __init__(self, op: CustomOp, arg_arity: int): self.op = op - self.args = args - self.arg_descs: list[Optional[ArgDescriptor]] = len(args) * [None] + self.arg_descs: list[Optional[ArgDescriptor]] = arg_arity * [None] self.result_descs: list[ArgDescriptor] = [] self.variant: str = "default" @@ -237,8 +256,11 @@ def __repr__(self): def generate_meta_returns(self) -> Any: results = [d.generate_meta() for d in self.result_descs] - if len(results) == 1: + arity = len(results) + if arity == 1: return results[0] + elif arity == 0: + return None else: return tuple(results) @@ -255,6 +277,7 @@ def spec_key(self) -> str: f"Error generating spec_key from:\n{textwrap.indent(repr(self), ' ')}" ) from e + @abstractmethod def arg_tensor(self, arg: int) -> "TensorArg": """Declares an argument to allow any ranked tensor and to specialize for each rank and dtype. @@ -262,6 +285,59 @@ def arg_tensor(self, arg: int) -> "TensorArg": Returns the argument descriptor, which can be used to further inspect or constrain the selection. It will default to allowing all dimensions to be dynamic. """ + ... + + @abstractmethod + def arg_tensor_list(self, arg: int) -> "TensorListArg": + """Declares an argument to accept a list of tensors which will be specialized + for the list size and each rank/dtype. + + Returns the argument descriptor, which can be used to further inspect or constrain + the selection. It will default to allowing all dimensions to be dynamic. + """ + ... + + @abstractmethod + def arg_int(self, arg: int) -> "IntArg": + """Declares an argument to be an integer value that can take any value. + + Returns the argument descriptor, which can be used to further inspect or constrain + the selection. + """ + ... + + @abstractmethod + def attr_str(self, arg: int) -> "AttrArg": + """Declares an argument to be a string attribute. + + Such arguments are not materialized in the IR as Values but may be used to + generate the IR. In AOT contexts, they must be derived from static values. + """ + ... + + @abstractmethod + def return_tensor(self, t: Tensor) -> "TensorArg": + """Marks the next return value as a Tensor. + + By default, it will be rank and dtype specialized but have completely dynamic + dimensions. Dimensions can be further constrained by modifying the returned + descriptor. + """ + ... + + +class EagerKernelSelection(KernelSelection): + """Kernel selection specialized for eager arguments.""" + + __slots__ = [ + "args", + ] + + def __init__(self, op: CustomOp, args: list[Any]): + super().__init__(op, len(args)) + self.args = args + + def arg_tensor(self, arg: int) -> "TensorArg": arg_descs = self.arg_descs arg_value = self.args[arg] assert arg_descs[arg] is None, f"Already constrained argument {arg}" @@ -272,12 +348,6 @@ def arg_tensor(self, arg: int) -> "TensorArg": return desc def arg_tensor_list(self, arg: int) -> "TensorListArg": - """Declares an argument to accept a list of tensors which will be specialized - for the list size and each rank/dtype. - - Returns the argument descriptor, which can be used to further inspect or constrain - the selection. It will default to allowing all dimensions to be dynamic. - """ arg_descs = self.arg_descs arg_value = self.args[arg] assert arg_descs[arg] is None, f"Already constrained argument {arg}" @@ -288,11 +358,6 @@ def arg_tensor_list(self, arg: int) -> "TensorListArg": return desc def arg_int(self, arg: int) -> "IntArg": - """Declares an argument to be an integer value that can take any value. - - Returns the argument descriptor, which can be used to further inspect or constrain - the selection. - """ arg_descs = self.arg_descs arg_value = self.args[arg] assert arg_descs[arg] is None, f"Already constrained argument {arg}" @@ -303,11 +368,6 @@ def arg_int(self, arg: int) -> "IntArg": return desc def attr_str(self, arg: int) -> "AttrArg": - """Declares an argument to be a string attribute. - - Such arguments are not materialized in the IR as Values but may be used to - generate the IR. In AOT contexts, they must be derived from static values. - """ arg_descs = self.arg_descs arg_value = self.args[arg] assert arg_descs[arg] is None, f"Already constrained argument {arg}" @@ -318,12 +378,6 @@ def attr_str(self, arg: int) -> "AttrArg": return desc def return_tensor(self, t: Tensor) -> "TensorArg": - """Marks the next return value as a Tensor. - - By default, it will be rank and dtype specialized but have completely dynamic - dimensions. Dimensions can be further constrained by modifying the returned - descriptor. - """ desc = TensorArg(t) self.result_descs.append(desc) return desc @@ -538,6 +592,7 @@ def __init__( self.ip = ip self.module_body = module_body self.symbol_table = symbol_table + self.yielded = False def arg_value(self, index: int) -> Union[list[Value], Value]: """Gets the concrete IR `Value` for the argument at `index`. @@ -629,7 +684,7 @@ def __init__( for desc in ksel.arg_descs: arity = desc.ir_arity if not desc.is_list: - if desc.ir_arity == 1: + if arity == 1: arg_bindings.append(block_arguments[block_arg_index]) block_arg_index += 1 else: @@ -671,8 +726,10 @@ def create_module( def yield_results(self, *results: Value): """Yields results of the kernel computation.""" + assert not self.yielded, "yield_results has already been called" with self.ip, Location.unknown(): func_d.ReturnOp(results) + self.yielded = True ############################################################################### @@ -687,7 +744,7 @@ def _get_library_op(library: torch.library.Library, name: str) -> Any: def _get_meta_impl(op: CustomOp): def meta(*args): - sel = KernelSelection(op, args) + sel = EagerKernelSelection(op, args) op.select(sel) if logger.isEnabledFor(logging.DEBUG): logging.debug( @@ -706,7 +763,7 @@ def _create_impl_trampoline(op: CustomOp): ) def handler(*args): - ksel = KernelSelection(op, args) + ksel = EagerKernelSelection(op, args) op.select(ksel) if logger.isEnabledFor(logging.DEBUG): logging.debug( diff --git a/python/shark_turbine/support/conversions.py b/python/shark_turbine/support/conversions.py index 132cfd428..c1a9fd2f4 100644 --- a/python/shark_turbine/support/conversions.py +++ b/python/shark_turbine/support/conversions.py @@ -53,6 +53,28 @@ torch.complex128: lambda: ComplexType.get(F64Type.get()), } +TORCH_DTYPE_TO_SIGNED_MLIR_TYPE_ASM = { + torch.float16: "f16", + torch.bfloat16: "bf16", + torch.float32: "f32", + torch.float64: "f64", + torch.uint8: "ui8", + torch.int8: "si8", + torch.int16: "si16", + torch.int32: "si32", + torch.int64: "si64", + torch.bool: "i1", + torch.qint8: "si8", + torch.quint8: "ui8", + torch.complex32: "complex", + torch.complex64: "complex", + torch.complex128: "complex", +} + +SIGNED_MLIR_TYPE_ASM_TO_TORCH_DTYPE = dict( + (v, k) for k, v in TORCH_DTYPE_TO_SIGNED_MLIR_TYPE_ASM.items() +) + TORCH_DTYPE_TO_IREE_TYPE_ASM = { torch.float16: "f16", torch.bfloat16: "bf16", diff --git a/python/shark_turbine/transforms/general/custom_op_expansion.py b/python/shark_turbine/transforms/general/custom_op_expansion.py new file mode 100644 index 000000000..0caa01493 --- /dev/null +++ b/python/shark_turbine/transforms/general/custom_op_expansion.py @@ -0,0 +1,238 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch +from torch import Tensor + +from ...dynamo.type_conversion import ( + NativeTypeConverter, +) + +from ...runtime.op_reg.base import ( + ALL_CUSTOM_OP_REGS, + AttrArg, + IntArg, + CustomOp, + KernelBuilder, + KernelSelection, + TensorArg, + TensorListArg, +) + +from ...support.conversions import ( + MLIR_TYPE_ASM_TO_TORCH_DTYPE, +) + +from ...support.ir_imports import ( + Block, + InsertionPoint, + OpResult, + Operation, + RankedTensorType, + StringAttr, + SymbolTable, + Value, +) + +from ..rewriter import ( + Pass, +) + + +class ExpandCustomOpsPass(Pass): + def __init__(self, root_op: Operation): + super().__init__(root_op) + # Track pending deletions in a dict to preserve order and unique. + self.ops_to_delete: dict[Operation, None] = {} + self.type_converter = NativeTypeConverter(root_op.context) + self.symbol_table = SymbolTable(root_op) + + def delete_op(self, op): + self.ops_to_delete[op.operation] = None + + def run(self): + for mr in self.funcs: + self.expand_func(mr.op) + for op in self.ops_to_delete.keys(): + self.erase_unused_op(op) + + def expand_func(self, func_op: Operation): + """Expands custom ops in a traced torch function. + + This finds operations of the form: + %0 = torch.operator "torch.ns.op" + + And looks them up in the ALL_CUSTOM_OP_REGS. If it originated from one of those + registered ops, then it will be expanded in place. + """ + name_prefix = "torch." + + for block in func_op.regions[0].blocks: + for op in block.operations: + if op.name == "torch.operator": + custom_op_name = StringAttr(op.attributes["name"]).value + if custom_op_name.startswith(name_prefix): + local_name = custom_op_name[len(name_prefix) :] + custom_op_reg = ALL_CUSTOM_OP_REGS.get(local_name) + if custom_op_reg is not None: + self.expand_custom_op(custom_op_reg, op) + + def expand_custom_op(self, op_reg: CustomOp, op: Operation): + original_operands: list[Value] = list(op.operands) + ksel = AOTKernelSelection( + op_reg, original_operands, list(op.results), self.type_converter + ) + op_reg.select(ksel) + + module_body = self.root_op.regions[0].blocks[0] + kb = InlineKernelBuilder( + ksel, + op, + type_converter=self.type_converter, + module_body=module_body, + symbol_table=self.symbol_table, + ) + with kb.ip, kb.location: + op_reg.generate(ksel, kb) + assert kb.yielded, "Custom op generation did not yield_results()" + + self.delete_op(op) + + +class AOTKernelSelection(KernelSelection): + __slots__ = [ + "operands", + "results", + "type_converter", + ] + + def __init__( + self, + op: CustomOp, + operands: list[Value], + results: list[Value], + type_converter: NativeTypeConverter, + ): + super().__init__(op, len(operands)) + self.operands = operands + self.results = results + self.type_converter = type_converter + + def arg_tensor(self, arg: int) -> TensorArg: + # This is annoying: We have to go from the Torch MLIR type system to the + # original torch.tensor Python type system. We do this by way of the native + # type converter because it has the mapping pathway we need. This is one of the + # only places in the code where we have to go this way to preserve the facade. + # Everywhere else is going from Torch -> IREE native. + arg_descs = self.arg_descs + assert arg_descs[arg] is None, f"Already constrained argument {arg}" + operand = self.operands[arg] + signed_native_type = self.type_converter.torch_type_to_native(operand.type) + try: + rtt = RankedTensorType(signed_native_type) + # TODO: We need to do the FakeMode/ShapeEnv dance to create a symbolic + # fake tensor here. + except TypeError as e: + raise TypeError( + f"Argument type mismatch from Torch IR for arg {arg}: Expected ranked tensor, got {signed_native_type}" + ) from e + assert not any( + rtt.is_dynamic_dim(i) for i in range(rtt.rank) + ), "NYI: Dynamic shape tensors in custom op AOT mode" + element_type_asm = str(rtt.element_type) + try: + dtype = MLIR_TYPE_ASM_TO_TORCH_DTYPE[element_type_asm] + except KeyError as e: + raise AssertionError( + f"Could not find dtype mapping for {element_type_asm} in MLIR_TYPE_ASM_TO_TORCH_DTYPE" + ) + t = torch.empty(rtt.shape, dtype=dtype, device="meta") + arg_descs[arg] = desc = TensorArg(t) + return desc + + def arg_tensor_list(self, arg: int) -> TensorListArg: + raise NotImplementedError("NYI: AOT arg_tensor_list") + + def arg_int(self, arg: int) -> IntArg: + raise NotImplementedError("NYI: AOT arg_int") + + def attr_str(self, arg: int) -> AttrArg: + arg_descs = self.arg_descs + assert arg_descs[arg] is None, f"Already constrained argument {arg}" + operand = self.operands[arg] + ty = operand.type + assert ( + str(ty) == "!torch.str" + ), f"Argument type mismatch from Torch IR for {arg}: Expected !torch.str, got {ty}" + str_value = _get_constant_str_from_value(operand) + arg_descs[arg] = desc = AttrArg(str_value) + return desc + + def return_tensor(self, t: Tensor) -> TensorArg: + raise NotImplementedError("NYI: return_tensor") + + +def _get_constant_str_from_value(v: Value) -> str: + """Given a constant str producer, return the str. + + Example: %str = torch.constant.str "TEST" + """ + constant_op = OpResult(v).owner + assert ( + constant_op.name == "torch.constant.str" + ), f"Expected constant !torch.str to be produced by a torch.constant.str op but got: {constant_op}" + return StringAttr(constant_op.attributes["value"]).value + + +class InlineKernelBuilder(KernelBuilder): + def __init__( + self, + ksel: KernelSelection, + torch_op: Operation, + *, + type_converter: NativeTypeConverter, + module_body: Block, + symbol_table: SymbolTable, + ): + location = torch_op.location + ip = InsertionPoint(torch_op) + with ip, location: + operands = list(torch_op.operands) + arg_bindings = [] + for desc, operand in zip(ksel.arg_descs, operands): + arity = desc.ir_arity + if not desc.is_list: + if arity == 1: + arg_bindings.append( + type_converter.materialize_torch_to_native(operand) + ) + else: + arg_bindings.append(None) + else: + # arg_bindings.extend(native_operands) + raise NotImplementedError("NYI: AOT custom op list arguments") + + super().__init__( + ksel, + arg_bindings=arg_bindings, + ip=ip, + module_body=module_body, + symbol_table=symbol_table, + ) + self.location = location + self.torch_op = torch_op + + def yield_results(self, *results: Value): + """Yields results of the kernel computation.""" + assert not self.yielded, "yield_results has already been called" + with self.ip, self.location: + torch_op_results: list[Value] = list(self.torch_op.results) + assert len(results) == len( + torch_op_results + ), f"Mismatched yield_results with custom op results" + for new_result, old_result in zip(results, torch_op_results): + old_result.replace_all_uses_with(new_result) + self.yielded = True diff --git a/tests/ops/iree_test.py b/tests/ops/iree_test.py index befb5a749..f10643026 100644 --- a/tests/ops/iree_test.py +++ b/tests/ops/iree_test.py @@ -22,7 +22,6 @@ def testTraceList(self): t2 = torch.randn(1, 8) ops.iree.trace_tensors("TEST 2", [t1, t2]) ops.iree.trace_tensors("TEST 1", [t1]) - ops.iree.trace_tensors("TEST 0", []) if __name__ == "__main__": diff --git a/tests/runtime/op_reg/kernel_aot_test.py b/tests/runtime/op_reg/kernel_aot_test.py new file mode 100644 index 000000000..853c69486 --- /dev/null +++ b/tests/runtime/op_reg/kernel_aot_test.py @@ -0,0 +1,62 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import unittest + +import torch +import torch.nn as nn + +import shark_turbine.aot as aot +import shark_turbine.ops as ops + +from shark_turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass + + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.layer0 = nn.Linear(8, 8, bias=True) + self.layer1 = nn.Linear(8, 4, bias=True) + self.layer2 = nn.Linear(4, 2, bias=True) + self.layer3 = nn.Linear(2, 2, bias=True) + + def forward(self, x: torch.Tensor): + x = self.layer0(x) + x = torch.sigmoid(x) + ops.iree.trace_tensor("LAYER0", x) + x = self.layer1(x) + x = torch.sigmoid(x) + ops.iree.trace_tensor("LAYER1", x) + x = self.layer2(x) + x = torch.sigmoid(x) + ops.iree.trace_tensor("LAYER2", x) + x = self.layer3(x) + ops.iree.trace_tensor("LAYER3", x) + return x + + +class KernelRegTest(unittest.TestCase): + def testTrace(self): + mlp = MLP() + prog = aot.export(mlp, torch.empty(97, 8, dtype=torch.float32)) + # print("ORIGINAL EXPORTED:") + # print(prog.print_readable()) + + p = ExpandCustomOpsPass(prog.mlir_module) + p.run() + + print("CUSTOM OP CONVERTED:") + print(prog.mlir_module) + + def testEager(self): + mlp = MLP() + mlp.forward(torch.empty(97, 8, dtype=torch.float32)) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main()