Skip to content

Commit

Permalink
Add AOT support.
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident committed Dec 22, 2023
1 parent 0d45767 commit 92c8035
Show file tree
Hide file tree
Showing 9 changed files with 436 additions and 47 deletions.
4 changes: 3 additions & 1 deletion python/shark_turbine/aot/support/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,10 @@ def create_tensor_global(
array = np.array(detached_tensor)
# We know that a Numpy array is a ReadableBuffer so ignore type error.
contents = memoryview(array) # type: ignore
shape_desc = "_".join([str(d) for d in t.shape])
blob_name = f"torch_tensor_{shape_desc}_{str(t.dtype)}"
elements_attr = DenseResourceElementsAttr.get_from_buffer(
contents, "from_py", tensor_type
contents, blob_name, tensor_type
)
ir_attrs["initial_value"] = elements_attr

Expand Down
23 changes: 16 additions & 7 deletions python/shark_turbine/dynamo/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, context: Context):
self.torch_type_to_native
)

def torch_type_to_native(self, torch_type: IrType) -> IrType:
def torch_type_to_native(self, torch_type: IrType, signless: bool = True) -> IrType:
"""Converts a presumed torch type to a corresponding native type.
This mirrors the type conversion in torch-mlir's BackendTypeConversion.cpp.
Expand All @@ -56,6 +56,8 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType:
!torch.float -> f64
!torch.bool -> i1
!torch.vtensor -> tensor
If `signless=False`, then integer types will retain their signs.
"""
# We don't presently have API support for introspecting torch type,
# and even if we did, it is likely that this is more efficient.
Expand All @@ -66,7 +68,11 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType:
if name == "bool":
return IntegerType.get_signless(1)
if name == "int":
return IntegerType.get_signless(64)
return (
IntegerType.get_signless(64)
if signless
else IntegerType.get_signed(64)
)
elif name == "float":
return F64Type.get()
elif name == "vtensor":
Expand All @@ -75,22 +81,25 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType:
dim_list_str, dtype_str = tm.groups()
dim_list = parse_tensor_dim_list(dim_list_str)
dtype = self.convert_torch_element_type_to_native(
IrType.parse(dtype_str)
IrType.parse(dtype_str), signless=signless
)
# TODO: Eliminate RankedTensorType dependence on Location.
# See: https://github.com/nod-ai/SHARK-Turbine/issues/145
with Location.unknown():
return RankedTensorType.get(dim_list, dtype)
raise TypeError(f"Unsupported torch type conversion for {torch_type}")

def convert_torch_element_type_to_native(self, torch_type: IrType) -> IrType:
def convert_torch_element_type_to_native(
self, torch_type: IrType, signless: bool = True
) -> IrType:
# Torch uses the builtin type hierarchy of IntegerType and FloatType
# to represent dtypes. These are mostly the same, but it always uses
# signed IntegerTypes which we must convert to signless for the native
# type system.
if IntegerType.isinstance(torch_type):
signed_int_type = IntegerType(torch_type)
return IntegerType.get_signless(signed_int_type.width)
if signless:
if IntegerType.isinstance(torch_type):
signed_int_type = IntegerType(torch_type)
return IntegerType.get_signless(signed_int_type.width)
return torch_type

