Skip to content

Commit

Permalink
Start adding the library of custom ops. (#296)
Browse files Browse the repository at this point in the history
Defines a couple of IREE builtins:

* `ops.iree.trace_tensor`
* `ops.iree.trace_tensors`

Extends the infra for better support:

* Adds support for `Tensor[]` arguments to custom ops.
  • Loading branch information
stellaraccident authored Feb 2, 2024
1 parent 6f67a97 commit a176f99
Show file tree
Hide file tree
Showing 18 changed files with 1,018 additions and 110 deletions.
4 changes: 3 additions & 1 deletion python/shark_turbine/aot/support/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,10 @@ def create_tensor_global(
array = np.array(detached_tensor)
# We know that a Numpy array is a ReadableBuffer so ignore type error.
contents = memoryview(array) # type: ignore
shape_desc = "_".join([str(d) for d in t.shape])
blob_name = f"torch_tensor_{shape_desc}_{str(t.dtype)}"
elements_attr = DenseResourceElementsAttr.get_from_buffer(
contents, "from_py", tensor_type
contents, blob_name, tensor_type
)
ir_attrs["initial_value"] = elements_attr

Expand Down
23 changes: 16 additions & 7 deletions python/shark_turbine/dynamo/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, context: Context):
self.torch_type_to_native
)

def torch_type_to_native(self, torch_type: IrType) -> IrType:
def torch_type_to_native(self, torch_type: IrType, signless: bool = True) -> IrType:
"""Converts a presumed torch type to a corresponding native type.
This mirrors the type conversion in torch-mlir's BackendTypeConversion.cpp.
Expand All @@ -56,6 +56,8 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType:
!torch.float -> f64
!torch.bool -> i1
!torch.vtensor -> tensor
If `signless=False`, then integer types will retain their signs.
"""
# We don't presently have API support for introspecting torch type,
# and even if we did, it is likely that this is more efficient.
Expand All @@ -66,7 +68,11 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType:
if name == "bool":
return IntegerType.get_signless(1)
if name == "int":
return IntegerType.get_signless(64)
return (
IntegerType.get_signless(64)
if signless
else IntegerType.get_signed(64)
)
elif name == "float":
return F64Type.get()
elif name == "vtensor":
Expand All @@ -75,22 +81,25 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType:
dim_list_str, dtype_str = tm.groups()
dim_list = parse_tensor_dim_list(dim_list_str)
dtype = self.convert_torch_element_type_to_native(
IrType.parse(dtype_str)
IrType.parse(dtype_str), signless=signless
)
# TODO: Eliminate RankedTensorType dependence on Location.
# See: https://github.com/nod-ai/SHARK-Turbine/issues/145
with Location.unknown():
return RankedTensorType.get(dim_list, dtype)
raise TypeError(f"Unsupported torch type conversion for {torch_type}")

def convert_torch_element_type_to_native(self, torch_type: IrType) -> IrType:
def convert_torch_element_type_to_native(
self, torch_type: IrType, signless: bool = True
) -> IrType:
# Torch uses the builtin type hierarchy of IntegerType and FloatType
# to represent dtypes. These are mostly the same, but it always uses
# signed IntegerTypes which we must convert to signless for the native
# type system.
if IntegerType.isinstance(torch_type):
signed_int_type = IntegerType(torch_type)
return IntegerType.get_signless(signed_int_type.width)
if signless:
if IntegerType.isinstance(torch_type):
signed_int_type = IntegerType(torch_type)
return IntegerType.get_signless(signed_int_type.width)
return torch_type

def materialize_native_to_torch(
Expand Down
7 changes: 7 additions & 0 deletions python/shark_turbine/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright 2023 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from . import iree
74 changes: 74 additions & 0 deletions python/shark_turbine/ops/iree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2023 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Custom ops for built-in IREE functionality."""

from ..support.ir_imports import (
RankedTensorType,
StringAttr,
Value,
flow_d,
tensor_d,
)

from ..runtime.op_reg import (
CustomOp,
KernelBuilder,
KernelSelection,
def_library,
)

__all__ = [
"trace",
]

IREE_LIBRARY = def_library("iree")


################################################################################
# trace_tensor / trace_tensors
# See the flow.tensor_trace op for details. In essence:
# * trace_key is a name to label tensors with (intended for log filtering)
# * tensor or tensors are values to log a value for
################################################################################


def _emit_tensor_trace(kb: KernelBuilder, key: str, ts: list[Value]):
dynamic_dims = []
for t in ts:
rtt = RankedTensorType(t.type)
for i in range(rtt.rank):
if rtt.is_dynamic_dim(i):
dynamic_dims.append(tensor_d.dim(t, kb.constant_index(i)))
flow_d.TensorTraceOp(StringAttr.get(key), ts, dynamic_dims)


@CustomOp.register(library=IREE_LIBRARY)
class trace_tensor(CustomOp):
signature = "trace_tensor(str trace_key, Tensor tensor) -> ()"

def select(self, ksel: KernelSelection):
ksel.attr_str(0)
ksel.arg_tensor(1)

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
_emit_tensor_trace(kb, ksel.arg_descs[0].v, [kb.arg_bindings[1]])
kb.yield_results()


@CustomOp.register(library=IREE_LIBRARY)
class trace_tensors(CustomOp):
signature = "trace_tensors(str trace_key, Tensor[] tensors) -> ()"

def select(self, ksel: KernelSelection):
ksel.attr_str(0)
ksel.arg_tensor_list(1)

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
ts = kb.arg_bindings[1]
if len(ts) >= 1:
_emit_tensor_trace(kb, ksel.arg_descs[0].v, ts)
kb.yield_results()
2 changes: 1 addition & 1 deletion python/shark_turbine/runtime/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def _device_import_torch_tensor_cpu(device: Device, t: torch.Tensor) -> HalBuffe
memory_type=MemoryType.DEVICE_LOCAL,
allowed_usage=BufferUsage.DEFAULT,
device=hal_device,
buffer=t.numpy(),
buffer=t.detach().numpy(),
element_type=element_type,
)
return bv
Expand Down
Loading

0 comments on commit a176f99

Please sign in to comment.