diff --git a/python/shark_turbine/aot/support/ir_utils.py b/python/shark_turbine/aot/support/ir_utils.py index 5dc9e9c79..5dfe4e435 100644 --- a/python/shark_turbine/aot/support/ir_utils.py +++ b/python/shark_turbine/aot/support/ir_utils.py @@ -246,8 +246,10 @@ def create_tensor_global( array = np.array(detached_tensor) # We know that a Numpy array is a ReadableBuffer so ignore type error. contents = memoryview(array) # type: ignore + shape_desc = "_".join([str(d) for d in t.shape]) + blob_name = f"torch_tensor_{shape_desc}_{str(t.dtype)}" elements_attr = DenseResourceElementsAttr.get_from_buffer( - contents, "from_py", tensor_type + contents, blob_name, tensor_type ) ir_attrs["initial_value"] = elements_attr diff --git a/python/shark_turbine/dynamo/type_conversion.py b/python/shark_turbine/dynamo/type_conversion.py index c6d332447..79fb3ea43 100644 --- a/python/shark_turbine/dynamo/type_conversion.py +++ b/python/shark_turbine/dynamo/type_conversion.py @@ -46,7 +46,7 @@ def __init__(self, context: Context): self.torch_type_to_native ) - def torch_type_to_native(self, torch_type: IrType) -> IrType: + def torch_type_to_native(self, torch_type: IrType, signless: bool = True) -> IrType: """Converts a presumed torch type to a corresponding native type. This mirrors the type conversion in torch-mlir's BackendTypeConversion.cpp. @@ -56,6 +56,8 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType: !torch.float -> f64 !torch.bool -> i1 !torch.vtensor -> tensor + + If `signless=False`, then integer types will retain their signs. """ # We don't presently have API support for introspecting torch type, # and even if we did, it is likely that this is more efficient. @@ -66,7 +68,11 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType: if name == "bool": return IntegerType.get_signless(1) if name == "int": - return IntegerType.get_signless(64) + return ( + IntegerType.get_signless(64) + if signless + else IntegerType.get_signed(64) + ) elif name == "float": return F64Type.get() elif name == "vtensor": @@ -75,7 +81,7 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType: dim_list_str, dtype_str = tm.groups() dim_list = parse_tensor_dim_list(dim_list_str) dtype = self.convert_torch_element_type_to_native( - IrType.parse(dtype_str) + IrType.parse(dtype_str), signless=signless ) # TODO: Eliminate RankedTensorType dependence on Location. # See: https://github.com/nod-ai/SHARK-Turbine/issues/145 @@ -83,14 +89,17 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType: return RankedTensorType.get(dim_list, dtype) raise TypeError(f"Unsupported torch type conversion for {torch_type}") - def convert_torch_element_type_to_native(self, torch_type: IrType) -> IrType: + def convert_torch_element_type_to_native( + self, torch_type: IrType, signless: bool = True + ) -> IrType: # Torch uses the builtin type hierarchy of IntegerType and FloatType # to represent dtypes. These are mostly the same, but it always uses # signed IntegerTypes which we must convert to signless for the native # type system. - if IntegerType.isinstance(torch_type): - signed_int_type = IntegerType(torch_type) - return IntegerType.get_signless(signed_int_type.width) + if signless: + if IntegerType.isinstance(torch_type): + signed_int_type = IntegerType(torch_type) + return IntegerType.get_signless(signed_int_type.width) return torch_type def materialize_native_to_torch( diff --git a/python/shark_turbine/ops/__init__.py b/python/shark_turbine/ops/__init__.py new file mode 100644 index 000000000..3f4a4554e --- /dev/null +++ b/python/shark_turbine/ops/__init__.py @@ -0,0 +1,7 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from . import iree diff --git a/python/shark_turbine/ops/iree.py b/python/shark_turbine/ops/iree.py new file mode 100644 index 000000000..2fb7485c0 --- /dev/null +++ b/python/shark_turbine/ops/iree.py @@ -0,0 +1,74 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Custom ops for built-in IREE functionality.""" + +from ..support.ir_imports import ( + RankedTensorType, + StringAttr, + Value, + flow_d, + tensor_d, +) + +from ..runtime.op_reg import ( + CustomOp, + KernelBuilder, + KernelSelection, + def_library, +) + +__all__ = [ + "trace", +] + +IREE_LIBRARY = def_library("iree") + + +################################################################################ +# trace_tensor / trace_tensors +# See the flow.tensor_trace op for details. In essence: +# * trace_key is a name to label tensors with (intended for log filtering) +# * tensor or tensors are values to log a value for +################################################################################ + + +def _emit_tensor_trace(kb: KernelBuilder, key: str, ts: list[Value]): + dynamic_dims = [] + for t in ts: + rtt = RankedTensorType(t.type) + for i in range(rtt.rank): + if rtt.is_dynamic_dim(i): + dynamic_dims.append(tensor_d.dim(t, kb.constant_index(i))) + flow_d.TensorTraceOp(StringAttr.get(key), ts, dynamic_dims) + + +@CustomOp.register(library=IREE_LIBRARY) +class trace_tensor(CustomOp): + signature = "trace_tensor(str trace_key, Tensor tensor) -> ()" + + def select(self, ksel: KernelSelection): + ksel.attr_str(0) + ksel.arg_tensor(1) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + _emit_tensor_trace(kb, ksel.arg_descs[0].v, [kb.arg_bindings[1]]) + kb.yield_results() + + +@CustomOp.register(library=IREE_LIBRARY) +class trace_tensors(CustomOp): + signature = "trace_tensors(str trace_key, Tensor[] tensors) -> ()" + + def select(self, ksel: KernelSelection): + ksel.attr_str(0) + ksel.arg_tensor_list(1) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + ts = kb.arg_bindings[1] + if len(ts) >= 1: + _emit_tensor_trace(kb, ksel.arg_descs[0].v, ts) + kb.yield_results() diff --git a/python/shark_turbine/runtime/device.py b/python/shark_turbine/runtime/device.py index a07e139cf..aef747ba9 100644 --- a/python/shark_turbine/runtime/device.py +++ b/python/shark_turbine/runtime/device.py @@ -263,7 +263,7 @@ def _device_import_torch_tensor_cpu(device: Device, t: torch.Tensor) -> HalBuffe memory_type=MemoryType.DEVICE_LOCAL, allowed_usage=BufferUsage.DEFAULT, device=hal_device, - buffer=t.numpy(), + buffer=t.detach().numpy(), element_type=element_type, ) return bv diff --git a/python/shark_turbine/runtime/op_reg/base.py b/python/shark_turbine/runtime/op_reg/base.py index 6702d745f..77b29c0e7 100644 --- a/python/shark_turbine/runtime/op_reg/base.py +++ b/python/shark_turbine/runtime/op_reg/base.py @@ -8,11 +8,13 @@ dispatcher. """ -from typing import Any, Callable, Optional, Type, Union +from typing import Any, Callable, Optional, Sequence, Type, Union from abc import ABC, abstractmethod, abstractproperty import functools import logging +import re +import textwrap import torch from torch import Tensor @@ -21,12 +23,15 @@ Block, Context, FunctionType, + IndexType, InsertionPoint, + IntegerAttr, Location, StringAttr, SymbolTable, IrType, Value, + arith_d, builtin_d, func_d, ) @@ -43,6 +48,7 @@ "KernelBuilder", "KernelSelection", "TensorArg", + "def_library", ] logger = logging.getLogger("turbine.runtime.op_reg") @@ -51,10 +57,29 @@ # Op library management ############################################################################### + +def def_library(ns) -> torch.library.Library: + """Creates a new 'DEF' library which contains custom ops. + + It is necessary to create such custom op libraries in this way since + the library is registered with the compiler in such a way that it can + operate over all known custom ops. + """ + return torch.library.Library(ns, "DEF") + + +def default_dispatch_keys() -> list[str]: + # TODO: Dynamically determine what devices to register against. + # Note that we have to register against specific keys instead of the + # fallback, as fallback is too broad and breaks certain elements of + # fx tracing. + return ["CPU"] + + # All such custom kernels are registered in the 'turbine' library/namespace. # We also allow extending existing libraries outside of this, but that is # the non default case. -TURBINE_LIBRARY = torch.library.Library("turbine", "DEF") +TURBINE_LIBRARY = def_library("turbine") class CustomOp(ABC): @@ -62,10 +87,10 @@ class CustomOp(ABC): @staticmethod def register( - op_class: Optional[Type["CustomOp"]], + op_class: Optional[Type["CustomOp"]] = None, *, library: torch.library.Library = TURBINE_LIBRARY, - dispatch_key: str = "", + dispatch_key: Union[str, Sequence[str], None] = None, register_meta: bool = True, register_impl: bool = True, ) -> Callable: @@ -104,12 +129,12 @@ def __init__( self, *, library: torch.library.Library, - dispatch_key: str, + dispatch_key: Union[str, Sequence[str], None], register_meta: bool, register_impl: bool, ): - name = self.name - fq_schema = f"{name}{self.signature}" + fq_schema = self.signature + name = _extract_name_from_signature(fq_schema) library.define(fq_schema) self.library = library self.cache_key_base = f"{library.ns}.{library.kind}::{name}" @@ -122,19 +147,25 @@ def __init__( library.impl(name, _get_meta_impl(self), "Meta") if register_impl: - library.impl(name, _create_impl_trampoline(self), dispatch_key) + if dispatch_key is None: + dispatch_key = default_dispatch_keys() + elif isinstance(dispatch_key, str): + dispatch_key = [dispatch_key] + for k in dispatch_key: + library.impl(name, _create_impl_trampoline(self), k) - @abstractproperty - def name(self) -> str: - """Name of the operation.""" - ... + fq_name = f"{library.ns}.{name}" + ALL_CUSTOM_OP_REGS[fq_name] = self @abstractproperty def signature(self) -> str: """PyTorch function signature. - This excludes the name, which will come from the `name` property - and be prepended to make a full PyTorch schema. + This is in the normal PyTorch kernel registration form. For example: + + ``` + my_op(Tensor t) -> Tensor + ``` """ ... @@ -165,7 +196,7 @@ def generate(self, ksel: "KernelSelection", kb: "KernelBuilder"): This method should generate IR into the given `KernelBuilder`. It can do so by consulting any state set on the `KernelSelection`. Each `KernelSelection.args` corresponds to `KernelBuilder.args`. - Unless if the argument was set as `is_ir_arg=False`, the argument + Unless if the argument was set as `ir_arity=0`, the argument will be a `Value`. Otherwise, it will be `None`. It is recommended to use `KernelBuilder.arg(n)` to access. @@ -174,7 +205,12 @@ def generate(self, ksel: "KernelSelection", kb: "KernelBuilder"): ... -class KernelSelection: +# All instantiated CustomOp instances, keyed by fully qualified name. This is +# used by the AOT compiler to expand custom ops that were captured in a trace. +ALL_CUSTOM_OP_REGS: dict[str, CustomOp] = {} + + +class KernelSelection(ABC): """Represents a selected kernel based on a concrete signature. The `CustomOp.select` method must yield an instance of this, and @@ -188,33 +224,59 @@ class KernelSelection: """ __slots__ = [ - "args", "arg_descs", "op", "result_descs", "variant", ] - def __init__(self, op: CustomOp, args: list[Any]): + def __init__(self, op: CustomOp, arg_arity: int): self.op = op - self.args = args - self.arg_descs: list[Optional[ArgDescriptor]] = len(args) * [None] + self.arg_descs: list[Optional[ArgDescriptor]] = arg_arity * [None] self.result_descs: list[ArgDescriptor] = [] self.variant: str = "default" + def __repr__(self): + lines = [ + "KernelSelection<", + f" op = '{self.op.name}',", + f" variant = '{self.variant}',", + " arg_descs = [", + ] + for arg_desc in self.arg_descs: + lines.append(f" {arg_desc},") + lines.append(" ],") + lines.append(" result_descs = [") + for result_desc in self.result_descs: + lines.append(f" {result_desc},") + lines.append(" ]") + lines.append(">") + return "\n".join(lines) + def generate_meta_returns(self) -> Any: results = [d.generate_meta() for d in self.result_descs] - if len(results) == 1: + arity = len(results) + if arity == 1: return results[0] + elif arity == 0: + return None else: return tuple(results) @property def spec_key(self) -> str: - arg_keys = ",".join(d.spec_key for d in self.arg_descs) - return_keys = ",".join(d.spec_key for d in self.result_descs) - return f"{self.op.cache_key_base}::{self.variant}({arg_keys})->({return_keys})" + try: + arg_keys = ",".join(d.spec_key for d in self.arg_descs) + return_keys = ",".join(d.spec_key for d in self.result_descs) + return ( + f"{self.op.cache_key_base}::{self.variant}({arg_keys})->({return_keys})" + ) + except Exception as e: + raise AssertionError( + f"Error generating spec_key from:\n{textwrap.indent(repr(self), ' ')}" + ) from e + @abstractmethod def arg_tensor(self, arg: int) -> "TensorArg": """Declares an argument to allow any ranked tensor and to specialize for each rank and dtype. @@ -222,6 +284,59 @@ def arg_tensor(self, arg: int) -> "TensorArg": Returns the argument descriptor, which can be used to further inspect or constrain the selection. It will default to allowing all dimensions to be dynamic. """ + ... + + @abstractmethod + def arg_tensor_list(self, arg: int) -> "TensorListArg": + """Declares an argument to accept a list of tensors which will be specialized + for the list size and each rank/dtype. + + Returns the argument descriptor, which can be used to further inspect or constrain + the selection. It will default to allowing all dimensions to be dynamic. + """ + ... + + @abstractmethod + def arg_int(self, arg: int) -> "IntArg": + """Declares an argument to be an integer value that can take any value. + + Returns the argument descriptor, which can be used to further inspect or constrain + the selection. + """ + ... + + @abstractmethod + def attr_str(self, arg: int) -> "AttrArg": + """Declares an argument to be a string attribute. + + Such arguments are not materialized in the IR as Values but may be used to + generate the IR. In AOT contexts, they must be derived from static values. + """ + ... + + @abstractmethod + def return_tensor(self, t: Tensor) -> "TensorArg": + """Marks the next return value as a Tensor. + + By default, it will be rank and dtype specialized but have completely dynamic + dimensions. Dimensions can be further constrained by modifying the returned + descriptor. + """ + ... + + +class EagerKernelSelection(KernelSelection): + """Kernel selection specialized for eager arguments.""" + + __slots__ = [ + "args", + ] + + def __init__(self, op: CustomOp, args: list[Any]): + super().__init__(op, len(args)) + self.args = args + + def arg_tensor(self, arg: int) -> "TensorArg": arg_descs = self.arg_descs arg_value = self.args[arg] assert arg_descs[arg] is None, f"Already constrained argument {arg}" @@ -231,12 +346,17 @@ def arg_tensor(self, arg: int) -> "TensorArg": arg_descs[arg] = desc = TensorArg(arg_value) return desc - def arg_int(self, arg: int) -> "IntArg": - """Declares an argument to be an integer value that can take any value. + def arg_tensor_list(self, arg: int) -> "TensorListArg": + arg_descs = self.arg_descs + arg_value = self.args[arg] + assert arg_descs[arg] is None, f"Already constrained argument {arg}" + assert isinstance( + arg_value, list + ), f"Argument type mismatch from Torch for {arg}: Expected tensor, got {type(arg_value)}" + arg_descs[arg] = desc = TensorListArg(arg_value) + return desc - Returns the argument descriptor, which can be used to further inspect or constrain - the selection. - """ + def arg_int(self, arg: int) -> "IntArg": arg_descs = self.arg_descs arg_value = self.args[arg] assert arg_descs[arg] is None, f"Already constrained argument {arg}" @@ -246,37 +366,116 @@ def arg_int(self, arg: int) -> "IntArg": arg_descs[arg] = desc = IntArg(arg_value) return desc - def return_tensor(self, t: Tensor) -> "TensorArg": - """Marks the next return value as a Tensor. + def attr_str(self, arg: int) -> "AttrArg": + arg_descs = self.arg_descs + arg_value = self.args[arg] + assert arg_descs[arg] is None, f"Already constrained argument {arg}" + assert isinstance( + arg_value, str + ), f"Argument type mismatch from Torch for {arg}: Expected int, got {type(arg_value)}" + arg_descs[arg] = desc = AttrArg(arg_value) + return desc - By default, it will be rank and dtype specialized but have completely dynamic - dimensions. Dimensions can be further constrained by modifying the returned - descriptor. - """ + def return_tensor(self, t: Tensor) -> "TensorArg": desc = TensorArg(t) self.result_descs.append(desc) return desc +class AttrArg: + ir_arity: int = 0 + maybe_tensor_value: Optional[Tensor] = None + is_list: bool = False + + __slots__ = [ + "v", + "spec_value", + ] + + def __init__(self, v: object): + self.v = v + # We specialize on every distinct value. + self.spec_value: Optional[Any] = v + + def __repr__(self): + return f"AttrArg(<{self.spec_value}>)" + + def generate_meta(self) -> object: + return self.v + + @property + def spec_key(self) -> str: + """Generates a key that will be the same for all specializations.""" + return f"attr<{self.spec_value}>" + + @property + def mlir_type_asm(self) -> str: + raise AssertionError("Cannot resolve `mlir_type_asm` for an AttrArg") + + +class IntArg: + __slots__ = [ + "ir_arity", + "spec_value", + "v", + ] + + # All descriptors have an attribute to indicate their value + # as a tensor, and those that aren't are fixated to None. + # This is to enable fast lookup in the hot path of determining + # how to dispatch. + maybe_tensor_value: Optional[Tensor] = None + is_list: bool = False + + def __init__(self, v: int): + self.v = v + self.spec_value: Optional[Any] = None + self.ir_arity: int = 1 + + def __repr__(self): + return f"IntArg({self.v}, spec_value={self.spec_value}, is_ir_arg={self.is_ir_arg})" + + def generate_meta(self) -> int: + return self.v + + @property + def spec_key(self) -> str: + """Generates a key that will be the same for all specializations.""" + return f"int<{self.spec_value}>" + + @property + def mlir_type_asm(self) -> str: + # TODO: We can have individual kernels constrain this to a narrower + # type. + return "i64" + + class TensorArg: __slots__ = [ "t", "spec_dims", - "is_ir_arg", "maybe_tensor_value", ] + ir_arity: int = 1 + is_list: bool = False + def __init__(self, t: Tensor): self.t = t # Any static dims that we are specializing. Defaults to all dynamic. self.spec_dims: list[Optional[int]] = len(t.shape) * [None] - self.is_ir_arg = True # All descriptors have an attribute to indicate their value # as a tensor, and those that aren't are fixated to None. # This is to enable fast lookup in the hot path of determining # how to dispatch. self.maybe_tensor_value: Tensor = t + def __repr__(self): + return ( + f"TensorArg(shape={self.t.shape}, dtype={self.t.dtype}, " + f"spec_dims={self.spec_dims})" + ) + def generate_meta(self) -> Tensor: t = self.t if t.device == "meta": @@ -305,40 +504,69 @@ def mlir_type_asm(self) -> str: return f"tensor<{spec}>" -class IntArg: +class TensorListArg: __slots__ = [ - "is_ir_arg", - "v", - "spec_value", + "ts", + "spec_dims", + "ir_arity", "maybe_tensor_value", ] - def __init__(self, v: int): - self.v = v - self.spec_value: Optional[Any] = None - self.is_ir_arg = True + is_list: bool = True + + def __init__(self, ts: Tensor): + self.ts = ts + self.ir_arity = len(ts) + # Any static dims that we are specializing. Defaults to all dynamic. + self.spec_dims: list[list[Optional[int]]] = [len(t.shape) * [None] for t in ts] # All descriptors have an attribute to indicate their value # as a tensor, and those that aren't are fixated to None. # This is to enable fast lookup in the hot path of determining # how to dispatch. - self.maybe_tensor_value: Optional[Tensor] = None + self.maybe_tensor_value: Tensor = ts - def generate_meta(self) -> int: - return self.v + def __repr__(self): + return ( + f"TensorListArg(shape={[t.shape for t in self.ts]}, " + f"dtype={[t.dtype for t in self.ts]}, " + f"spec_dims={self.spec_dims}, ir_arity={self.ir_arity})" + ) + + def generate_meta(self) -> list[Tensor]: + metas = [] + for t in self.ts: + if t.device == "meta": + metas.append(t) + else: + metas.append(t.clone().detach().to("meta")) + return metas @property def spec_key(self) -> str: """Generates a key that will be the same for all specializations.""" - return f"int<{self.spec_value}>" + return ( + f"tensor[{[len(t.shape) for t in self.ts]}" + f":{[str(t.dtype) for t in self.ts]}]<{self.spec_dims}>" + ) @property - def mlir_type_asm(self) -> str: - # TODO: We can have individual kernels constrain this to a narrower - # type. - return "i64" - - -ArgDescriptor = Union[TensorArg, IntArg] + def mlir_type_asm(self) -> list[str]: + asms = [] + for t, spec_dims in zip(self.ts, self.spec_dims): + try: + dtype_asm = TORCH_DTYPE_TO_IREE_TYPE_ASM[t.dtype] + except KeyError as e: + raise KeyError( + f"Unknown mapping of torch dtype {t.dtype} to MLIR " + f"(possibly missing in TORCH_DTYPE_TO_IREE_TYPE_ASM table)" + ) from e + dim_asm = "x".join(["?" if d is None else str(d) for d in spec_dims]) + spec = f"{dim_asm}x{dtype_asm}" if dim_asm else dtype_asm + asms.append(f"tensor<{spec}>") + return asms + + +ArgDescriptor = Union[AttrArg, IntArg, TensorArg, TensorListArg] ############################################################################### # KernelBuilder @@ -352,7 +580,7 @@ class KernelBuilder(ABC): def __init__( self, ksel: KernelSelection, - arg_bindings: list[Value], + arg_bindings: list[Union[Value, list[Value]]], *, ip: InsertionPoint, module_body: Block, @@ -363,11 +591,12 @@ def __init__( self.ip = ip self.module_body = module_body self.symbol_table = symbol_table + self.yielded = False - def arg_value(self, index: int) -> Value: + def arg_value(self, index: int) -> Union[list[Value], Value]: """Gets the concrete IR `Value` for the argument at `index`. - This will assert if the corresponding argument was set as `is_ir_arg=False` + This will assert if the corresponding argument was set as `ir_arity=0` during kernel selection. """ try: @@ -386,6 +615,10 @@ def yield_results(self, *results: Value): """Yields results of the kernel computation.""" ... + def constant_index(self, i: int) -> Value: + """Builds a constant index value.""" + return arith_d.constant(IndexType.get(), IntegerAttr.get(IndexType.get(), i)) + class FreeFuncKernelBuilder(KernelBuilder): """Kernel builder that emits the body of the kernel into a free function. @@ -409,10 +642,31 @@ def __init__( if func_name is None: func_name = ksel.op.name with context, Location.unknown(), InsertionPoint(module_body): - arg_types = [ - IrType.parse(d.mlir_type_asm) for d in ksel.arg_descs if d.is_ir_arg - ] - result_types = [IrType.parse(d.mlir_type_asm) for d in ksel.result_descs] + # Assemble arg types. + arg_types = [] + for d in ksel.arg_descs: + arity = d.ir_arity + if not d.is_list: + if arity == 1: + arg_types.append(IrType.parse(d.mlir_type_asm)) + else: + continue + else: + for i in range(arity): + arg_types.append(IrType.parse(d.mlir_type_asm[i])) + + # Assemble result types. + result_types = [] + for d in ksel.result_descs: + if not d.is_list: + if d.ir_arity == 1: + result_types.append(IrType.parse(d.mlir_type_asm)) + else: + continue + else: + raise AssertionError("NYI: arity > 1 results") + + # Create the func. ftype = FunctionType.get(arg_types, result_types) func_op = func_d.FuncOp(func_name, ftype) if not is_public: @@ -422,13 +676,21 @@ def __init__( # Map inputs to arg bindings, lining up with arguments that are elided. block_arguments = list(entry_block.arguments) - block_arguments.reverse() + block_arg_index = 0 arg_bindings: list[Optional[Value]] = [] for desc in ksel.arg_descs: - if desc.is_ir_arg: - arg_bindings.append(block_arguments.pop()) + arity = desc.ir_arity + if not desc.is_list: + if arity == 1: + arg_bindings.append(block_arguments[block_arg_index]) + block_arg_index += 1 + else: + arg_bindings.append(None) else: - arg_bindings.append(None) + arg_bindings.append( + block_arguments[block_arg_index : block_arg_index + arity] + ) + block_arg_index += arity super().__init__( ksel, @@ -461,8 +723,10 @@ def create_module( def yield_results(self, *results: Value): """Yields results of the kernel computation.""" + assert not self.yielded, "yield_results has already been called" with self.ip, Location.unknown(): func_d.ReturnOp(results) + self.yielded = True ############################################################################### @@ -477,7 +741,7 @@ def _get_library_op(library: torch.library.Library, name: str) -> Any: def _get_meta_impl(op: CustomOp): def meta(*args): - sel = KernelSelection(op, args) + sel = EagerKernelSelection(op, args) op.select(sel) if logger.isEnabledFor(logging.DEBUG): logging.debug( @@ -496,7 +760,7 @@ def _create_impl_trampoline(op: CustomOp): ) def handler(*args): - ksel = KernelSelection(op, args) + ksel = EagerKernelSelection(op, args) op.select(ksel) if logger.isEnabledFor(logging.DEBUG): logging.debug( @@ -505,3 +769,13 @@ def handler(*args): return eager_dispatch(ksel) return handler + + +_SIGNATURE_NAME_PATTERN = re.compile(r"^([^(]+)") + + +def _extract_name_from_signature(sig: str) -> str: + m = re.match(_SIGNATURE_NAME_PATTERN, sig) + if not m: + raise ValueError(f"Expected signature of form `name() -> (). Got: {sig}") + return m.group(1) diff --git a/python/shark_turbine/runtime/op_reg/compiler.py b/python/shark_turbine/runtime/op_reg/compiler.py index c15ead608..4d9595a34 100644 --- a/python/shark_turbine/runtime/op_reg/compiler.py +++ b/python/shark_turbine/runtime/op_reg/compiler.py @@ -24,6 +24,10 @@ GeneralError, ) +from ...support.ir_imports import ( + Location, +) + from ...support.logging import ( runtime_logger as logger, ) @@ -87,10 +91,11 @@ def compile_standalone_kernel( start = default_timer() config = KernelCompileConfig(cache_key, list(device.compile_target_flags)) kb = FreeFuncKernelBuilder.create_module(ksel, func_name=func_name) - ksel.op.generate(ksel, kb) + with kb.ip, Location.unknown(): + ksel.op.generate(ksel, kb) kb.module_op.verify() module_asm = kb.module_op.get_asm( - binary=True, enable_debug_info=True, assume_verified=True + binary=True, enable_debug_info=True, print_generic_op_form=True ) generation_time = default_timer() - start diff --git a/python/shark_turbine/runtime/op_reg/eager.py b/python/shark_turbine/runtime/op_reg/eager.py index 629118096..00ac7a7df 100644 --- a/python/shark_turbine/runtime/op_reg/eager.py +++ b/python/shark_turbine/runtime/op_reg/eager.py @@ -50,15 +50,29 @@ def eager_dispatch(ksel: KernelSelection): device: Optional[Device] = None torch_device: Optional[torch.device] = None for arg_desc in ksel.arg_descs: - if not arg_desc.is_ir_arg: - continue - tensor_arg = arg_desc.maybe_tensor_value - if tensor_arg is None: - continue - torch_device = tensor_arg.device - device = lookup_device_from_torch(torch_device) - if device is not None: - break + if not arg_desc.is_list: + if arg_desc.ir_arity == 1: + # One arg has maybe_tensor_value as a single element (common case). + tensor_arg = arg_desc.maybe_tensor_value + if tensor_arg is None: + continue + torch_device = tensor_arg.device + device = lookup_device_from_torch(torch_device) + if device is not None: + break + else: + # Optional arg omitted. + assert arg_desc.ir_arity == 0 + continue + else: + # List. maybe_tensor_value is a list. Uncommon case. + for tensor_arg in arg_desc.maybe_tensor_value: + if tensor_arg is None: + continue + torch_device = tensor_arg.device + device = lookup_device_from_torch(torch_device) + if device is not None: + break # Default to CPU. if device is None: @@ -72,21 +86,16 @@ def eager_dispatch(ksel: KernelSelection): # Build the concrete args, issuing device movement as necessary. arg_list = VmVariantList(len(ksel.arg_descs)) - for arg_desc in ksel.arg_descs: - if not arg_desc.is_ir_arg: - continue - tensor_arg = arg_desc.maybe_tensor_value - # Handle non-tensor args. - if tensor_arg is None: - scalar_value = arg_desc.v - if isinstance(scalar_value, int): - arg_list.push_int(scalar_value) - elif isinstance(scalar_value, float): - arg_list.push_float(scalar_value) - else: - raise UnsupportedTypeError(type(scalar_value)) - continue - # Tensor arg. + + def push_scalar(scalar_value): + if isinstance(scalar_value, int): + arg_list.push_int(scalar_value) + elif isinstance(scalar_value, float): + arg_list.push_float(scalar_value) + else: + raise UnsupportedTypeError(type(scalar_value)) + + def push_tensor(tensor_arg): if tensor_arg.device != torch_device: # TODO: If the source and target device are both known to us, # we can do this "in house" vs asking torch to do it. @@ -101,6 +110,28 @@ def eager_dispatch(ksel: KernelSelection): # import_torch_tensor. arg_list.push_ref(device.import_torch_tensor(tensor_arg)) + for arg_desc in ksel.arg_descs: + arity = arg_desc.ir_arity + if not arg_desc.is_list: + # Non-list. + if arity == 1: + tensor_arg = arg_desc.maybe_tensor_value + if tensor_arg is not None: + push_tensor(tensor_arg) + else: + push_scalar(arg_desc.v) + else: + continue + else: + # List. Uncommon case. + tensor_arg = arg_desc.maybe_tensor_value + if tensor_arg is not None: + for i in range(arity): + push_tensor(tensor_arg[i]) + else: + for i in range(arity): + push_scalar(arg_desc.v[i]) + if config.async_invocations: raise NotImplementedError("Async execution not yet implemented") @@ -113,8 +144,9 @@ def eager_dispatch(ksel: KernelSelection): # Unpack results. results = [] - for i, result_desc in enumerate(ksel.result_descs): + arity = result_desc.ir_arity + assert arity == 1, "NYI: Optional and result lists" meta_tensor_value = result_desc.maybe_tensor_value if meta_tensor_value is None: # Scalar return. @@ -128,5 +160,7 @@ def eager_dispatch(ksel: KernelSelection): if len(results) == 1: return results[0] + elif len(results) == 0: + return None else: return tuple(results) diff --git a/python/shark_turbine/support/conversions.py b/python/shark_turbine/support/conversions.py index 132cfd428..c1a9fd2f4 100644 --- a/python/shark_turbine/support/conversions.py +++ b/python/shark_turbine/support/conversions.py @@ -53,6 +53,28 @@ torch.complex128: lambda: ComplexType.get(F64Type.get()), } +TORCH_DTYPE_TO_SIGNED_MLIR_TYPE_ASM = { + torch.float16: "f16", + torch.bfloat16: "bf16", + torch.float32: "f32", + torch.float64: "f64", + torch.uint8: "ui8", + torch.int8: "si8", + torch.int16: "si16", + torch.int32: "si32", + torch.int64: "si64", + torch.bool: "i1", + torch.qint8: "si8", + torch.quint8: "ui8", + torch.complex32: "complex", + torch.complex64: "complex", + torch.complex128: "complex", +} + +SIGNED_MLIR_TYPE_ASM_TO_TORCH_DTYPE = dict( + (v, k) for k, v in TORCH_DTYPE_TO_SIGNED_MLIR_TYPE_ASM.items() +) + TORCH_DTYPE_TO_IREE_TYPE_ASM = { torch.float16: "f16", torch.bfloat16: "bf16", diff --git a/python/shark_turbine/transforms/general/custom_op_expansion.py b/python/shark_turbine/transforms/general/custom_op_expansion.py new file mode 100644 index 000000000..88a4bb841 --- /dev/null +++ b/python/shark_turbine/transforms/general/custom_op_expansion.py @@ -0,0 +1,248 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch +from torch import Tensor + +from ...dynamo.type_conversion import ( + NativeTypeConverter, +) + +from ...runtime.op_reg.base import ( + ALL_CUSTOM_OP_REGS, + AttrArg, + IntArg, + CustomOp, + KernelBuilder, + KernelSelection, + TensorArg, + TensorListArg, +) + +from ...support.conversions import ( + MLIR_TYPE_ASM_TO_TORCH_DTYPE, +) + +from ...support.ir_imports import ( + Block, + InsertionPoint, + OpResult, + Operation, + RankedTensorType, + StringAttr, + SymbolTable, + Value, +) + +from ..rewriter import ( + Pass, +) + + +class ExpandCustomOpsPass(Pass): + def __init__( + self, root_op: Operation, reg: dict[str, CustomOp] = ALL_CUSTOM_OP_REGS + ): + super().__init__(root_op) + self.reg = reg + # Track pending deletions in a dict to preserve order and unique. + self.ops_to_delete: dict[Operation, None] = {} + self.type_converter = NativeTypeConverter(root_op.context) + self.symbol_table = SymbolTable(root_op) + + def delete_op(self, op): + self.ops_to_delete[op.operation] = None + + def run(self): + for mr in self.funcs: + self.expand_func(mr.op) + for op in self.ops_to_delete.keys(): + self.erase_unused_op(op) + + def expand_func(self, func_op: Operation): + """Expands custom ops in a traced torch function. + + This finds operations of the form: + %0 = torch.operator "torch.ns.op" + + And looks them up in the reg dict. If it originated from one of those + registered ops, then it will be expanded in place. + """ + name_prefix = "torch." + + for block in func_op.regions[0].blocks: + for op in block.operations: + if op.name == "torch.operator": + custom_op_name = StringAttr(op.attributes["name"]).value + if custom_op_name.startswith(name_prefix): + local_name = custom_op_name[len(name_prefix) :] + custom_op_reg = self.reg.get(local_name) + if custom_op_reg is not None: + self.expand_custom_op(custom_op_reg, op) + + def expand_custom_op(self, op_reg: CustomOp, op: Operation): + original_operands: list[Value] = list(op.operands) + ksel = AOTKernelSelection( + op_reg, original_operands, list(op.results), self.type_converter + ) + op_reg.select(ksel) + + module_body = self.root_op.regions[0].blocks[0] + kb = InlineKernelBuilder( + ksel, + op, + type_converter=self.type_converter, + module_body=module_body, + symbol_table=self.symbol_table, + ) + with kb.ip, kb.location: + op_reg.generate(ksel, kb) + assert kb.yielded, "Custom op generation did not yield_results()" + + self.delete_op(op) + + +class AOTKernelSelection(KernelSelection): + __slots__ = [ + "operands", + "results", + "type_converter", + ] + + def __init__( + self, + op: CustomOp, + operands: list[Value], + results: list[Value], + type_converter: NativeTypeConverter, + ): + super().__init__(op, len(operands)) + self.operands = operands + self.results = results + self.type_converter = type_converter + + def arg_tensor(self, arg: int) -> TensorArg: + # This is annoying: We have to go from the Torch MLIR type system to the + # original torch.tensor Python type system. We do this by way of the native + # type converter because it has the mapping pathway we need. This is one of the + # only places in the code where we have to go this way to preserve the facade. + # Everywhere else is going from Torch -> IREE native. + arg_descs = self.arg_descs + assert arg_descs[arg] is None, f"Already constrained argument {arg}" + operand = self.operands[arg] + signed_native_type = self.type_converter.torch_type_to_native(operand.type) + try: + rtt = RankedTensorType(signed_native_type) + # TODO: We need to do the FakeMode/ShapeEnv dance to create a symbolic + # fake tensor here. + except TypeError as e: + raise TypeError( + f"Argument type mismatch from Torch IR for arg {arg}: Expected ranked tensor, got {signed_native_type}" + ) from e + assert not any( + rtt.is_dynamic_dim(i) for i in range(rtt.rank) + ), "NYI: Dynamic shape tensors in custom op AOT mode" + element_type_asm = str(rtt.element_type) + try: + dtype = MLIR_TYPE_ASM_TO_TORCH_DTYPE[element_type_asm] + except KeyError as e: + raise AssertionError( + f"Could not find dtype mapping for {element_type_asm} in MLIR_TYPE_ASM_TO_TORCH_DTYPE" + ) + t = torch.empty(rtt.shape, dtype=dtype, device="meta") + arg_descs[arg] = desc = TensorArg(t) + return desc + + def arg_tensor_list(self, arg: int) -> TensorListArg: + raise NotImplementedError("NYI: AOT arg_tensor_list") + + def arg_int(self, arg: int) -> IntArg: + raise NotImplementedError("NYI: AOT arg_int") + + def attr_str(self, arg: int) -> AttrArg: + arg_descs = self.arg_descs + assert arg_descs[arg] is None, f"Already constrained argument {arg}" + operand = self.operands[arg] + ty = operand.type + assert ( + str(ty) == "!torch.str" + ), f"Argument type mismatch from Torch IR for {arg}: Expected !torch.str, got {ty}" + str_value = _get_constant_str_from_value(operand) + arg_descs[arg] = desc = AttrArg(str_value) + return desc + + def return_tensor(self, t: Tensor) -> TensorArg: + desc = TensorArg(t) + self.result_descs.append(desc) + return desc + + +def _get_constant_str_from_value(v: Value) -> str: + """Given a constant str producer, return the str. + + Example: %str = torch.constant.str "TEST" + """ + constant_op = OpResult(v).owner + assert ( + constant_op.name == "torch.constant.str" + ), f"Expected constant !torch.str to be produced by a torch.constant.str op but got: {constant_op}" + return StringAttr(constant_op.attributes["value"]).value + + +class InlineKernelBuilder(KernelBuilder): + def __init__( + self, + ksel: KernelSelection, + torch_op: Operation, + *, + type_converter: NativeTypeConverter, + module_body: Block, + symbol_table: SymbolTable, + ): + location = torch_op.location + ip = InsertionPoint(torch_op) + with ip, location: + operands = list(torch_op.operands) + arg_bindings = [] + for desc, operand in zip(ksel.arg_descs, operands): + arity = desc.ir_arity + if not desc.is_list: + if arity == 1: + arg_bindings.append( + type_converter.materialize_torch_to_native(operand) + ) + else: + arg_bindings.append(None) + else: + # arg_bindings.extend(native_operands) + raise NotImplementedError("NYI: AOT custom op list arguments") + + super().__init__( + ksel, + arg_bindings=arg_bindings, + ip=ip, + module_body=module_body, + symbol_table=symbol_table, + ) + self.location = location + self.torch_op = torch_op + self.type_converter = type_converter + + def yield_results(self, *results: Value): + """Yields results of the kernel computation.""" + assert not self.yielded, "yield_results has already been called" + with self.ip, self.location: + torch_op_results: list[Value] = list(self.torch_op.results) + assert len(results) == len( + torch_op_results + ), f"Mismatched yield_results with custom op results" + for new_result, old_result in zip(results, torch_op_results): + torch_type = old_result.type + new_result = self.type_converter.materialize_native_to_torch( + new_result, torch_type + ) + old_result.replace_all_uses_with(new_result) + self.yielded = True diff --git a/tests/dynamo/type_conversion_test.py b/tests/dynamo/type_conversion_test.py index a4410d80b..dfc3de25b 100644 --- a/tests/dynamo/type_conversion_test.py +++ b/tests/dynamo/type_conversion_test.py @@ -24,15 +24,19 @@ def testPrimitives(self): self._compareNative("!torch.int", "i64") self._compareNative("!torch.float", "f64") + def testSigned(self): + self._compareNative("!torch.bool", "i1", signless=False) + self._compareNative("!torch.int", "si64", signless=False) + def testValueTensors(self): self._compareNative("!torch.vtensor<[2, 2],f32>", "tensor<2x2xf32>") self._compareNative("!torch.vtensor<[?, ?],f32>", "tensor") self._compareNative("!torch.vtensor<[],f32>", "tensor") - def _compareNative(self, torch_str: str, native_str: str): + def _compareNative(self, torch_str: str, native_str: str, *, signless: bool = True): with self.conv._context: torch_type = IrType.parse(torch_str) - native_type = self.conv.torch_type_to_native(torch_type) + native_type = self.conv.torch_type_to_native(torch_type, signless=signless) self.assertEqual(str(native_type), native_str) diff --git a/tests/ops/iree_test.py b/tests/ops/iree_test.py new file mode 100644 index 000000000..f10643026 --- /dev/null +++ b/tests/ops/iree_test.py @@ -0,0 +1,29 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import unittest + +import torch + +import shark_turbine.ops as ops + + +class KernelRegTest(unittest.TestCase): + def testTrace(self): + t = torch.randn(3, 4) + ops.iree.trace_tensor("TEST", t) + + def testTraceList(self): + t1 = torch.randn(3, 4) + t2 = torch.randn(1, 8) + ops.iree.trace_tensors("TEST 2", [t1, t2]) + ops.iree.trace_tensors("TEST 1", [t1]) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/runtime/op_reg/kernel_aot_test.py b/tests/runtime/op_reg/kernel_aot_test.py new file mode 100644 index 000000000..48c7f59f1 --- /dev/null +++ b/tests/runtime/op_reg/kernel_aot_test.py @@ -0,0 +1,63 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import unittest + +import torch +import torch.nn as nn + +import shark_turbine.aot as aot +import shark_turbine.ops as ops + +from shark_turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass + + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.layer0 = nn.Linear(8, 8, bias=True) + self.layer1 = nn.Linear(8, 4, bias=True) + self.layer2 = nn.Linear(4, 2, bias=True) + self.layer3 = nn.Linear(2, 2, bias=True) + + def forward(self, x: torch.Tensor): + x = self.layer0(x) + x = torch.sigmoid(x) + ops.iree.trace_tensor("LAYER0", x) + x = self.layer1(x) + x = torch.sigmoid(x) + ops.iree.trace_tensor("LAYER1", x) + x = self.layer2(x) + x = torch.sigmoid(x) + ops.iree.trace_tensor("LAYER2", x) + x = self.layer3(x) + ops.iree.trace_tensor("LAYER3", x) + return x + + +class KernelRegTest(unittest.TestCase): + def testTrace(self): + mlp = MLP() + prog = aot.export(mlp, torch.empty(97, 8, dtype=torch.float32)) + + p = ExpandCustomOpsPass(prog.mlir_module) + p.run() + + print("CUSTOM OP CONVERTED:") + module_asm = str(prog.mlir_module) + self.assertIn('flow.tensor.trace "LAYER0"', module_asm) + self.assertIn('flow.tensor.trace "LAYER1"', module_asm) + self.assertIn('flow.tensor.trace "LAYER3"', module_asm) + + def testEager(self): + mlp = MLP() + mlp.forward(torch.empty(97, 8, dtype=torch.float32)) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/runtime/op_reg/kernel_reg_test.py b/tests/runtime/op_reg/kernel_reg_test.py index 10662db14..75554b046 100644 --- a/tests/runtime/op_reg/kernel_reg_test.py +++ b/tests/runtime/op_reg/kernel_reg_test.py @@ -15,11 +15,10 @@ class KernelRegTest(unittest.TestCase): - def testSimple(self): + def testRegistrationDispatchAndCache(self): @CustomOp.register class identity(CustomOp): - name = "test_identity" - signature = "(Tensor self) -> Tensor" + signature = "test_identity(Tensor self) -> Tensor" def select(self, ksel: KernelSelection): x = ksel.arg_tensor(0) diff --git a/tests/transforms/general/custom_op_expansion_test.py b/tests/transforms/general/custom_op_expansion_test.py new file mode 100644 index 000000000..3ede7e9f7 --- /dev/null +++ b/tests/transforms/general/custom_op_expansion_test.py @@ -0,0 +1,112 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +from pathlib import Path +import torch +import unittest + +from shark_turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass +from shark_turbine.runtime.op_reg import ( + def_library, + CustomOp, + KernelBuilder, + KernelSelection, +) + +from shark_turbine.support.ir_imports import ( + Context, + Module, +) + + +class PassTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.lib = def_library("expand_custom_op_pass_test") + CustomOp.register(library=cls.lib)(IdentityOp) + CustomOp.register(library=cls.lib)(PrintStringAttrOp) + CustomOp.register(library=cls.lib)(IntArgOp) + + def testTensorArgReturn(self): + m = self.run_test_case("custom_op_simple.mlir") + m_asm = str(m) + self.assertNotIn("torch.operator", m_asm) + self.assertIn( + "%0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[97,8],f32> -> tensor<97x8xf32>", + m_asm, + ) + self.assertIn( + "%1 = torch_c.from_builtin_tensor %0 : tensor<97x8xf32> -> !torch.vtensor<[97,8],f32>", + m_asm, + ) + print(m_asm) + + def testStringAttrArg(self): + global _TEST_STRING_ATTR + _TEST_STRING_ATTR = "" + m = self.run_test_case("custom_op_string_attr.mlir") + m_asm = str(m) + self.assertEqual(_TEST_STRING_ATTR, "TEST_VALUE") + self.assertNotIn("torch.operator", m_asm) + print(m_asm) + + def testIntArg(self): + global _TEST_STRING_ATTR + _TEST_STRING_ATTR = "" + with self.assertRaisesRegex(NotImplementedError, "arg_int"): + self.run_test_case("custom_op_int_arg.mlir") + + def run_test_case(self, file_name: str): + p = Path(__file__).resolve().parent / "testdata" / file_name + contents = p.read_text() + with Context() as ctx: + m = Module.parse(contents) + p = ExpandCustomOpsPass(m.operation) + p.run() + print(f"TEST CASE {file_name}:\n{m}") + m.operation.verify() + return m + + +class IdentityOp(CustomOp): + signature = "identity_tensor(Tensor t) -> Tensor" + + def select(self, ksel: KernelSelection): + x = ksel.arg_tensor(0) + ksel.return_tensor(x.t) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + kb.yield_results(kb.arg_bindings[0]) + + +class PrintStringAttrOp(CustomOp): + signature = "print_string_attr(str key) -> ()" + + def select(self, ksel: KernelSelection): + ksel.attr_str(0) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + global _TEST_STRING_ATTR + _TEST_STRING_ATTR = str(ksel.arg_descs[0].v) + print("CAPTURED STRING ATTR:", _TEST_STRING_ATTR) + kb.yield_results() + + +class IntArgOp(CustomOp): + signature = "int_arg(int t) -> ()" + + def select(self, ksel: KernelSelection): + x = ksel.arg_int(0) + ksel.return_int() + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + kb.yield_results(kb.arg_bindings[0]) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/transforms/general/testdata/custom_op_int_arg.mlir b/tests/transforms/general/testdata/custom_op_int_arg.mlir new file mode 100644 index 000000000..1a17b9a6a --- /dev/null +++ b/tests/transforms/general/testdata/custom_op_int_arg.mlir @@ -0,0 +1,9 @@ +builtin.module { + +func.func @forward() { + %i = torch.constant.int 1000 + torch.operator "torch.expand_custom_op_pass_test.int_arg"(%i) : (!torch.int) -> () + return +} + +} diff --git a/tests/transforms/general/testdata/custom_op_simple.mlir b/tests/transforms/general/testdata/custom_op_simple.mlir new file mode 100644 index 000000000..b0a879f9b --- /dev/null +++ b/tests/transforms/general/testdata/custom_op_simple.mlir @@ -0,0 +1,8 @@ +builtin.module { + +func.func @forward(%arg0: !torch.vtensor<[97,8],f32>) -> !torch.vtensor<[97,8],f32> { + %0 = torch.operator "torch.expand_custom_op_pass_test.identity_tensor"(%arg0) : (!torch.vtensor<[97,8],f32>) -> (!torch.vtensor<[97,8],f32>) + return %0 : !torch.vtensor<[97,8],f32> +} + +} diff --git a/tests/transforms/general/testdata/custom_op_string_attr.mlir b/tests/transforms/general/testdata/custom_op_string_attr.mlir new file mode 100644 index 000000000..c534a0745 --- /dev/null +++ b/tests/transforms/general/testdata/custom_op_string_attr.mlir @@ -0,0 +1,9 @@ +builtin.module { + +func.func @forward() { + %str = torch.constant.str "TEST_VALUE" + torch.operator "torch.expand_custom_op_pass_test.print_string_attr"(%str) : (!torch.str) -> () + return +} + +}