def materialize_native_to_torch(
Expand Down
4 changes: 2 additions & 2 deletions python/shark_turbine/ops/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _emit_tensor_trace(kb: KernelBuilder, key: str, ts: list[Value]):
@CustomOp.register(library=IREE_LIBRARY)
class trace_tensor(CustomOp):
name = "trace_tensor"
signature = "(str trace_key, Tensor t) -> ()"
signature = "(str trace_key, Tensor tensor) -> ()"

def select(self, ksel: KernelSelection):
ksel.attr_str(0)
Expand All @@ -58,7 +58,7 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
@CustomOp.register(library=IREE_LIBRARY)
class trace_tensors(CustomOp):
name = "trace_tensors"
signature = "(str trace_key, Tensor[] self) -> ()"
signature = "(str trace_key, Tensor[] tensors) -> ()"

def select(self, ksel: KernelSelection):
ksel.attr_str(0)
Expand Down
2 changes: 1 addition & 1 deletion python/shark_turbine/runtime/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def _device_import_torch_tensor_cpu(device: Device, t: torch.Tensor) -> HalBuffe
memory_type=MemoryType.DEVICE_LOCAL,
allowed_usage=BufferUsage.DEFAULT,
device=hal_device,
buffer=t.numpy(),
buffer=t.detach().numpy(),
element_type=element_type,
)
return bv
Expand Down
127 changes: 92 additions & 35 deletions python/shark_turbine/runtime/op_reg/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
dispatcher.
"""

from typing import Any, Callable, Optional, Type, Union
from typing import Any, Callable, Optional, Sequence, Type, Union

from abc import ABC, abstractmethod, abstractproperty
import functools
Expand Down Expand Up @@ -67,6 +67,14 @@ def def_library(ns) -> torch.library.Library:
return torch.library.Library(ns, "DEF")


def default_dispatch_keys() -> list[str]:
# TODO: Dynamically determine what devices to register against.
# Note that we have to register against specific keys instead of the
# fallback, as fallback is too broad and breaks certain elements of
# fx tracing.
return ["CPU"]


# All such custom kernels are registered in the 'turbine' library/namespace.
# We also allow extending existing libraries outside of this, but that is
# the non default case.
Expand All @@ -81,7 +89,7 @@ def register(
op_class: Optional[Type["CustomOp"]] = None,
*,
library: torch.library.Library = TURBINE_LIBRARY,
dispatch_key: str = "",
dispatch_key: Union[str, Sequence[str], None] = None,
register_meta: bool = True,
register_impl: bool = True,
) -> Callable:
Expand Down Expand Up @@ -120,7 +128,7 @@ def __init__(
self,
*,
library: torch.library.Library,
dispatch_key: str,
dispatch_key: Union[str, Sequence[str], None],
register_meta: bool,
register_impl: bool,
):
Expand All @@ -138,7 +146,15 @@ def __init__(
library.impl(name, _get_meta_impl(self), "Meta")

if register_impl:
library.impl(name, _create_impl_trampoline(self), dispatch_key)
if dispatch_key is None:
dispatch_key = default_dispatch_keys()
elif isinstance(dispatch_key, str):
dispatch_key = [dispatch_key]
for k in dispatch_key:
library.impl(name, _create_impl_trampoline(self), k)

fq_name = f"{library.ns}.{name}"
ALL_CUSTOM_OP_REGS[fq_name] = self

@abstractproperty
def name(self) -> str:
Expand Down Expand Up @@ -190,7 +206,12 @@ def generate(self, ksel: "KernelSelection", kb: "KernelBuilder"):
...


class KernelSelection:
# All instantiated CustomOp instances, keyed by fully qualified name. This is
# used by the AOT compiler to expand custom ops that were captured in a trace.
ALL_CUSTOM_OP_REGS: dict[str, CustomOp] = {}


class KernelSelection(ABC):
"""Represents a selected kernel based on a concrete signature.
The `CustomOp.select` method must yield an instance of this, and
Expand All @@ -204,17 +225,15 @@ class KernelSelection:
"""

__slots__ = [
"args",
"arg_descs",
"op",
"result_descs",
"variant",
]

def __init__(self, op: CustomOp, args: list[Any]):
def __init__(self, op: CustomOp, arg_arity: int):
self.op = op
self.args = args
self.arg_descs: list[Optional[ArgDescriptor]] = len(args) * [None]
self.arg_descs: list[Optional[ArgDescriptor]] = arg_arity * [None]
self.result_descs: list[ArgDescriptor] = []
self.variant: str = "default"

Expand All @@ -237,8 +256,11 @@ def __repr__(self):

def generate_meta_returns(self) -> Any:
results = [d.generate_meta() for d in self.result_descs]
if len(results) == 1:
arity = len(results)
if arity == 1:
return results[0]
elif arity == 0:
return None
else:
return tuple(results)

Expand All @@ -255,13 +277,67 @@ def spec_key(self) -> str:
f"Error generating spec_key from:\n{textwrap.indent(repr(self), ' ')}"
) from e

@abstractmethod
def arg_tensor(self, arg: int) -> "TensorArg":
"""Declares an argument to allow any ranked tensor and to specialize for each rank
and dtype.
Returns the argument descriptor, which can be used to further inspect or constrain
the selection. It will default to allowing all dimensions to be dynamic.
"""
...

@abstractmethod
def arg_tensor_list(self, arg: int) -> "TensorListArg":
"""Declares an argument to accept a list of tensors which will be specialized
for the list size and each rank/dtype.
Returns the argument descriptor, which can be used to further inspect or constrain
the selection. It will default to allowing all dimensions to be dynamic.
"""
...

@abstractmethod
def arg_int(self, arg: int) -> "IntArg":
"""Declares an argument to be an integer value that can take any value.
Returns the argument descriptor, which can be used to further inspect or constrain
the selection.
"""
...

@abstractmethod
def attr_str(self, arg: int) -> "AttrArg":
"""Declares an argument to be a string attribute.
Such arguments are not materialized in the IR as Values but may be used to
generate the IR. In AOT contexts, they must be derived from static values.
"""
...

@abstractmethod
def return_tensor(self, t: Tensor) -> "TensorArg":
"""Marks the next return value as a Tensor.
By default, it will be rank and dtype specialized but have completely dynamic
dimensions. Dimensions can be further constrained by modifying the returned
descriptor.
"""
...


class EagerKernelSelection(KernelSelection):
"""Kernel selection specialized for eager arguments."""

__slots__ = [
"args",
]

def __init__(self, op: CustomOp, args: list[Any]):
super().__init__(op, len(args))
self.args = args

def arg_tensor(self, arg: int) -> "TensorArg":
arg_descs = self.arg_descs
arg_value = self.args[arg]
assert arg_descs[arg] is None, f"Already constrained argument {arg}"
Expand All @@ -272,12 +348,6 @@ def arg_tensor(self, arg: int) -> "TensorArg":
return desc

def arg_tensor_list(self, arg: int) -> "TensorListArg":
"""Declares an argument to accept a list of tensors which will be specialized
for the list size and each rank/dtype.
Returns the argument descriptor, which can be used to further inspect or constrain
the selection. It will default to allowing all dimensions to be dynamic.
"""
arg_descs = self.arg_descs
arg_value = self.args[arg]
assert arg_descs[arg] is None, f"Already constrained argument {arg}"
Expand All @@ -288,11 +358,6 @@ def arg_tensor_list(self, arg: int) -> "TensorListArg":
return desc

def arg_int(self, arg: int) -> "IntArg":
"""Declares an argument to be an integer value that can take any value.
Returns the argument descriptor, which can be used to further inspect or constrain
the selection.
"""
arg_descs = self.arg_descs
arg_value = self.args[arg]
assert arg_descs[arg] is None, f"Already constrained argument {arg}"
Expand All @@ -303,11 +368,6 @@ def arg_int(self, arg: int) -> "IntArg":
return desc

def attr_str(self, arg: int) -> "AttrArg":
"""Declares an argument to be a string attribute.
Such arguments are not materialized in the IR as Values but may be used to
generate the IR. In AOT contexts, they must be derived from static values.
"""
arg_descs = self.arg_descs
arg_value = self.args[arg]
assert arg_descs[arg] is None, f"Already constrained argument {arg}"
Expand All @@ -318,12 +378,6 @@ def attr_str(self, arg: int) -> "AttrArg":
return desc

def return_tensor(self, t: Tensor) -> "TensorArg":
"""Marks the next return value as a Tensor.
By default, it will be rank and dtype specialized but have completely dynamic
dimensions. Dimensions can be further constrained by modifying the returned
descriptor.
"""
desc = TensorArg(t)
self.result_descs.append(desc)
return desc
Expand Down Expand Up @@ -538,6 +592,7 @@ def __init__(
self.ip = ip
self.module_body = module_body
self.symbol_table = symbol_table
self.yielded = False

def arg_value(self, index: int) -> Union[list[Value], Value]:
"""Gets the concrete IR `Value` for the argument at `index`.
Expand Down Expand Up @@ -629,7 +684,7 @@ def __init__(
for desc in ksel.arg_descs:
arity = desc.ir_arity
if not desc.is_list:
if desc.ir_arity == 1:
if arity == 1:
arg_bindings.append(block_arguments[block_arg_index])
block_arg_index += 1
else:
Expand Down Expand Up @@ -671,8 +726,10 @@ def create_module(

def yield_results(self, *results: Value):
"""Yields results of the kernel computation."""
assert not self.yielded, "yield_results has already been called"
with self.ip, Location.unknown():
func_d.ReturnOp(results)
self.yielded = True


###############################################################################
Expand All @@ -687,7 +744,7 @@ def _get_library_op(library: torch.library.Library, name: str) -> Any:

def _get_meta_impl(op: CustomOp):
def meta(*args):
sel = KernelSelection(op, args)
sel = EagerKernelSelection(op, args)
op.select(sel)
if logger.isEnabledFor(logging.DEBUG):
logging.debug(
Expand All @@ -706,7 +763,7 @@ def _create_impl_trampoline(op: CustomOp):
)

def handler(*args):
ksel = KernelSelection(op, args)
ksel = EagerKernelSelection(op, args)
op.select(ksel)
if logger.isEnabledFor(logging.DEBUG):
logging.debug(
Expand Down
22 changes: 22 additions & 0 deletions python/shark_turbine/support/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,28 @@
torch.complex128: lambda: ComplexType.get(F64Type.get()),
}

TORCH_DTYPE_TO_SIGNED_MLIR_TYPE_ASM = {
torch.float16: "f16",
torch.bfloat16: "bf16",
torch.float32: "f32",
torch.float64: "f64",
torch.uint8: "ui8",
torch.int8: "si8",
torch.int16: "si16",
torch.int32: "si32",
torch.int64: "si64",
torch.bool: "i1",
torch.qint8: "si8",
torch.quint8: "ui8",
torch.complex32: "complex<f16>",
torch.complex64: "complex<f32>",
torch.complex128: "complex<f64>",
}

SIGNED_MLIR_TYPE_ASM_TO_TORCH_DTYPE = dict(
(v, k) for k, v in TORCH_DTYPE_TO_SIGNED_MLIR_TYPE_ASM.items()
)

TORCH_DTYPE_TO_IREE_TYPE_ASM = {
torch.float16: "f16",
torch.bfloat16: "bf16",
Expand Down
Loading

0 comments on commit 92c8035

Please sign in to comment.