From 68df3169f3ad40102e4cdb6f8f8e54a64c0eead3 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 19 Dec 2023 13:27:44 -0800 Subject: [PATCH] [custom ops] Begin the scaffolding for dispatch of PyTorch custom ops. (#270) This makes it possible for us to directly define regular torch ops in terms of generated MLIR. The resulting ops will be specialized and cached per requirements in their definition and will be compiled for any device that Turbine supports when dispatched against tensors on that device. It is left to a follow-up to also wire this mechanism in on the AOT side so that compiling programs that contain our own custom ops transparently includes them with no further glue. The scaffolding for this is in place, but this patch is big enough without touching AOT. This allows users to say something like: ``` @CustomOp.register class identity(CustomOp): name = "test_identity" signature = "(Tensor self) -> Tensor" def select(self, ksel: KernelSelection): x = ksel.arg_tensor(0) ksel.return_tensor(x.t) def generate(self, ksel: KernelSelection, kb: KernelBuilder): # This just yields the IR value of kernel input as the output. # Effectively in eager mode, this is a `return` from the kernel # function. kb.yield_results(kb.arg_bindings[0]) t = torch.tensor([[1, 2, 3]], dtype=torch.int32) result = identity(t) print("CPU result:", result) torch.testing.assert_close(result, t) ``` There will be dedicated `CustomOp` subclasses for our various DSLs that can be used for such things (for more sugar'd use than just open coding IR). --- .github/workflows/test.yml | 2 +- python/shark_turbine/aot/builtins/jittable.py | 24 +- python/shark_turbine/aot/compiled_module.py | 14 +- python/shark_turbine/aot/exporter.py | 9 +- python/shark_turbine/aot/support/ir_utils.py | 63 +-- .../aot/support/procedural/base.py | 2 +- .../aot/support/procedural/globals.py | 2 +- .../aot/support/procedural/iree_emitter.py | 7 +- .../aot/support/procedural/primitives.py | 2 +- .../aot/support/procedural/tracer.py | 2 +- python/shark_turbine/dynamo/__init__.py | 1 - python/shark_turbine/dynamo/backends/cpu.py | 2 +- python/shark_turbine/dynamo/device.py | 176 ------ python/shark_turbine/dynamo/executor.py | 2 +- python/shark_turbine/dynamo/tensor.py | 65 +-- python/shark_turbine/runtime/__init__.py | 8 + python/shark_turbine/runtime/device.py | 375 +++++++++++++ .../shark_turbine/runtime/op_reg/__init__.py | 7 + python/shark_turbine/runtime/op_reg/base.py | 507 ++++++++++++++++++ .../shark_turbine/runtime/op_reg/compiler.py | 137 +++++ python/shark_turbine/runtime/op_reg/eager.py | 132 +++++ python/shark_turbine/support/conversions.py | 118 ++++ python/shark_turbine/support/exceptions.py | 12 + .../{aot => }/support/ir_imports.py | 1 + python/shark_turbine/support/logging.py | 9 + tests/dynamo/tensor_test.py | 3 +- tests/{dynamo => runtime}/device_test.py | 40 +- tests/runtime/op_reg/kernel_reg_test.py | 87 +++ 28 files changed, 1490 insertions(+), 319 deletions(-) delete mode 100644 python/shark_turbine/dynamo/device.py create mode 100644 python/shark_turbine/runtime/__init__.py create mode 100644 python/shark_turbine/runtime/device.py create mode 100644 python/shark_turbine/runtime/op_reg/__init__.py create mode 100644 python/shark_turbine/runtime/op_reg/base.py create mode 100644 python/shark_turbine/runtime/op_reg/compiler.py create mode 100644 python/shark_turbine/runtime/op_reg/eager.py create mode 100644 python/shark_turbine/support/conversions.py rename python/shark_turbine/{aot => }/support/ir_imports.py (97%) create mode 100644 python/shark_turbine/support/logging.py rename tests/{dynamo => runtime}/device_test.py (61%) create mode 100644 tests/runtime/op_reg/kernel_reg_test.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 27a5dac66..6c56db1ca 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -46,7 +46,7 @@ jobs: - name: Run tests run: | - pytest tests/ + pytest -n 4 tests/ black: strategy: diff --git a/python/shark_turbine/aot/builtins/jittable.py b/python/shark_turbine/aot/builtins/jittable.py index af80ae291..d2c85b73f 100644 --- a/python/shark_turbine/aot/builtins/jittable.py +++ b/python/shark_turbine/aot/builtins/jittable.py @@ -31,6 +31,18 @@ FxImporter, ) +from ...support.ir_imports import ( + FlatSymbolRefAttr, + FunctionType, + Operation, + StringAttr, + SymbolTable, + TypeAttr, + Value, + func_d, + util_d, +) + from ..passes import ( functorch_functionalize, ) @@ -53,18 +65,6 @@ MaterializedGlobal, ) -from ..support.ir_imports import ( - FlatSymbolRefAttr, - FunctionType, - Operation, - StringAttr, - SymbolTable, - TypeAttr, - Value, - func_d, - util_d, -) - StringAttrOrStr = Union[StringAttr, str] diff --git a/python/shark_turbine/aot/compiled_module.py b/python/shark_turbine/aot/compiled_module.py index 304bc5ccc..9808ffeb4 100644 --- a/python/shark_turbine/aot/compiled_module.py +++ b/python/shark_turbine/aot/compiled_module.py @@ -17,13 +17,7 @@ from . import builtins -from .support.procedural import ( - GlobalsDef, - ProcedureTrace, - current_ir_trace, -) - -from .support.ir_imports import ( +from ..support.ir_imports import ( Context, Location, MLIRError, @@ -33,6 +27,12 @@ StringAttr, ) +from .support.procedural import ( + GlobalsDef, + ProcedureTrace, + current_ir_trace, +) + from .support.ir_utils import ( ModuleBuilder, ) diff --git a/python/shark_turbine/aot/exporter.py b/python/shark_turbine/aot/exporter.py index 0a05b12fe..e8c429583 100644 --- a/python/shark_turbine/aot/exporter.py +++ b/python/shark_turbine/aot/exporter.py @@ -19,16 +19,17 @@ Output, ) +from ..support.ir_imports import ( + Context, + Operation, +) + from .builtins import * from .compiled_module import ( CompiledModule, CompiledModuleMeta, ExportProcDef, ) -from .support.ir_imports import ( - Context, - Operation, -) from .support.procedural import ( AbstractTypedef, ) diff --git a/python/shark_turbine/aot/support/ir_utils.py b/python/shark_turbine/aot/support/ir_utils.py index 9168f67db..5dc9e9c79 100644 --- a/python/shark_turbine/aot/support/ir_utils.py +++ b/python/shark_turbine/aot/support/ir_utils.py @@ -5,7 +5,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple +from typing import Any, Callable, Generator, List, Optional, Sequence, Tuple from pathlib import Path import tempfile @@ -15,7 +15,6 @@ from ...importers.fx_importer import ( ContextCache, - TORCH_DTYPE_TO_MLIR_TYPE_ASM, ) from ...importers.utils import ( @@ -26,12 +25,9 @@ NativeTypeConverter, ) -from .ir_imports import ( +from ...support.ir_imports import ( Attribute, - Block, - BlockArgument, BF16Type, - ComplexType, DenseElementsAttr, DenseResourceElementsAttr, F16Type, @@ -46,7 +42,6 @@ IrType, Location, MLIRError, - OpResult, Operation, RankedTensorType, StringAttr, @@ -59,61 +54,21 @@ tensor_d, ) +from ...support.conversions import ( + TORCH_DTYPE_TO_IREE_TYPE, +) + from .utils import ( RefTracker, logger, ) -############################################################################### -# Lookup tables -############################################################################### - -# We need the inverse of the TORCH_DTYPE_TO_MLIR_TYPE_ASM table. -MLIR_TYPE_ASM_TO_TORCH_DTYPE = {v: k for k, v in TORCH_DTYPE_TO_MLIR_TYPE_ASM.items()} - -# When emitting constants, we have to create native IREE types. -TORCH_DTYPE_TO_IREE_TYPE: Dict[torch.dtype, Callable[[], IrType]] = { - torch.float16: lambda: F16Type.get(), - torch.bfloat16: lambda: BF16Type.get(), - torch.float32: lambda: F32Type.get(), - torch.float64: lambda: F64Type.get(), - torch.uint8: lambda: IntegerType.get_signless(8), - torch.int8: lambda: IntegerType.get_signless(8), - torch.int16: lambda: IntegerType.get_signless(16), - torch.int32: lambda: IntegerType.get_signless(32), - torch.int64: lambda: IntegerType.get_signless(64), - torch.bool: lambda: IntegerType.get_signless(1), - torch.qint8: lambda: IntegerType.get_signless(8), - torch.quint8: lambda: IntegerType.get_signless(8), - torch.complex32: lambda: ComplexType.get(F16Type.get()), - torch.complex64: lambda: ComplexType.get(F32Type.get()), - torch.complex128: lambda: ComplexType.get(F64Type.get()), -} - -TORCH_DTYPE_TO_IREE_TYPE_ASM = { - torch.float16: "f16", - torch.bfloat16: "bf16", - torch.float32: "f32", - torch.float64: "f64", - torch.uint8: "i8", - torch.int8: "i8", - torch.int16: "i16", - torch.int32: "i32", - torch.int64: "i64", - torch.bool: "i1", - torch.qint8: "i8", - torch.quint8: "i8", - torch.complex32: "complex", - torch.complex64: "complex", - torch.complex128: "complex", -} - ############################################################################### # Configuration ############################################################################### # Maps a name to an altered name. If returns None, then the original -# name is used (this lets Dict.get serve as a NameMapCallback). +# name is used (this lets dict.get serve as a NameMapCallback). NameMapCallback = Callable[[str], Optional[str]] @@ -420,7 +375,7 @@ def build_index_attribute(value: int) -> IntegerAttr: def build_index_value( - value: int, constant_cache: Optional[Dict[int, Value]] = None + value: int, constant_cache: Optional[dict[int, Value]] = None ) -> Value: if constant_cache is not None and value in constant_cache: return constant_cache[value] @@ -431,7 +386,7 @@ def build_index_value( def build_tensor_dim_value( - t: Value, dim: int, constant_cache: Optional[Dict[int, Value]] = None + t: Value, dim: int, constant_cache: Optional[dict[int, Value]] = None ) -> Value: dim_value = build_index_value(dim, constant_cache=constant_cache) return tensor_d.DimOp(t, dim_value).result diff --git a/python/shark_turbine/aot/support/procedural/base.py b/python/shark_turbine/aot/support/procedural/base.py index 4bc7d0be9..7bcc8b7fc 100644 --- a/python/shark_turbine/aot/support/procedural/base.py +++ b/python/shark_turbine/aot/support/procedural/base.py @@ -17,7 +17,7 @@ import torch -from ..ir_imports import ( +from ....support.ir_imports import ( F32Type, F64Type, IndexType, diff --git a/python/shark_turbine/aot/support/procedural/globals.py b/python/shark_turbine/aot/support/procedural/globals.py index be882eb59..c186538d8 100644 --- a/python/shark_turbine/aot/support/procedural/globals.py +++ b/python/shark_turbine/aot/support/procedural/globals.py @@ -19,7 +19,7 @@ import torch -from ..ir_imports import ( +from ....support.ir_imports import ( IrType, Operation, Value, diff --git a/python/shark_turbine/aot/support/procedural/iree_emitter.py b/python/shark_turbine/aot/support/procedural/iree_emitter.py index b8f1b2734..5e951b0b1 100644 --- a/python/shark_turbine/aot/support/procedural/iree_emitter.py +++ b/python/shark_turbine/aot/support/procedural/iree_emitter.py @@ -12,7 +12,7 @@ import torch -from ..ir_imports import ( +from ....support.ir_imports import ( IndexType, IntegerType, IrType, @@ -23,8 +23,11 @@ flow_d, ) -from ..ir_utils import ( +from ....support.conversions import ( TORCH_DTYPE_TO_IREE_TYPE, +) + +from ..ir_utils import ( build_index_value, ) diff --git a/python/shark_turbine/aot/support/procedural/primitives.py b/python/shark_turbine/aot/support/procedural/primitives.py index 9e0240ea5..0e07d8f48 100644 --- a/python/shark_turbine/aot/support/procedural/primitives.py +++ b/python/shark_turbine/aot/support/procedural/primitives.py @@ -24,7 +24,7 @@ dynamic_dim, ) -from ..ir_imports import ( +from ....support.ir_imports import ( F32Type, IrType, RankedTensorType, diff --git a/python/shark_turbine/aot/support/procedural/tracer.py b/python/shark_turbine/aot/support/procedural/tracer.py index 0339e1ee4..d4cf954e1 100644 --- a/python/shark_turbine/aot/support/procedural/tracer.py +++ b/python/shark_turbine/aot/support/procedural/tracer.py @@ -14,7 +14,7 @@ Sequence, ) -from ..ir_imports import ( +from ....support.ir_imports import ( Location, StringAttr, Value, diff --git a/python/shark_turbine/dynamo/__init__.py b/python/shark_turbine/dynamo/__init__.py index aa6d60d96..b122f4621 100644 --- a/python/shark_turbine/dynamo/__init__.py +++ b/python/shark_turbine/dynamo/__init__.py @@ -5,7 +5,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from .device import Device from .tensor import ( enable, TurbineMode, diff --git a/python/shark_turbine/dynamo/backends/cpu.py b/python/shark_turbine/dynamo/backends/cpu.py index 61a9adca9..f92299c2e 100644 --- a/python/shark_turbine/dynamo/backends/cpu.py +++ b/python/shark_turbine/dynamo/backends/cpu.py @@ -7,7 +7,7 @@ import functools import sys -from ..device import ( +from ...runtime.device import ( DeviceState, ) diff --git a/python/shark_turbine/dynamo/device.py b/python/shark_turbine/dynamo/device.py deleted file mode 100644 index 07181e6a3..000000000 --- a/python/shark_turbine/dynamo/device.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from functools import lru_cache -from typing import List, Optional, Sequence, Union -from threading import local, Lock - -from iree.runtime import ( - asdevicearray, - create_hal_module, - HalBufferView, - DeviceArray, - get_driver, - VmContext, - HalDevice, - HalDriver, - VmInstance, - VmModule, - VmVariantList, -) - -from ..support.exceptions import * - -__all__ = [ - "get_vm_instance", - "DeviceState", -] - -_CONFIG_LOCK = Lock() -_GLOBAL_VM_INSTANCE: Optional[VmInstance] = None - - -def get_vm_instance() -> VmInstance: - global _GLOBAL_VM_INSTANCE - if not _GLOBAL_VM_INSTANCE: - with _CONFIG_LOCK: - if not _GLOBAL_VM_INSTANCE: - _GLOBAL_VM_INSTANCE = VmInstance() - return _GLOBAL_VM_INSTANCE - - -class DeviceState: - """State for an instantiated HAL device. - - Note that the IREE runtime internally manages a global cache of drivers for - standard named-access (not custom-constructed) drivers. - """ - - __slots__ = [ - "device", - "driver", - "instance", - ] - - def __init__( - self, - *, - driver: Union[str, HalDriver], - device: Optional[HalDevice] = None, - vm_instance: Optional[VmInstance] = None, - ): - self.instance = vm_instance or get_vm_instance() - self.driver = driver if isinstance(driver, HalDriver) else get_driver(driver) - self.device = device if device else self.driver.create_default_device() - - @staticmethod - @lru_cache(maxsize=None) - def from_uri(uri: str) -> "DeviceState": - driver = get_driver(uri) - return DeviceState(driver=driver, device=driver.create_device_by_uri(uri)) - - -_CURRENT_THREAD = local() - - -class Device: - """Represents a low-level device (HalDriver/HalDevice) and scheduling data. - - This is the type that user's interact with as a 'Device'. Devices can be handled - loose-leaf or bound to a thread with a context manager. - """ - - __slots__ = [ - "_s", - "_main_timeline", - "_main_timepoint", - "_tx_timeline", - "_tx_timepoint", - "_fence_capacity", - ] - - def __new__( - cls, uri: Optional[str] = None, *, device_state: Optional[DeviceState] = None - ): - if uri is not None: - # Construction by URI is cached on the thread. - assert not device_state, "device_state= cannot be given with explicit URI" - try: - existing = _CURRENT_THREAD.device_by_uri[uri] - except (AttributeError, KeyError): - ... - else: - return existing - - # New instance. - device_state = DeviceState.from_uri(uri) - new_inst = super().__new__(cls) - new_inst._s = device_state - try: - _CURRENT_THREAD.device_by_uri[uri] = new_inst - except AttributeError: - _CURRENT_THREAD.device_by_uri = {uri: new_inst} - new_inst._initialize() - return new_inst - else: - # Explicit construction with a device_state is assumed that you know what you - # are doing and an uncached instance will be returned. This will be unsychronized - # relative to any cached instance. - assert device_state, "device_state= must be given if URI ommitted" - new_inst = super().__new__(cls) - new_inst._s = device_state - new_inst._initialize() - return new_inst - - def _initialize(self): - d = self._s.device - self._main_timeline = d.create_semaphore(0) - self._main_timepoint = 0 - self._tx_timeline = d.create_semaphore(0) - self._tx_timepoint = 0 - # Maximum number of semaphores the device uses. Can be increased if doing out of the - # ordinary scheduling. - self._fence_capacity = 2 - - @property - def hal_device(self) -> HalDevice: - return self._s.device - - def current() -> "Device": - try: - return _CURRENT_THREAD.stack[-1] - except (AttributeError, IndexError): - raise NoCurrentDeviceError() - - def set(self) -> "Device": - """Sets this device as the current device without a context manager.""" - try: - _CURRENT_THREAD.stack.append(self) - except AttributeError: - _CURRENT_THREAD.stack = [self] - - def clear(self): - """Clears the current device without a context manager.""" - try: - c = _CURRENT_THREAD.stack[-1] - if _CURRENT_THREAD.stack[-1] is self: - _CURRENT_THREAD.stack.pop() - return - except (AttributeError, IndexError): - ... - raise MismatchedDeviceSetClearError() - - def __repr__(self): - return f"" - - def __enter__(self): - try: - _CURRENT_THREAD.stack.append(self) - except AttributeError: - _CURRENT_THREAD.stack = [self] - - def __exit__(self, type, value, traceback): - _CURRENT_THREAD.stack.pop() diff --git a/python/shark_turbine/dynamo/executor.py b/python/shark_turbine/dynamo/executor.py index 417946845..5208210a6 100644 --- a/python/shark_turbine/dynamo/executor.py +++ b/python/shark_turbine/dynamo/executor.py @@ -30,7 +30,7 @@ from_numpy as torch_from_numpy, ) -from .device import Device, DeviceState +from ..runtime.device import Device, DeviceState @functools.lru_cache(maxsize=None) diff --git a/python/shark_turbine/dynamo/tensor.py b/python/shark_turbine/dynamo/tensor.py index f7070c07c..c85515c5a 100644 --- a/python/shark_turbine/dynamo/tensor.py +++ b/python/shark_turbine/dynamo/tensor.py @@ -10,11 +10,10 @@ zoo: https://github.com/albanD/subclass_zoo/blob/main/new_device.py """ -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Any, Optional, Sequence import functools import atexit -from array import array import numpy as np from types import BuiltinFunctionType @@ -22,10 +21,17 @@ import torch._dynamo as dynamo from torch.overrides import TorchFunctionMode -from .device import ( +from ..runtime.device import ( Device, DeviceState, ) + +from ..support.conversions import ( + DTYPE_TO_ELEMENT_TYPE, + dtype_to_element_type, + torch_dtype_to_numpy, +) + from .executor import EagerSpecializedExecutable from ..support import ( @@ -267,7 +273,7 @@ def buffer_view(self) -> HalBufferView: self._bv = HalBufferView( self._storage.buffer, shape=self.size(), - element_type=_dtype_to_element_type(self.dtype), + element_type=dtype_to_element_type(self.dtype), ) return self._bv @@ -453,7 +459,7 @@ def inner(*args, device: Device, **kwargs): @cpu_tensor_constructor def _arange(*args, dtype=None): if dtype is not None: - dtype = _torch_dtype_to_numpy(dtype) + dtype = torch_dtype_to_numpy(dtype) return torch.from_numpy(np.arange(*args, dtype=dtype)) @@ -589,53 +595,6 @@ def func_src_op(*args, **kwargs): # Conversions ############################################################################### -_DTYPE_TO_ELEMENT_TYPE: Dict[torch.dtype, HalElementType] = { - torch.float16: HalElementType.FLOAT_16, - torch.bfloat16: HalElementType.BFLOAT_16, - torch.float32: HalElementType.FLOAT_32, - torch.float64: HalElementType.FLOAT_64, - torch.uint8: HalElementType.UINT_8, - torch.int8: HalElementType.SINT_8, - torch.int16: HalElementType.SINT_16, - torch.int32: HalElementType.SINT_32, - torch.int64: HalElementType.SINT_64, - torch.bool: HalElementType.BOOL_8, - torch.qint8: HalElementType.OPAQUE_8, - torch.quint8: HalElementType.OPAQUE_8, - torch.complex64: HalElementType.COMPLEX_64, - torch.complex128: HalElementType.COMPLEX_128, -} - - -def _dtype_to_element_type(dtype) -> HalElementType: - try: - return _DTYPE_TO_ELEMENT_TYPE[dtype] - except KeyError: - raise UnknownDTypeError(dtype) - - -_TORCH_DTYPE_TO_NUMPY = { - torch.float16: np.float16, - torch.float32: np.float32, - torch.float64: np.float64, - torch.uint8: np.uint8, - torch.int8: np.int8, - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.int64, - torch.bool: np.bool_, - torch.complex64: np.complex64, - torch.complex128: np.complex128, -} - - -def _torch_dtype_to_numpy(torch_dtype: torch.dtype) -> Any: - try: - return _TORCH_DTYPE_TO_NUMPY[torch_dtype] - except KeyError: - raise UnknownDTypeError(torch_dtype) - - _ELEMENT_TYPE_TO_NUMPY_DTYPE = { HalElementType.FLOAT_16: np.float16, HalElementType.FLOAT_32: np.float32, @@ -653,7 +612,7 @@ def _torch_dtype_to_numpy(torch_dtype: torch.dtype) -> Any: def _element_type_to_numpy_dtype(element_type: HalElementType) -> Any: try: - return _DTYPE_TO_ELEMENT_TYPE[element_type] + return DTYPE_TO_ELEMENT_TYPE[element_type] except KeyError: raise UnknownDTypeError(element_type) diff --git a/python/shark_turbine/runtime/__init__.py b/python/shark_turbine/runtime/__init__.py new file mode 100644 index 000000000..29434c268 --- /dev/null +++ b/python/shark_turbine/runtime/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .device import * +from . import op_reg diff --git a/python/shark_turbine/runtime/device.py b/python/shark_turbine/runtime/device.py new file mode 100644 index 000000000..a07e139cf --- /dev/null +++ b/python/shark_turbine/runtime/device.py @@ -0,0 +1,375 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from functools import lru_cache +from typing import Callable, Optional, Union +from threading import local, Lock + +import torch + +from iree.runtime import ( + BufferUsage, + HalBufferView, + HalDevice, + HalDriver, + MemoryType, + VmInstance, + VmModule, + create_hal_module, + get_driver, +) + +from ..support.conversions import ( + dtype_to_element_type, + torch_dtype_to_numpy, +) + +from ..support.exceptions import ( + NoCurrentDeviceError, + MismatchedDeviceSetClearError, + UnsupportedTorchDeviceError, +) + +from ..support.logging import runtime_logger as logger + +__all__ = [ + "get_vm_instance", + "Device", + "DeviceState", +] + +_CONFIG_LOCK = Lock() +_GLOBAL_VM_INSTANCE: Optional[VmInstance] = None +_CURRENT_THREAD = local() + +############################################################################### +# DeviceState ande Device classes. +# These associated shared VmInstance and HalDrivers with a concrete HalDevice. +# The Device class also adds other accounting needed for interop in PyTorch's +# eager environment (i.e. transfer and compute queue counters, etc). +############################################################################### + + +def get_vm_instance() -> VmInstance: + global _GLOBAL_VM_INSTANCE + if not _GLOBAL_VM_INSTANCE: + with _CONFIG_LOCK: + if not _GLOBAL_VM_INSTANCE: + _GLOBAL_VM_INSTANCE = VmInstance() + return _GLOBAL_VM_INSTANCE + + +class DeviceState: + """State for an instantiated HAL device. + + Note that the IREE runtime internally manages a global cache of drivers for + standard named-access (not custom-constructed) drivers. + """ + + __slots__ = [ + "device", + "driver", + "instance", + ] + + def __init__( + self, + *, + driver: Union[str, HalDriver], + device: Optional[HalDevice] = None, + vm_instance: Optional[VmInstance] = None, + ): + self.instance = vm_instance or get_vm_instance() + self.driver = driver if isinstance(driver, HalDriver) else get_driver(driver) + self.device = device if device else self.driver.create_default_device() + + @staticmethod + @lru_cache(maxsize=None) + def from_uri(uri: str) -> "DeviceState": + driver = get_driver(uri) + return DeviceState(driver=driver, device=driver.create_device_by_uri(uri)) + + +class Device: + """Represents a low-level device (HalDriver/HalDevice) and scheduling data. + + This is the type that user's interact with as a 'Device'. Devices can be handled + loose-leaf or bound to a thread with a context manager. + """ + + __slots__ = [ + "_s", + "_main_timeline", + "_main_timepoint", + "_tx_timeline", + "_tx_timepoint", + "_fence_capacity", + "compile_target_flags", + "export_torch_tensor", + "import_torch_tensor", + "instance_cache_key", + "type_cache_key", + ] + + _s: DeviceState + + # Each device will have a function attached to import a torch.tensor + # *that is already on that device* directly from device memory. + # This is unsafe and relatively unchecked. If criss-crossing devices, + # it is undefined behavior. + import_torch_tensor: Callable[[torch.Tensor], HalBufferView] + + # Devices can also export a torch tensor from a HalBufferView, given + # a meta tensor that describes it. + export_torch_tensor: Callable[[HalBufferView, torch.Tensor], torch.Tensor] + + # Cache key that uniquely identifies this device. + instance_cache_key: str + + # Cache key that uniquely identifies this type of device (currently + # based on its driver). + type_cache_key: str + + # Compiler flags to use to target this device. + # TODO: We should replace this with a target attribute but need an API + # to derive that. + compile_target_flags: tuple[str, ...] + + def __new__( + cls, uri: Optional[str] = None, *, device_state: Optional[DeviceState] = None + ): + if uri is not None: + # Construction by URI is cached on the thread. + assert not device_state, "device_state= cannot be given with explicit URI" + try: + existing = _CURRENT_THREAD.device_by_uri[uri] + except (AttributeError, KeyError): + ... + else: + return existing + + # New instance. + device_state = DeviceState.from_uri(uri) + new_inst = super().__new__(cls) + new_inst._s = device_state + try: + _CURRENT_THREAD.device_by_uri[uri] = new_inst + except AttributeError: + _CURRENT_THREAD.device_by_uri = {uri: new_inst} + new_inst._initialize() + return new_inst + else: + # Explicit construction with a device_state is assumed that you know what you + # are doing and an uncached instance will be returned. This will be unsychronized + # relative to any cached instance. + assert device_state, "device_state= must be given if URI ommitted" + new_inst = super().__new__(cls) + new_inst._s = device_state + new_inst._initialize() + return new_inst + + def _initialize(self): + d = self._s.device + self._main_timeline = d.create_semaphore(0) + self._main_timepoint = 0 + self._tx_timeline = d.create_semaphore(0) + self._tx_timepoint = 0 + # Maximum number of semaphores the device uses. Can be increased if doing out of the + # ordinary scheduling. + self._fence_capacity = 2 + + # Perform driver specific augmentations. + # TODO: Add a HalDriver.id property to get the driver name instead of parsing + # the device repr. + driver_id = repr(d) + colon_pos = driver_id.find(":") + if colon_pos >= 0: + driver_id = driver_id[0:colon_pos] + try: + import_fn = TORCH_TENSOR_IMPORTERS[driver_id] + export_fn = TORCH_TENSOR_EXPORTERS[driver_id] + self.import_torch_tensor = lambda t: import_fn(self, t) + self.export_torch_tensor = lambda bv, t: export_fn(self, bv, t) + self.compile_target_flags = DEVICE_TARGET_COMPILE_FLAGS[driver_id] + except KeyError as e: + raise AssertionError( + f"Unsupported TORCH_TENSOR_IMPORTERS for iree driver '{driver_id}'" + ) from e + + # Cache keys. + # TODO: The type cache key should actually be based on the driver id + # and device characteristics hash. + self.instance_cache_key = repr(d) + self.type_cache_key = driver_id + + @property + def hal_device(self) -> HalDevice: + return self._s.device + + @property + def vm_instance(self) -> VmInstance: + return self._s.instance + + def create_hal_module(self) -> VmModule: + s = self._s + return create_hal_module(s.instance, s.device) + + def current() -> "Device": + try: + return _CURRENT_THREAD.stack[-1] + except (AttributeError, IndexError): + raise NoCurrentDeviceError() + + def set(self) -> "Device": + """Sets this device as the current device without a context manager.""" + try: + _CURRENT_THREAD.stack.append(self) + except AttributeError: + _CURRENT_THREAD.stack = [self] + + def clear(self): + """Clears the current device without a context manager.""" + try: + c = _CURRENT_THREAD.stack[-1] + if _CURRENT_THREAD.stack[-1] is self: + _CURRENT_THREAD.stack.pop() + return + except (AttributeError, IndexError): + ... + raise MismatchedDeviceSetClearError() + + def __repr__(self): + return f"" + + def __enter__(self): + try: + _CURRENT_THREAD.stack.append(self) + except AttributeError: + _CURRENT_THREAD.stack = [self] + + def __exit__(self, type, value, traceback): + _CURRENT_THREAD.stack.pop() + + +def _device_import_torch_tensor_cpu(device: Device, t: torch.Tensor) -> HalBufferView: + hal_device = device.hal_device + element_type = dtype_to_element_type(t.dtype) + # TODO: In this case, we should be importing the raw buffer, but this is not + # generically exposed to Python in the IREE runtime. + bv = device.hal_device.allocator.allocate_buffer_copy( + memory_type=MemoryType.DEVICE_LOCAL, + allowed_usage=BufferUsage.DEFAULT, + device=hal_device, + buffer=t.numpy(), + element_type=element_type, + ) + return bv + + +def _device_export_torch_tensor_cpu( + device: Device, bv: HalBufferView, like: torch.Tensor +) -> torch.Tensor: + # TODO: Similar to import, we know that the buffer is in local CPU memory + # and could export it if we had Python API support for that. Until we have + # that, we do this very torturous indirection. + mapped_memory = bv.map() + shape = list(like.shape) + np_dtype = torch_dtype_to_numpy(like.dtype) + mapped_array = mapped_memory.asarray(shape, np_dtype) + return torch.from_numpy(mapped_array) + + +# Mapping of torch tensor importers keyed by driver name. +TORCH_TENSOR_IMPORTERS: dict[str, Callable[[Device, torch.Tensor], HalBufferView]] = { + "local-sync": _device_import_torch_tensor_cpu, + "local-task": _device_import_torch_tensor_cpu, +} + +TORCH_TENSOR_EXPORTERS: dict[ + str, Callable[[Device, HalBufferView, torch.Tensor], torch.Tensor] +] = { + "local-sync": _device_export_torch_tensor_cpu, + "local-task": _device_export_torch_tensor_cpu, +} + +DEVICE_TARGET_COMPILE_FLAGS: dict[str, tuple[str, ...]] = { + "local-task": ("--iree-hal-target-backends=llvm-cpu",), +} + +# Aliases. +DEVICE_TARGET_COMPILE_FLAGS["local-sync"] = DEVICE_TARGET_COMPILE_FLAGS["local-task"] + +# Make sure all tables have the same keys. +assert ( + TORCH_TENSOR_IMPORTERS.keys() == DEVICE_TARGET_COMPILE_FLAGS.keys() +), "Not all devices have the same configs" + +assert ( + TORCH_TENSOR_IMPORTERS.keys() == TORCH_TENSOR_EXPORTERS.keys() +), "Not all devices have the same configs" + +############################################################################### +# torch.device to Device mapping +############################################################################### + + +def lookup_device_from_torch( + torch_device: torch.device, *, create: bool = True +) -> Optional[Device]: + """Gets a shared Device corresponding to the given torch.device. + + This will return None if the device is wholly unsupported or if + create=False. Otherwise, faults in setting up the device are + reported as an appropriate exception. + """ + try: + mapping = _CURRENT_THREAD.device_by_torch_device + except AttributeError: + _CURRENT_THREAD.device_by_torch_device = mapping = {} + device = mapping.get(torch_device) + if device is not None or not create: + return device + logger.debug("Creating turbine device for torch.device = %r", torch_device) + device = _create_device_from_torch(torch_device) + if device is not None: + mapping[torch_device] = device + return device + + +def get_device_from_torch(torch_device: torch.device) -> Device: + """Gets a shared Device corresponding to the given torch.device. + + Raises an exception if the device cannot be created. + """ + device = lookup_device_from_torch(torch_device) + if device is None: + raise UnsupportedTorchDeviceError(torch_device) + return device + + +def _create_device_from_torch(torch_device: torch.device) -> Optional[Device]: + torch_type = torch_device.type + uri = None + if torch_type == "cpu": + uri = "local-task" + + if uri is None: + return None + + return Device(uri) + + +############################################################################### +# Utilities +############################################################################### + +# The nanobind leak checker doesn't interop well with the way that +# global state is managed for PyTorch. It isn't clear that this +# is a fully correctable state of affairs, so we just disable it +# for now. RIP nice things :( +from iree.runtime._binding import disable_leak_checker + +disable_leak_checker() diff --git a/python/shark_turbine/runtime/op_reg/__init__.py b/python/shark_turbine/runtime/op_reg/__init__.py new file mode 100644 index 000000000..c18790d33 --- /dev/null +++ b/python/shark_turbine/runtime/op_reg/__init__.py @@ -0,0 +1,7 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .base import * diff --git a/python/shark_turbine/runtime/op_reg/base.py b/python/shark_turbine/runtime/op_reg/base.py new file mode 100644 index 000000000..6702d745f --- /dev/null +++ b/python/shark_turbine/runtime/op_reg/base.py @@ -0,0 +1,507 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Base classes for registering custom operations with the PyTorch +dispatcher. +""" + +from typing import Any, Callable, Optional, Type, Union + +from abc import ABC, abstractmethod, abstractproperty +import functools +import logging + +import torch +from torch import Tensor + +from ...support.ir_imports import ( + Block, + Context, + FunctionType, + InsertionPoint, + Location, + StringAttr, + SymbolTable, + IrType, + Value, + builtin_d, + func_d, +) + +from ...support.conversions import ( + TORCH_DTYPE_TO_IREE_TYPE_ASM, +) + +__all__ = [ + "ArgDescriptor", + "CustomOp", + "FreeFuncKernelBuilder", + "IntArg", + "KernelBuilder", + "KernelSelection", + "TensorArg", +] + +logger = logging.getLogger("turbine.runtime.op_reg") + +############################################################################### +# Op library management +############################################################################### + +# All such custom kernels are registered in the 'turbine' library/namespace. +# We also allow extending existing libraries outside of this, but that is +# the non default case. +TURBINE_LIBRARY = torch.library.Library("turbine", "DEF") + + +class CustomOp(ABC): + """Users subclass this in order to register a turbine custom op.""" + + @staticmethod + def register( + op_class: Optional[Type["CustomOp"]], + *, + library: torch.library.Library = TURBINE_LIBRARY, + dispatch_key: str = "", + register_meta: bool = True, + register_impl: bool = True, + ) -> Callable: + """Class decorator for `CustomOp` implementations. + + The decorator will instantiate the class and then replace it with + the callable operation that can be used to invoke the kernel. + + Typical usage: + + ``` + @CustomOp.register + class identity(CustomOp): + ... + + result = identity(torch.tensor(1, 2, 3)) + ``` + """ + if not op_class: + return functools.partial( + CustomOp.register, + library=library, + dispatch_key=dispatch_key, + register_meta=register_meta, + register_impl=register_impl, + ) + instance = op_class( + library=library, + dispatch_key=dispatch_key, + register_meta=register_meta, + register_impl=register_impl, + ) + return instance.op + + def __init__( + self, + *, + library: torch.library.Library, + dispatch_key: str, + register_meta: bool, + register_impl: bool, + ): + name = self.name + fq_schema = f"{name}{self.signature}" + library.define(fq_schema) + self.library = library + self.cache_key_base = f"{library.ns}.{library.kind}::{name}" + self.op = _get_library_op(library, name) + + # The meta kernel can be provided by the selection machinery and + # does not require a tie-in to the kernel generator, which layers + # on top. + if register_meta: + library.impl(name, _get_meta_impl(self), "Meta") + + if register_impl: + library.impl(name, _create_impl_trampoline(self), dispatch_key) + + @abstractproperty + def name(self) -> str: + """Name of the operation.""" + ... + + @abstractproperty + def signature(self) -> str: + """PyTorch function signature. + + This excludes the name, which will come from the `name` property + and be prepended to make a full PyTorch schema. + """ + ... + + @abstractmethod + def select(self, sel: "KernelSelection"): + """Performs kernel selection. + + This method has three purposes: + + 1. Selects which kernel specialization is needed based on + arguments. + 2. Returns the meta tensor results of the operation, effectively + completing the transfer function from argument types to + result types. + 3. Sets additional metadata that the generate method can use. + + The `device="meta"` kernel implementation is composed completely by + invoking `select`. For implementation devices, `select` is called + for each invocation. The `generate` will be called subsequently if + the kernel needs to be generated. + """ + ... + + @abstractmethod + def generate(self, ksel: "KernelSelection", kb: "KernelBuilder"): + """Generates a kernel based on the `KernelSelection`. + + This method should generate IR into the given `KernelBuilder`. It + can do so by consulting any state set on the `KernelSelection`. + Each `KernelSelection.args` corresponds to `KernelBuilder.args`. + Unless if the argument was set as `is_ir_arg=False`, the argument + will be a `Value`. Otherwise, it will be `None`. It is recommended + to use `KernelBuilder.arg(n)` to access. + + Generation should conclude with a call to `KernelBuilder.yield_results`. + """ + ... + + +class KernelSelection: + """Represents a selected kernel based on a concrete signature. + + The `CustomOp.select` method must yield an instance of this, and + it will be done for every invocation. At this point, the kernel + has not yet been generated, but we have selected a generation + strategy based on a concrete signature. + + This mechanism also serves as the means for servicing `meta` + registrations because it implicitly computes everything needed + (i.e. shapes, etc). + """ + + __slots__ = [ + "args", + "arg_descs", + "op", + "result_descs", + "variant", + ] + + def __init__(self, op: CustomOp, args: list[Any]): + self.op = op + self.args = args + self.arg_descs: list[Optional[ArgDescriptor]] = len(args) * [None] + self.result_descs: list[ArgDescriptor] = [] + self.variant: str = "default" + + def generate_meta_returns(self) -> Any: + results = [d.generate_meta() for d in self.result_descs] + if len(results) == 1: + return results[0] + else: + return tuple(results) + + @property + def spec_key(self) -> str: + arg_keys = ",".join(d.spec_key for d in self.arg_descs) + return_keys = ",".join(d.spec_key for d in self.result_descs) + return f"{self.op.cache_key_base}::{self.variant}({arg_keys})->({return_keys})" + + def arg_tensor(self, arg: int) -> "TensorArg": + """Declares an argument to allow any ranked tensor and to specialize for each rank + and dtype. + + Returns the argument descriptor, which can be used to further inspect or constrain + the selection. It will default to allowing all dimensions to be dynamic. + """ + arg_descs = self.arg_descs + arg_value = self.args[arg] + assert arg_descs[arg] is None, f"Already constrained argument {arg}" + assert isinstance( + arg_value, Tensor + ), f"Argument type mismatch from Torch for {arg}: Expected tensor, got {type(arg_value)}" + arg_descs[arg] = desc = TensorArg(arg_value) + return desc + + def arg_int(self, arg: int) -> "IntArg": + """Declares an argument to be an integer value that can take any value. + + Returns the argument descriptor, which can be used to further inspect or constrain + the selection. + """ + arg_descs = self.arg_descs + arg_value = self.args[arg] + assert arg_descs[arg] is None, f"Already constrained argument {arg}" + assert isinstance( + arg_value, int + ), f"Argument type mismatch from Torch for {arg}: Expected int, got {type(arg_value)}" + arg_descs[arg] = desc = IntArg(arg_value) + return desc + + def return_tensor(self, t: Tensor) -> "TensorArg": + """Marks the next return value as a Tensor. + + By default, it will be rank and dtype specialized but have completely dynamic + dimensions. Dimensions can be further constrained by modifying the returned + descriptor. + """ + desc = TensorArg(t) + self.result_descs.append(desc) + return desc + + +class TensorArg: + __slots__ = [ + "t", + "spec_dims", + "is_ir_arg", + "maybe_tensor_value", + ] + + def __init__(self, t: Tensor): + self.t = t + # Any static dims that we are specializing. Defaults to all dynamic. + self.spec_dims: list[Optional[int]] = len(t.shape) * [None] + self.is_ir_arg = True + # All descriptors have an attribute to indicate their value + # as a tensor, and those that aren't are fixated to None. + # This is to enable fast lookup in the hot path of determining + # how to dispatch. + self.maybe_tensor_value: Tensor = t + + def generate_meta(self) -> Tensor: + t = self.t + if t.device == "meta": + return t + else: + return t.clone().detach().to("meta") + + @property + def spec_key(self) -> str: + """Generates a key that will be the same for all specializations.""" + t = self.t + return f"tensor[{len(t.shape)}:{str(t.dtype)}]<{self.spec_dims}>" + + @property + def mlir_type_asm(self) -> str: + t = self.t + try: + dtype_asm = TORCH_DTYPE_TO_IREE_TYPE_ASM[t.dtype] + except KeyError as e: + raise KeyError( + f"Unknown mapping of torch dtype {t.dtype} to MLIR " + f"(possibly missing in TORCH_DTYPE_TO_IREE_TYPE_ASM table)" + ) from e + dim_asm = "x".join(["?" if d is None else str(d) for d in self.spec_dims]) + spec = f"{dim_asm}x{dtype_asm}" if dim_asm else dtype_asm + return f"tensor<{spec}>" + + +class IntArg: + __slots__ = [ + "is_ir_arg", + "v", + "spec_value", + "maybe_tensor_value", + ] + + def __init__(self, v: int): + self.v = v + self.spec_value: Optional[Any] = None + self.is_ir_arg = True + # All descriptors have an attribute to indicate their value + # as a tensor, and those that aren't are fixated to None. + # This is to enable fast lookup in the hot path of determining + # how to dispatch. + self.maybe_tensor_value: Optional[Tensor] = None + + def generate_meta(self) -> int: + return self.v + + @property + def spec_key(self) -> str: + """Generates a key that will be the same for all specializations.""" + return f"int<{self.spec_value}>" + + @property + def mlir_type_asm(self) -> str: + # TODO: We can have individual kernels constrain this to a narrower + # type. + return "i64" + + +ArgDescriptor = Union[TensorArg, IntArg] + +############################################################################### +# KernelBuilder +# Helper object for constructing IR +############################################################################### + + +class KernelBuilder(ABC): + """Support class for building a kernel.""" + + def __init__( + self, + ksel: KernelSelection, + arg_bindings: list[Value], + *, + ip: InsertionPoint, + module_body: Block, + symbol_table: SymbolTable, + ): + self.ksel = ksel + self.arg_bindings = arg_bindings + self.ip = ip + self.module_body = module_body + self.symbol_table = symbol_table + + def arg_value(self, index: int) -> Value: + """Gets the concrete IR `Value` for the argument at `index`. + + This will assert if the corresponding argument was set as `is_ir_arg=False` + during kernel selection. + """ + try: + v = self.arg_bindings[index] + except IndexError as e: + raise AssertionError( + f"Out of range access to kernel arg. Expected 0..{len(self.arg_bindings)}. Got {index}" + ) from e + assert ( + v is not None + ), f"No `Value` is available for arg {index}: it was marked as `is_ir_arg=False` during kernel selection." + return v + + @abstractmethod + def yield_results(self, *results: Value): + """Yields results of the kernel computation.""" + ... + + +class FreeFuncKernelBuilder(KernelBuilder): + """Kernel builder that emits the body of the kernel into a free function. + + This is intended to be used when compiling a standalone module that will + be directly invoked by the runtime. Further variants exist that generate + into a func but also emit a call into another local context. + """ + + def __init__( + self, + ksel: KernelSelection, + *, + module_body: Block, + symbol_table: SymbolTable, + func_name: Optional[str] = None, + is_public: bool = True, + ): + self.module_op = module_body.owner + context = self.module_op.context + if func_name is None: + func_name = ksel.op.name + with context, Location.unknown(), InsertionPoint(module_body): + arg_types = [ + IrType.parse(d.mlir_type_asm) for d in ksel.arg_descs if d.is_ir_arg + ] + result_types = [IrType.parse(d.mlir_type_asm) for d in ksel.result_descs] + ftype = FunctionType.get(arg_types, result_types) + func_op = func_d.FuncOp(func_name, ftype) + if not is_public: + func_op.attributes["sym_visibility"] = StringAttr.get("private") + entry_block: Block = func_op.add_entry_block() + symbol_table.insert(func_op) + + # Map inputs to arg bindings, lining up with arguments that are elided. + block_arguments = list(entry_block.arguments) + block_arguments.reverse() + arg_bindings: list[Optional[Value]] = [] + for desc in ksel.arg_descs: + if desc.is_ir_arg: + arg_bindings.append(block_arguments.pop()) + else: + arg_bindings.append(None) + + super().__init__( + ksel, + arg_bindings, + ip=InsertionPoint(entry_block), + module_body=module_body, + symbol_table=symbol_table, + ) + + @staticmethod + def create_module( + ksel: KernelSelection, + *, + context: Optional[Context] = None, + func_name: Optional[str] = None, + is_public: bool = True, + ) -> "FreeFuncKernelBuilder": + """Short-cut to create a new module with a single function in one shot.""" + if context is None: + context = Context() + with context, Location.unknown(): + module_op = builtin_d.ModuleOp() + return FreeFuncKernelBuilder( + ksel, + module_body=module_op.body, + symbol_table=SymbolTable(module_op), + func_name=func_name, + is_public=is_public, + ) + + def yield_results(self, *results: Value): + """Yields results of the kernel computation.""" + with self.ip, Location.unknown(): + func_d.ReturnOp(results) + + +############################################################################### +# Private utilities +############################################################################### + + +def _get_library_op(library: torch.library.Library, name: str) -> Any: + ns = getattr(torch.ops, library.ns) + return getattr(ns, name) + + +def _get_meta_impl(op: CustomOp): + def meta(*args): + sel = KernelSelection(op, args) + op.select(sel) + if logger.isEnabledFor(logging.DEBUG): + logging.debug( + "Meta dispatch on %s for specialization %s", op.name, sel.spec_key + ) + return sel.generate_meta_returns() + + return meta + + +def _create_impl_trampoline(op: CustomOp): + # Import lazily when an implementation trampoline is requested to avoid + # circular dependency between base objects and eager runtime goo. + from .eager import ( + eager_dispatch, + ) + + def handler(*args): + ksel = KernelSelection(op, args) + op.select(ksel) + if logger.isEnabledFor(logging.DEBUG): + logging.debug( + "Dispatch on %s for specialization %s", op.name, ksel.spec_key + ) + return eager_dispatch(ksel) + + return handler diff --git a/python/shark_turbine/runtime/op_reg/compiler.py b/python/shark_turbine/runtime/op_reg/compiler.py new file mode 100644 index 000000000..c15ead608 --- /dev/null +++ b/python/shark_turbine/runtime/op_reg/compiler.py @@ -0,0 +1,137 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from dataclasses import dataclass +from timeit import default_timer +from typing import Any + +from iree.compiler.api import ( + Session, + Source, + Output, +) + +from iree.runtime import ( + VmContext, + VmFunction, + VmModule, +) + +from ...support.exceptions import ( + GeneralError, +) + +from ...support.logging import ( + runtime_logger as logger, +) + +from ..device import ( + Device, +) + +from .base import ( + FreeFuncKernelBuilder, + KernelSelection, +) + + +@dataclass(slots=True) +class KernelCompileConfig: + # Unique key for this kernel. + key: str + + # Compiler flags to pass. + flags: list[str] + + # Use the in-process compiler (default). Some compiler options are only + # available when invoked standalone/out-of-process, so this is allowed. + # Out-of-process can also be a useful debugging feature and may be + # globally controlled. + in_process: bool = True + + # Whether compiled for async invocations. + async_invocations: bool = False + + # Whether we compiled with layout specialization and can handle certain + # permutations of strided tensors. This is currently not supported but will + # be at some point. Having the option lets us annotate code paths that are + # NYI. + layout_specialized: bool = False + + # Arbitrary objects to keep alive as part of this config. This can include + # things like unbacked memory mappings, etc. + keep_alive: Any = None + + +# TODO: The cache should be more than just a simple dict. Can be persistent +KERNEL_CACHE: dict[str, tuple[VmContext, VmFunction, KernelCompileConfig]] = {} + + +def _testing_get_cache_size() -> int: + return len(KERNEL_CACHE) + + +def compile_standalone_kernel( + device: Device, ksel: KernelSelection, func_name: str = "main" +) -> tuple[VmContext, VmFunction, KernelCompileConfig]: + # Early exit on cache hit. + cache_key = f"{ksel.spec_key}::{device.type_cache_key}" + cache_hit = KERNEL_CACHE.get(cache_key) + if cache_hit is not None: + return cache_hit + + # Cache miss. + start = default_timer() + config = KernelCompileConfig(cache_key, list(device.compile_target_flags)) + kb = FreeFuncKernelBuilder.create_module(ksel, func_name=func_name) + ksel.op.generate(ksel, kb) + kb.module_op.verify() + module_asm = kb.module_op.get_asm( + binary=True, enable_debug_info=True, assume_verified=True + ) + generation_time = default_timer() - start + + if not config.in_process: + raise NotImplementedError("Out-of-process compilation not yet supported") + + # TODO: We could be caching the session per device type key. + # TODO: Create the source and get the module to build into from that vs + # reserializing (once issues are worked out for that). + start = default_timer() + session = Session() + session.set_flags(*config.flags) + inv = session.invocation() + source = Source.wrap_buffer(session, module_asm) + output = Output.open_membuffer() + inv.enable_console_diagnostics() + inv.parse_source(source) + if not inv.execute(): + # TODO: Capture diagnostics and report. + raise GeneralError(f"Kernel compilation failed. See diagnostics.") + inv.output_vm_bytecode(output) + mapped_memory = output.map_memory() + compilation_time = default_timer() - start + + # Load. + vm_instance = device.vm_instance + vm_module = VmModule.copy_buffer(vm_instance, mapped_memory) + # TODO: We should be able to wrap the buffer as below but there are some + # subtle ref-counting/shutdown sequencing issues that need to be resolved. + # vm_module = VmModule.wrap_buffer(vm_instance, mapped_memory) + vm_context = VmContext(vm_instance, [device.create_hal_module(), vm_module]) + main_function = vm_module.lookup_function("main") + + logger.debug( + "Compiled kernel %s: mlir=%d bytes, vmfb=%d bytes (generation: %sms, compilation: %sms)", + cache_key, + len(module_asm), + len(mapped_memory), + generation_time * 1000, + compilation_time * 1000, + ) + cache_hit = (vm_context, main_function, config) + KERNEL_CACHE[cache_key] = cache_hit + return cache_hit diff --git a/python/shark_turbine/runtime/op_reg/eager.py b/python/shark_turbine/runtime/op_reg/eager.py new file mode 100644 index 000000000..629118096 --- /dev/null +++ b/python/shark_turbine/runtime/op_reg/eager.py @@ -0,0 +1,132 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Custom op integration into the eager executor.""" + +from timeit import default_timer +from typing import Optional + +import torch + +from iree.runtime import ( + HalBufferView, + VmVariantList, +) + +from ...support.exceptions import ( + UnsupportedTypeError, +) + +from ...support.logging import ( + runtime_logger as logger, +) + +from ..device import ( + Device, + lookup_device_from_torch, +) + +from .base import ( + KernelSelection, +) + +from .compiler import ( + compile_standalone_kernel, +) + +__all__ = [ + "eager_dispatch", +] + + +def eager_dispatch(ksel: KernelSelection): + """Main entry-point for handling dispatch of a selected kernel via a generator.""" + # Scan arg descs and decide on a compute device. + # For now, we compute on the first device that we support. + # This is very simplisitic and will need to be extended for multi-device, etc. + device: Optional[Device] = None + torch_device: Optional[torch.device] = None + for arg_desc in ksel.arg_descs: + if not arg_desc.is_ir_arg: + continue + tensor_arg = arg_desc.maybe_tensor_value + if tensor_arg is None: + continue + torch_device = tensor_arg.device + device = lookup_device_from_torch(torch_device) + if device is not None: + break + + # Default to CPU. + if device is None: + logger.debug("Fallback to CPU device due to no supported device in arguments") + torch_device = torch.device("cpu") + device = lookup_device_from_torch(torch_device) + + # Compile. + # TODO: We can do compilation asynchronously with the device movement + vm_context, vm_f, config = compile_standalone_kernel(device, ksel) + + # Build the concrete args, issuing device movement as necessary. + arg_list = VmVariantList(len(ksel.arg_descs)) + for arg_desc in ksel.arg_descs: + if not arg_desc.is_ir_arg: + continue + tensor_arg = arg_desc.maybe_tensor_value + # Handle non-tensor args. + if tensor_arg is None: + scalar_value = arg_desc.v + if isinstance(scalar_value, int): + arg_list.push_int(scalar_value) + elif isinstance(scalar_value, float): + arg_list.push_float(scalar_value) + else: + raise UnsupportedTypeError(type(scalar_value)) + continue + # Tensor arg. + if tensor_arg.device != torch_device: + # TODO: If the source and target device are both known to us, + # we can do this "in house" vs asking torch to do it. + tensor_arg = tensor_arg.to(torch_device) + if not tensor_arg.is_contiguous(): + if config.layout_specialized: + raise NotImplementedError( + "Layout specialized kernels are not yet implemented" + ) + tensor_arg = tensor_arg.contiguous() + # Since we know we are on the same device, we can use the unsafe + # import_torch_tensor. + arg_list.push_ref(device.import_torch_tensor(tensor_arg)) + + if config.async_invocations: + raise NotImplementedError("Async execution not yet implemented") + + # Invoke. + ret_list = VmVariantList(len(ksel.result_descs)) + start = default_timer() + vm_context.invoke(vm_f, arg_list, ret_list) + invoke_time = default_timer() - start + logger.debug("Kernel invocation %s: %sms", config.key, invoke_time * 1000) + + # Unpack results. + results = [] + + for i, result_desc in enumerate(ksel.result_descs): + meta_tensor_value = result_desc.maybe_tensor_value + if meta_tensor_value is None: + # Scalar return. + raise NotImplementedError("CustomOp scalar return") + + # Tensor return. The meta tensor value already has the correct torch + # dtype and shape, so we just need to export and return it for the + # appropriate device. + bv: HalBufferView = HalBufferView.__iree_vm_cast__(ret_list.get_as_ref(i)) + results.append(device.export_torch_tensor(bv, meta_tensor_value)) + + if len(results) == 1: + return results[0] + else: + return tuple(results) diff --git a/python/shark_turbine/support/conversions.py b/python/shark_turbine/support/conversions.py new file mode 100644 index 000000000..132cfd428 --- /dev/null +++ b/python/shark_turbine/support/conversions.py @@ -0,0 +1,118 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Any, Callable + +import numpy as np +import torch + +from iree.runtime import ( + HalElementType, +) + +from ..importers.fx_importer import ( + TORCH_DTYPE_TO_MLIR_TYPE_ASM, +) + +from .exceptions import ( + UnknownDTypeError, +) + +from .ir_imports import ( + BF16Type, + ComplexType, + F16Type, + F32Type, + F64Type, + IntegerType, + IrType, +) + +# We need the inverse of the TORCH_DTYPE_TO_MLIR_TYPE_ASM table. +MLIR_TYPE_ASM_TO_TORCH_DTYPE = {v: k for k, v in TORCH_DTYPE_TO_MLIR_TYPE_ASM.items()} + +# When emitting constants, we have to create native IREE types. +TORCH_DTYPE_TO_IREE_TYPE: dict[torch.dtype, Callable[[], IrType]] = { + torch.float16: lambda: F16Type.get(), + torch.bfloat16: lambda: BF16Type.get(), + torch.float32: lambda: F32Type.get(), + torch.float64: lambda: F64Type.get(), + torch.uint8: lambda: IntegerType.get_signless(8), + torch.int8: lambda: IntegerType.get_signless(8), + torch.int16: lambda: IntegerType.get_signless(16), + torch.int32: lambda: IntegerType.get_signless(32), + torch.int64: lambda: IntegerType.get_signless(64), + torch.bool: lambda: IntegerType.get_signless(1), + torch.qint8: lambda: IntegerType.get_signless(8), + torch.quint8: lambda: IntegerType.get_signless(8), + torch.complex32: lambda: ComplexType.get(F16Type.get()), + torch.complex64: lambda: ComplexType.get(F32Type.get()), + torch.complex128: lambda: ComplexType.get(F64Type.get()), +} + +TORCH_DTYPE_TO_IREE_TYPE_ASM = { + torch.float16: "f16", + torch.bfloat16: "bf16", + torch.float32: "f32", + torch.float64: "f64", + torch.uint8: "i8", + torch.int8: "i8", + torch.int16: "i16", + torch.int32: "i32", + torch.int64: "i64", + torch.bool: "i1", + torch.qint8: "i8", + torch.quint8: "i8", + torch.complex32: "complex", + torch.complex64: "complex", + torch.complex128: "complex", +} + +DTYPE_TO_ELEMENT_TYPE: dict[torch.dtype, HalElementType] = { + torch.float16: HalElementType.FLOAT_16, + torch.bfloat16: HalElementType.BFLOAT_16, + torch.float32: HalElementType.FLOAT_32, + torch.float64: HalElementType.FLOAT_64, + torch.uint8: HalElementType.UINT_8, + torch.int8: HalElementType.SINT_8, + torch.int16: HalElementType.SINT_16, + torch.int32: HalElementType.SINT_32, + torch.int64: HalElementType.SINT_64, + torch.bool: HalElementType.BOOL_8, + torch.qint8: HalElementType.OPAQUE_8, + torch.quint8: HalElementType.OPAQUE_8, + torch.complex64: HalElementType.COMPLEX_64, + torch.complex128: HalElementType.COMPLEX_128, +} + + +def dtype_to_element_type(dtype) -> HalElementType: + try: + return DTYPE_TO_ELEMENT_TYPE[dtype] + except KeyError: + raise UnknownDTypeError(dtype) + + +TORCH_DTYPE_TO_NUMPY = { + torch.float16: np.dtype("f2"), + torch.float32: np.dtype("f4"), + torch.float64: np.dtype("f8"), + torch.uint8: np.dtype("u1"), + torch.int8: np.dtype("i1"), + torch.int16: np.dtype("i2"), + torch.int32: np.dtype("i4"), + torch.int64: np.dtype("i8"), + torch.bool: np.dtype("?"), + torch.complex64: np.dtype("c8"), + torch.complex128: np.dtype("c16"), +} + + +def torch_dtype_to_numpy(torch_dtype: torch.dtype) -> Any: + try: + return TORCH_DTYPE_TO_NUMPY[torch_dtype] + except KeyError: + raise UnknownDTypeError(torch_dtype) diff --git a/python/shark_turbine/support/exceptions.py b/python/shark_turbine/support/exceptions.py index 240703a9f..be2c2a633 100644 --- a/python/shark_turbine/support/exceptions.py +++ b/python/shark_turbine/support/exceptions.py @@ -23,6 +23,18 @@ def __init__(self): ) +class UnsupportedTorchDeviceError(Exception): + def __init__(self, torch_device): + super().__init__( + f"Attempt to use turbine with a torch.device that is not supported by this build: {torch_device}" + ) + + +class UnsupportedTypeError(Exception): + def __init__(self, t: type, usage: str): + super().__init__(f"Python type {t} is not supported for {usage}") + + class ApiSequencingError(Exception): ... diff --git a/python/shark_turbine/aot/support/ir_imports.py b/python/shark_turbine/support/ir_imports.py similarity index 97% rename from python/shark_turbine/aot/support/ir_imports.py rename to python/shark_turbine/support/ir_imports.py index d7d7a4adc..a15e5fe12 100644 --- a/python/shark_turbine/aot/support/ir_imports.py +++ b/python/shark_turbine/support/ir_imports.py @@ -48,6 +48,7 @@ ) from iree.compiler.dialects import ( + builtin as builtin_d, flow as flow_d, func as func_d, util as util_d, diff --git a/python/shark_turbine/support/logging.py b/python/shark_turbine/support/logging.py new file mode 100644 index 000000000..3eec33091 --- /dev/null +++ b/python/shark_turbine/support/logging.py @@ -0,0 +1,9 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging + +runtime_logger = logging.getLogger("shark_turbine.runtime") diff --git a/tests/dynamo/tensor_test.py b/tests/dynamo/tensor_test.py index d736c153f..fcd406608 100644 --- a/tests/dynamo/tensor_test.py +++ b/tests/dynamo/tensor_test.py @@ -12,7 +12,8 @@ import torch # Public API imports. -from shark_turbine.dynamo import Device, TurbineMode, DeviceTensor +from shark_turbine.runtime import Device +from shark_turbine.dynamo import TurbineMode, DeviceTensor class TensorTest(unittest.TestCase): diff --git a/tests/dynamo/device_test.py b/tests/runtime/device_test.py similarity index 61% rename from tests/dynamo/device_test.py rename to tests/runtime/device_test.py index b9b1a2288..c37750cca 100644 --- a/tests/dynamo/device_test.py +++ b/tests/runtime/device_test.py @@ -8,14 +8,19 @@ import unittest import threading +import torch + +from iree.runtime import HalElementType + # Public API imports. -from shark_turbine.dynamo import ( +from shark_turbine.runtime import ( Device, ) # Internals. -from shark_turbine.dynamo.device import ( +from shark_turbine.runtime.device import ( _CURRENT_THREAD, + get_device_from_torch, ) from shark_turbine.support.exceptions import * @@ -86,6 +91,37 @@ def run_t2(): self.assertIsNot(devices[0], devices[1]) +# CPU is always available so we can enable this unconditionally. +class TorchCPUInterop(unittest.TestCase): + def testFromTorchDevice(self): + torch_device = torch.device("cpu") + device1 = get_device_from_torch(torch_device) + print(device1) + self.assertIsNotNone(device1) + device2 = get_device_from_torch(torch_device) + self.assertIs(device1, device2) + + def testCpuDeviceCacheKey(self): + d = get_device_from_torch(torch.device("cpu")) + self.assertEqual(d.instance_cache_key, "local-task") + self.assertEqual(d.type_cache_key, "local-task") + + def testImportExportTorchTensor(self): + d = get_device_from_torch(torch.device("cpu")) + cpu_tensor = torch.tensor([1, 2, 3], dtype=torch.int32, device="cpu") + bv = d.import_torch_tensor(cpu_tensor) + print(bv) + self.assertEqual(bv.shape, [3]) + self.assertEqual(bv.element_type, HalElementType.SINT_32) + meta_tensor = cpu_tensor.to(device="meta") + readback_tensor = d.export_torch_tensor(bv, meta_tensor) + torch.testing.assert_close(cpu_tensor, readback_tensor) + + def testCompilerFlags(self): + d = get_device_from_torch(torch.device("cpu")) + self.assertIn("--iree-hal-target-backends=llvm-cpu", d.compile_target_flags) + + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main() diff --git a/tests/runtime/op_reg/kernel_reg_test.py b/tests/runtime/op_reg/kernel_reg_test.py new file mode 100644 index 000000000..10662db14 --- /dev/null +++ b/tests/runtime/op_reg/kernel_reg_test.py @@ -0,0 +1,87 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import unittest + +import torch + +from shark_turbine.runtime.op_reg import * + +from shark_turbine.runtime.op_reg.compiler import _testing_get_cache_size + + +class KernelRegTest(unittest.TestCase): + def testSimple(self): + @CustomOp.register + class identity(CustomOp): + name = "test_identity" + signature = "(Tensor self) -> Tensor" + + def select(self, ksel: KernelSelection): + x = ksel.arg_tensor(0) + ksel.return_tensor(x.t) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + # This just yields the IR value of kernel input as the output. + # Effectively in eager mode, this is a `return` from the kernel + # function. + kb.yield_results(kb.arg_bindings[0]) + + self.assertIsNotNone(torch.ops.turbine.test_identity) + + start_compile_count = _testing_get_cache_size() + + # Make sure that the meta registration works. + t = torch.tensor([[1, 2, 3]], dtype=torch.int32, device="meta") + result = identity(t) + self.assertListEqual(list(result.shape), [1, 3]) + self.assertEqual(result.dtype, torch.int32) + self.assertEqual(t.device.type, "meta") + # Meta dispatch should not trigger compilation. + self.assertEqual(_testing_get_cache_size(), start_compile_count) + + # Make sure that CPU dispatch works. + t = torch.tensor([[1, 2, 3]], dtype=torch.int32) + result = identity(t) + print("CPU result:", result) + torch.testing.assert_close(result, t) + # Novel execution should compile a new kernel. + self.assertEqual(_testing_get_cache_size(), start_compile_count + 1) + + # Second run of the same kernel should serve out of cache. + result = identity(t) + torch.testing.assert_close(result, t) + # Repeated execution should use a cached kernel. + self.assertEqual(_testing_get_cache_size(), start_compile_count + 1) + + # It should recompile for different dtype. + t = torch.tensor([[1, 2, 3]], dtype=torch.int16) + result = identity(t) + print("CPU result:", result) + torch.testing.assert_close(result, t) + # Novel execution should compile a new kernel. + self.assertEqual(_testing_get_cache_size(), start_compile_count + 2) + + # It should recompile for different rank. + t = torch.tensor([1, 2, 3], dtype=torch.int16) + result = identity(t) + print("CPU result:", result) + torch.testing.assert_close(result, t) + # Novel execution should compile a new kernel. + self.assertEqual(_testing_get_cache_size(), start_compile_count + 3) + + # It should serve out of cache for same-rank but different dims. + t = torch.tensor([1, 2, 3, 4, 5], dtype=torch.int16) + result = identity(t) + print("CPU result:", result) + torch.testing.assert_close(result, t) + self.assertEqual(_testing_get_cache_size(), start_compile_count + 3) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main()