From 0d45767a37086c8ff1addc3b9da1c1a1ed9042dd Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 21 Dec 2023 17:44:49 -0800 Subject: [PATCH] Start adding the library of custom ops. Defines a couple of IREE builtins: * `ops.iree.trace_tensor` * `ops.iree.trace_tensors` Extends the infra for better support: * Adds support for `Tensor[]` arguments to custom ops. --- python/shark_turbine/ops/__init__.py | 7 + python/shark_turbine/ops/iree.py | 71 +++++ python/shark_turbine/runtime/op_reg/base.py | 286 +++++++++++++++--- .../shark_turbine/runtime/op_reg/compiler.py | 9 +- python/shark_turbine/runtime/op_reg/eager.py | 82 +++-- tests/ops/iree_test.py | 30 ++ 6 files changed, 420 insertions(+), 65 deletions(-) create mode 100644 python/shark_turbine/ops/__init__.py create mode 100644 python/shark_turbine/ops/iree.py create mode 100644 tests/ops/iree_test.py diff --git a/python/shark_turbine/ops/__init__.py b/python/shark_turbine/ops/__init__.py new file mode 100644 index 000000000..3f4a4554e --- /dev/null +++ b/python/shark_turbine/ops/__init__.py @@ -0,0 +1,7 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from . import iree diff --git a/python/shark_turbine/ops/iree.py b/python/shark_turbine/ops/iree.py new file mode 100644 index 000000000..77cdfa047 --- /dev/null +++ b/python/shark_turbine/ops/iree.py @@ -0,0 +1,71 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ..support.ir_imports import ( + RankedTensorType, + StringAttr, + Value, + flow_d, + tensor_d, +) + +from ..runtime.op_reg import ( + CustomOp, + KernelBuilder, + KernelSelection, + def_library, +) + +__all__ = [ + "trace", +] + +IREE_LIBRARY = def_library("iree") + + +################################################################################ +# trace_tensor / trace_tensors +################################################################################ + + +def _emit_tensor_trace(kb: KernelBuilder, key: str, ts: list[Value]): + dynamic_dims = [] + for t in ts: + rtt = RankedTensorType(t.type) + for i in range(rtt.rank): + if rtt.is_dynamic_dim(i): + dynamic_dims.append(tensor_d.dim(t, kb.constant_index(i))) + flow_d.TensorTraceOp(StringAttr.get(key), ts, dynamic_dims) + + +@CustomOp.register(library=IREE_LIBRARY) +class trace_tensor(CustomOp): + name = "trace_tensor" + signature = "(str trace_key, Tensor t) -> ()" + + def select(self, ksel: KernelSelection): + ksel.attr_str(0) + ksel.arg_tensor(1) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + _emit_tensor_trace(kb, ksel.arg_descs[0].v, [kb.arg_bindings[1]]) + kb.yield_results() + + +@CustomOp.register(library=IREE_LIBRARY) +class trace_tensors(CustomOp): + name = "trace_tensors" + signature = "(str trace_key, Tensor[] self) -> ()" + + def select(self, ksel: KernelSelection): + ksel.attr_str(0) + ksel.arg_tensor_list(1) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + ts = kb.arg_bindings[1] + if len(ts) >= 1: + _emit_tensor_trace(kb, ksel.arg_descs[0].v, ts) + kb.yield_results() diff --git a/python/shark_turbine/runtime/op_reg/base.py b/python/shark_turbine/runtime/op_reg/base.py index 6702d745f..a0832269c 100644 --- a/python/shark_turbine/runtime/op_reg/base.py +++ b/python/shark_turbine/runtime/op_reg/base.py @@ -13,6 +13,7 @@ from abc import ABC, abstractmethod, abstractproperty import functools import logging +import textwrap import torch from torch import Tensor @@ -21,12 +22,15 @@ Block, Context, FunctionType, + IndexType, InsertionPoint, + IntegerAttr, Location, StringAttr, SymbolTable, IrType, Value, + arith_d, builtin_d, func_d, ) @@ -43,6 +47,7 @@ "KernelBuilder", "KernelSelection", "TensorArg", + "def_library", ] logger = logging.getLogger("turbine.runtime.op_reg") @@ -51,10 +56,21 @@ # Op library management ############################################################################### + +def def_library(ns) -> torch.library.Library: + """Creates a new 'DEF' library which contains custom ops. + + It is necessary to create such custom op libraries in this way since + the library is registered with the compiler in such a way that it can + operate over all known custom ops. + """ + return torch.library.Library(ns, "DEF") + + # All such custom kernels are registered in the 'turbine' library/namespace. # We also allow extending existing libraries outside of this, but that is # the non default case. -TURBINE_LIBRARY = torch.library.Library("turbine", "DEF") +TURBINE_LIBRARY = def_library("turbine") class CustomOp(ABC): @@ -62,7 +78,7 @@ class CustomOp(ABC): @staticmethod def register( - op_class: Optional[Type["CustomOp"]], + op_class: Optional[Type["CustomOp"]] = None, *, library: torch.library.Library = TURBINE_LIBRARY, dispatch_key: str = "", @@ -165,7 +181,7 @@ def generate(self, ksel: "KernelSelection", kb: "KernelBuilder"): This method should generate IR into the given `KernelBuilder`. It can do so by consulting any state set on the `KernelSelection`. Each `KernelSelection.args` corresponds to `KernelBuilder.args`. - Unless if the argument was set as `is_ir_arg=False`, the argument + Unless if the argument was set as `ir_arity=0`, the argument will be a `Value`. Otherwise, it will be `None`. It is recommended to use `KernelBuilder.arg(n)` to access. @@ -202,6 +218,23 @@ def __init__(self, op: CustomOp, args: list[Any]): self.result_descs: list[ArgDescriptor] = [] self.variant: str = "default" + def __repr__(self): + lines = [ + "KernelSelection<", + f" op = '{self.op.name}',", + f" variant = '{self.variant}',", + " arg_descs = [", + ] + for arg_desc in self.arg_descs: + lines.append(f" {arg_desc},") + lines.append(" ],") + lines.append(" result_descs = [") + for result_desc in self.result_descs: + lines.append(f" {result_desc},") + lines.append(" ]") + lines.append(">") + return "\n".join(lines) + def generate_meta_returns(self) -> Any: results = [d.generate_meta() for d in self.result_descs] if len(results) == 1: @@ -211,9 +244,16 @@ def generate_meta_returns(self) -> Any: @property def spec_key(self) -> str: - arg_keys = ",".join(d.spec_key for d in self.arg_descs) - return_keys = ",".join(d.spec_key for d in self.result_descs) - return f"{self.op.cache_key_base}::{self.variant}({arg_keys})->({return_keys})" + try: + arg_keys = ",".join(d.spec_key for d in self.arg_descs) + return_keys = ",".join(d.spec_key for d in self.result_descs) + return ( + f"{self.op.cache_key_base}::{self.variant}({arg_keys})->({return_keys})" + ) + except Exception as e: + raise AssertionError( + f"Error generating spec_key from:\n{textwrap.indent(repr(self), ' ')}" + ) from e def arg_tensor(self, arg: int) -> "TensorArg": """Declares an argument to allow any ranked tensor and to specialize for each rank @@ -231,6 +271,22 @@ def arg_tensor(self, arg: int) -> "TensorArg": arg_descs[arg] = desc = TensorArg(arg_value) return desc + def arg_tensor_list(self, arg: int) -> "TensorListArg": + """Declares an argument to accept a list of tensors which will be specialized + for the list size and each rank/dtype. + + Returns the argument descriptor, which can be used to further inspect or constrain + the selection. It will default to allowing all dimensions to be dynamic. + """ + arg_descs = self.arg_descs + arg_value = self.args[arg] + assert arg_descs[arg] is None, f"Already constrained argument {arg}" + assert isinstance( + arg_value, list + ), f"Argument type mismatch from Torch for {arg}: Expected tensor, got {type(arg_value)}" + arg_descs[arg] = desc = TensorListArg(arg_value) + return desc + def arg_int(self, arg: int) -> "IntArg": """Declares an argument to be an integer value that can take any value. @@ -246,6 +302,21 @@ def arg_int(self, arg: int) -> "IntArg": arg_descs[arg] = desc = IntArg(arg_value) return desc + def attr_str(self, arg: int) -> "AttrArg": + """Declares an argument to be a string attribute. + + Such arguments are not materialized in the IR as Values but may be used to + generate the IR. In AOT contexts, they must be derived from static values. + """ + arg_descs = self.arg_descs + arg_value = self.args[arg] + assert arg_descs[arg] is None, f"Already constrained argument {arg}" + assert isinstance( + arg_value, str + ), f"Argument type mismatch from Torch for {arg}: Expected int, got {type(arg_value)}" + arg_descs[arg] = desc = AttrArg(arg_value) + return desc + def return_tensor(self, t: Tensor) -> "TensorArg": """Marks the next return value as a Tensor. @@ -258,25 +329,100 @@ def return_tensor(self, t: Tensor) -> "TensorArg": return desc +class AttrArg: + ir_arity: int = 0 + maybe_tensor_value: Optional[Tensor] = None + is_list: bool = False + + __slots__ = [ + "v", + "spec_value", + ] + + def __init__(self, v: object): + self.v = v + # We specialize on every distinct value. + self.spec_value: Optional[Any] = v + + def __repr__(self): + return f"AttrArg(<{self.spec_value}>)" + + def generate_meta(self) -> object: + return self.v + + @property + def spec_key(self) -> str: + """Generates a key that will be the same for all specializations.""" + return f"attr<{self.spec_value}>" + + @property + def mlir_type_asm(self) -> str: + raise AssertionError("Cannot resolve `mlir_type_asm` for an AttrArg") + + +class IntArg: + __slots__ = [ + "ir_arity", + "spec_value", + "v", + ] + + # All descriptors have an attribute to indicate their value + # as a tensor, and those that aren't are fixated to None. + # This is to enable fast lookup in the hot path of determining + # how to dispatch. + maybe_tensor_value: Optional[Tensor] = None + is_list: bool = False + + def __init__(self, v: int): + self.v = v + self.spec_value: Optional[Any] = None + self.ir_arity: int = 1 + + def __repr__(self): + return f"IntArg({self.v}, spec_value={self.spec_value}, is_ir_arg={self.is_ir_arg})" + + def generate_meta(self) -> int: + return self.v + + @property + def spec_key(self) -> str: + """Generates a key that will be the same for all specializations.""" + return f"int<{self.spec_value}>" + + @property + def mlir_type_asm(self) -> str: + # TODO: We can have individual kernels constrain this to a narrower + # type. + return "i64" + + class TensorArg: __slots__ = [ "t", "spec_dims", - "is_ir_arg", "maybe_tensor_value", ] + ir_arity: int = 1 + is_list: bool = False + def __init__(self, t: Tensor): self.t = t # Any static dims that we are specializing. Defaults to all dynamic. self.spec_dims: list[Optional[int]] = len(t.shape) * [None] - self.is_ir_arg = True # All descriptors have an attribute to indicate their value # as a tensor, and those that aren't are fixated to None. # This is to enable fast lookup in the hot path of determining # how to dispatch. self.maybe_tensor_value: Tensor = t + def __repr__(self): + return ( + f"TensorArg(shape={self.t.shape}, dtype={self.t.dtype}, " + f"spec_dims={self.spec_dims})" + ) + def generate_meta(self) -> Tensor: t = self.t if t.device == "meta": @@ -305,40 +451,69 @@ def mlir_type_asm(self) -> str: return f"tensor<{spec}>" -class IntArg: +class TensorListArg: __slots__ = [ - "is_ir_arg", - "v", - "spec_value", + "ts", + "spec_dims", + "ir_arity", "maybe_tensor_value", ] - def __init__(self, v: int): - self.v = v - self.spec_value: Optional[Any] = None - self.is_ir_arg = True + is_list: bool = True + + def __init__(self, ts: Tensor): + self.ts = ts + self.ir_arity = len(ts) + # Any static dims that we are specializing. Defaults to all dynamic. + self.spec_dims: list[list[Optional[int]]] = [len(t.shape) * [None] for t in ts] # All descriptors have an attribute to indicate their value # as a tensor, and those that aren't are fixated to None. # This is to enable fast lookup in the hot path of determining # how to dispatch. - self.maybe_tensor_value: Optional[Tensor] = None + self.maybe_tensor_value: Tensor = ts - def generate_meta(self) -> int: - return self.v + def __repr__(self): + return ( + f"TensorListArg(shape={[t.shape for t in self.ts]}, " + f"dtype={[t.dtype for t in self.ts]}, " + f"spec_dims={self.spec_dims}, ir_arity={self.ir_arity})" + ) + + def generate_meta(self) -> list[Tensor]: + metas = [] + for t in self.ts: + if t.device == "meta": + metas.append(t) + else: + metas.append(t.clone().detach().to("meta")) + return metas @property def spec_key(self) -> str: """Generates a key that will be the same for all specializations.""" - return f"int<{self.spec_value}>" + return ( + f"tensor[{[len(t.shape) for t in self.ts]}" + f":{[str(t.dtype) for t in self.ts]}]<{self.spec_dims}>" + ) @property - def mlir_type_asm(self) -> str: - # TODO: We can have individual kernels constrain this to a narrower - # type. - return "i64" - - -ArgDescriptor = Union[TensorArg, IntArg] + def mlir_type_asm(self) -> list[str]: + asms = [] + for t, spec_dims in zip(self.ts, self.spec_dims): + try: + dtype_asm = TORCH_DTYPE_TO_IREE_TYPE_ASM[t.dtype] + except KeyError as e: + raise KeyError( + f"Unknown mapping of torch dtype {t.dtype} to MLIR " + f"(possibly missing in TORCH_DTYPE_TO_IREE_TYPE_ASM table)" + ) from e + dim_asm = "x".join(["?" if d is None else str(d) for d in spec_dims]) + spec = f"{dim_asm}x{dtype_asm}" if dim_asm else dtype_asm + asms.append(f"tensor<{spec}>") + return asms + + +ArgDescriptor = Union[AttrArg, IntArg, TensorArg, TensorListArg] ############################################################################### # KernelBuilder @@ -352,7 +527,7 @@ class KernelBuilder(ABC): def __init__( self, ksel: KernelSelection, - arg_bindings: list[Value], + arg_bindings: list[Union[Value, list[Value]]], *, ip: InsertionPoint, module_body: Block, @@ -364,10 +539,10 @@ def __init__( self.module_body = module_body self.symbol_table = symbol_table - def arg_value(self, index: int) -> Value: + def arg_value(self, index: int) -> Union[list[Value], Value]: """Gets the concrete IR `Value` for the argument at `index`. - This will assert if the corresponding argument was set as `is_ir_arg=False` + This will assert if the corresponding argument was set as `ir_arity=0` during kernel selection. """ try: @@ -386,6 +561,10 @@ def yield_results(self, *results: Value): """Yields results of the kernel computation.""" ... + def constant_index(self, i: int) -> Value: + """Builds a constant index value.""" + return arith_d.constant(IntegerAttr.get(IndexType.get(), i)) + class FreeFuncKernelBuilder(KernelBuilder): """Kernel builder that emits the body of the kernel into a free function. @@ -409,10 +588,33 @@ def __init__( if func_name is None: func_name = ksel.op.name with context, Location.unknown(), InsertionPoint(module_body): - arg_types = [ - IrType.parse(d.mlir_type_asm) for d in ksel.arg_descs if d.is_ir_arg - ] - result_types = [IrType.parse(d.mlir_type_asm) for d in ksel.result_descs] + # Assemble arg types. + arg_types = [] + for d in ksel.arg_descs: + arity = d.ir_arity + if not d.is_list: + if arity == 1: + arg_types.append(IrType.parse(d.mlir_type_asm)) + else: + continue + else: + for i in range(arity): + arg_types.append(IrType.parse(d.mlir_type_asm[i])) + + # Assemble result types. + result_types = [] + for d in ksel.result_descs: + if not d.is_list: + if d.ir_arity == 1: + result_types.append(IrType.parse(d.mlir_type_asm)) + else: + continue + else: + # for i in range(arity): + # result_types.append(IrType.parse(d.mlir_type_asm[i])) + raise AssertionError("NYI: arity > 1 results") + + # Create the func. ftype = FunctionType.get(arg_types, result_types) func_op = func_d.FuncOp(func_name, ftype) if not is_public: @@ -422,13 +624,21 @@ def __init__( # Map inputs to arg bindings, lining up with arguments that are elided. block_arguments = list(entry_block.arguments) - block_arguments.reverse() + block_arg_index = 0 arg_bindings: list[Optional[Value]] = [] for desc in ksel.arg_descs: - if desc.is_ir_arg: - arg_bindings.append(block_arguments.pop()) + arity = desc.ir_arity + if not desc.is_list: + if desc.ir_arity == 1: + arg_bindings.append(block_arguments[block_arg_index]) + block_arg_index += 1 + else: + arg_bindings.append(None) else: - arg_bindings.append(None) + arg_bindings.append( + block_arguments[block_arg_index : block_arg_index + arity] + ) + block_arg_index += arity super().__init__( ksel, diff --git a/python/shark_turbine/runtime/op_reg/compiler.py b/python/shark_turbine/runtime/op_reg/compiler.py index c15ead608..4d9595a34 100644 --- a/python/shark_turbine/runtime/op_reg/compiler.py +++ b/python/shark_turbine/runtime/op_reg/compiler.py @@ -24,6 +24,10 @@ GeneralError, ) +from ...support.ir_imports import ( + Location, +) + from ...support.logging import ( runtime_logger as logger, ) @@ -87,10 +91,11 @@ def compile_standalone_kernel( start = default_timer() config = KernelCompileConfig(cache_key, list(device.compile_target_flags)) kb = FreeFuncKernelBuilder.create_module(ksel, func_name=func_name) - ksel.op.generate(ksel, kb) + with kb.ip, Location.unknown(): + ksel.op.generate(ksel, kb) kb.module_op.verify() module_asm = kb.module_op.get_asm( - binary=True, enable_debug_info=True, assume_verified=True + binary=True, enable_debug_info=True, print_generic_op_form=True ) generation_time = default_timer() - start diff --git a/python/shark_turbine/runtime/op_reg/eager.py b/python/shark_turbine/runtime/op_reg/eager.py index 629118096..472e91a9e 100644 --- a/python/shark_turbine/runtime/op_reg/eager.py +++ b/python/shark_turbine/runtime/op_reg/eager.py @@ -50,15 +50,27 @@ def eager_dispatch(ksel: KernelSelection): device: Optional[Device] = None torch_device: Optional[torch.device] = None for arg_desc in ksel.arg_descs: - if not arg_desc.is_ir_arg: - continue - tensor_arg = arg_desc.maybe_tensor_value - if tensor_arg is None: - continue - torch_device = tensor_arg.device - device = lookup_device_from_torch(torch_device) - if device is not None: - break + if not arg_desc.is_list: + if arg_desc.ir_arity: + # One arg has maybe_tensor_value as a single element (common case). + tensor_arg = arg_desc.maybe_tensor_value + if tensor_arg is None: + continue + torch_device = tensor_arg.device + device = lookup_device_from_torch(torch_device) + if device is not None: + break + else: + continue + else: + # List. maybe_tensor_value is a list. Uncommon case. + for tensor_arg in arg_desc.maybe_tensor_value: + if tensor_arg is None: + continue + torch_device = tensor_arg.device + device = lookup_device_from_torch(torch_device) + if device is not None: + break # Default to CPU. if device is None: @@ -72,21 +84,16 @@ def eager_dispatch(ksel: KernelSelection): # Build the concrete args, issuing device movement as necessary. arg_list = VmVariantList(len(ksel.arg_descs)) - for arg_desc in ksel.arg_descs: - if not arg_desc.is_ir_arg: - continue - tensor_arg = arg_desc.maybe_tensor_value - # Handle non-tensor args. - if tensor_arg is None: - scalar_value = arg_desc.v - if isinstance(scalar_value, int): - arg_list.push_int(scalar_value) - elif isinstance(scalar_value, float): - arg_list.push_float(scalar_value) - else: - raise UnsupportedTypeError(type(scalar_value)) - continue - # Tensor arg. + + def push_scalar(scalar_value): + if isinstance(scalar_value, int): + arg_list.push_int(scalar_value) + elif isinstance(scalar_value, float): + arg_list.push_float(scalar_value) + else: + raise UnsupportedTypeError(type(scalar_value)) + + def push_tensor(tensor_arg): if tensor_arg.device != torch_device: # TODO: If the source and target device are both known to us, # we can do this "in house" vs asking torch to do it. @@ -101,6 +108,28 @@ def eager_dispatch(ksel: KernelSelection): # import_torch_tensor. arg_list.push_ref(device.import_torch_tensor(tensor_arg)) + for arg_desc in ksel.arg_descs: + arity = arg_desc.ir_arity + if not arg_desc.is_list: + # Non-list. + if arity == 1: + tensor_arg = arg_desc.maybe_tensor_value + if tensor_arg is not None: + push_tensor(tensor_arg) + else: + push_scalar(arg_desc.v) + else: + continue + else: + # List. Uncommon case. + tensor_arg = arg_desc.maybe_tensor_value + if tensor_arg is not None: + for i in range(arity): + push_tensor(tensor_arg[i]) + else: + for i in range(arity): + push_scalar(arg_desc.v[i]) + if config.async_invocations: raise NotImplementedError("Async execution not yet implemented") @@ -113,8 +142,9 @@ def eager_dispatch(ksel: KernelSelection): # Unpack results. results = [] - for i, result_desc in enumerate(ksel.result_descs): + arity = result_desc.ir_arity + assert arity == 1, "NYI: Optional and result lists" meta_tensor_value = result_desc.maybe_tensor_value if meta_tensor_value is None: # Scalar return. @@ -128,5 +158,7 @@ def eager_dispatch(ksel: KernelSelection): if len(results) == 1: return results[0] + elif len(results) == 0: + return None else: return tuple(results) diff --git a/tests/ops/iree_test.py b/tests/ops/iree_test.py new file mode 100644 index 000000000..befb5a749 --- /dev/null +++ b/tests/ops/iree_test.py @@ -0,0 +1,30 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import unittest + +import torch + +import shark_turbine.ops as ops + + +class KernelRegTest(unittest.TestCase): + def testTrace(self): + t = torch.randn(3, 4) + ops.iree.trace_tensor("TEST", t) + + def testTraceList(self): + t1 = torch.randn(3, 4) + t2 = torch.randn(1, 8) + ops.iree.trace_tensors("TEST 2", [t1, t2]) + ops.iree.trace_tensors("TEST 1", [t1]) + ops.iree.trace_tensors("TEST 0", []) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main()