Skip to content

Commit

Permalink
[custom ops] Begin the scaffolding for dispatch of PyTorch custom ops. (
Browse files Browse the repository at this point in the history
#270)

This makes it possible for us to directly define regular torch ops in
terms of generated MLIR. The resulting ops will be specialized and
cached per requirements in their definition and will be compiled for any
device that Turbine supports when dispatched against tensors on that
device.

It is left to a follow-up to also wire this mechanism in on the AOT side
so that compiling programs that contain our own custom ops transparently
includes them with no further glue. The scaffolding for this is in
place, but this patch is big enough without touching AOT.

This allows users to say something like:

```
        @CustomOp.register
        class identity(CustomOp):
            name = "test_identity"
            signature = "(Tensor self) -> Tensor"

            def select(self, ksel: KernelSelection):
                x = ksel.arg_tensor(0)
                ksel.return_tensor(x.t)

            def generate(self, ksel: KernelSelection, kb: KernelBuilder):
                # This just yields the IR value of kernel input as the output.
                # Effectively in eager mode, this is a `return` from the kernel
                # function.
                kb.yield_results(kb.arg_bindings[0])

        t = torch.tensor([[1, 2, 3]], dtype=torch.int32)
        result = identity(t)
        print("CPU result:", result)
        torch.testing.assert_close(result, t)
```

There will be dedicated `CustomOp` subclasses for our various DSLs that
can be used for such things (for more sugar'd use than just open coding
IR).
  • Loading branch information
stellaraccident authored Dec 19, 2023
1 parent 5d9d08b commit 68df316
Show file tree
Hide file tree
Showing 28 changed files with 1,490 additions and 319 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
- name: Run tests
run: |
pytest tests/
pytest -n 4 tests/
black:
strategy:
Expand Down
24 changes: 12 additions & 12 deletions python/shark_turbine/aot/builtins/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@
FxImporter,
)

from ...support.ir_imports import (
FlatSymbolRefAttr,
FunctionType,
Operation,
StringAttr,
SymbolTable,
TypeAttr,
Value,
func_d,
util_d,
)

from ..passes import (
functorch_functionalize,
)
Expand All @@ -53,18 +65,6 @@
MaterializedGlobal,
)

from ..support.ir_imports import (
FlatSymbolRefAttr,
FunctionType,
Operation,
StringAttr,
SymbolTable,
TypeAttr,
Value,
func_d,
util_d,
)

StringAttrOrStr = Union[StringAttr, str]


Expand Down
14 changes: 7 additions & 7 deletions python/shark_turbine/aot/compiled_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,7 @@

from . import builtins

from .support.procedural import (
GlobalsDef,
ProcedureTrace,
current_ir_trace,
)

from .support.ir_imports import (
from ..support.ir_imports import (
Context,
Location,
MLIRError,
Expand All @@ -33,6 +27,12 @@
StringAttr,
)

from .support.procedural import (
GlobalsDef,
ProcedureTrace,
current_ir_trace,
)

from .support.ir_utils import (
ModuleBuilder,
)
Expand Down
9 changes: 5 additions & 4 deletions python/shark_turbine/aot/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@
Output,
)

from ..support.ir_imports import (
Context,
Operation,
)

from .builtins import *
from .compiled_module import (
CompiledModule,
CompiledModuleMeta,
ExportProcDef,
)
from .support.ir_imports import (
Context,
Operation,
)
from .support.procedural import (
AbstractTypedef,
)
Expand Down
63 changes: 9 additions & 54 deletions python/shark_turbine/aot/support/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple
from typing import Any, Callable, Generator, List, Optional, Sequence, Tuple

from pathlib import Path
import tempfile
Expand All @@ -15,7 +15,6 @@

from ...importers.fx_importer import (
ContextCache,
TORCH_DTYPE_TO_MLIR_TYPE_ASM,
)

from ...importers.utils import (
Expand All @@ -26,12 +25,9 @@
NativeTypeConverter,
)

from .ir_imports import (
from ...support.ir_imports import (
Attribute,
Block,
BlockArgument,
BF16Type,
ComplexType,
DenseElementsAttr,
DenseResourceElementsAttr,
F16Type,
Expand All @@ -46,7 +42,6 @@
IrType,
Location,
MLIRError,
OpResult,
Operation,
RankedTensorType,
StringAttr,
Expand All @@ -59,61 +54,21 @@
tensor_d,
)

from ...support.conversions import (
TORCH_DTYPE_TO_IREE_TYPE,
)

from .utils import (
RefTracker,
logger,
)

###############################################################################
# Lookup tables
###############################################################################

# We need the inverse of the TORCH_DTYPE_TO_MLIR_TYPE_ASM table.
MLIR_TYPE_ASM_TO_TORCH_DTYPE = {v: k for k, v in TORCH_DTYPE_TO_MLIR_TYPE_ASM.items()}

# When emitting constants, we have to create native IREE types.
TORCH_DTYPE_TO_IREE_TYPE: Dict[torch.dtype, Callable[[], IrType]] = {
torch.float16: lambda: F16Type.get(),
torch.bfloat16: lambda: BF16Type.get(),
torch.float32: lambda: F32Type.get(),
torch.float64: lambda: F64Type.get(),
torch.uint8: lambda: IntegerType.get_signless(8),
torch.int8: lambda: IntegerType.get_signless(8),
torch.int16: lambda: IntegerType.get_signless(16),
torch.int32: lambda: IntegerType.get_signless(32),
torch.int64: lambda: IntegerType.get_signless(64),
torch.bool: lambda: IntegerType.get_signless(1),
torch.qint8: lambda: IntegerType.get_signless(8),
torch.quint8: lambda: IntegerType.get_signless(8),
torch.complex32: lambda: ComplexType.get(F16Type.get()),
torch.complex64: lambda: ComplexType.get(F32Type.get()),
torch.complex128: lambda: ComplexType.get(F64Type.get()),
}

TORCH_DTYPE_TO_IREE_TYPE_ASM = {
torch.float16: "f16",
torch.bfloat16: "bf16",
torch.float32: "f32",
torch.float64: "f64",
torch.uint8: "i8",
torch.int8: "i8",
torch.int16: "i16",
torch.int32: "i32",
torch.int64: "i64",
torch.bool: "i1",
torch.qint8: "i8",
torch.quint8: "i8",
torch.complex32: "complex<f16>",
torch.complex64: "complex<f32>",
torch.complex128: "complex<f64>",
}

###############################################################################
# Configuration
###############################################################################

# Maps a name to an altered name. If returns None, then the original
# name is used (this lets Dict.get serve as a NameMapCallback).
# name is used (this lets dict.get serve as a NameMapCallback).
NameMapCallback = Callable[[str], Optional[str]]


Expand Down Expand Up @@ -420,7 +375,7 @@ def build_index_attribute(value: int) -> IntegerAttr:


def build_index_value(
value: int, constant_cache: Optional[Dict[int, Value]] = None
value: int, constant_cache: Optional[dict[int, Value]] = None
) -> Value:
if constant_cache is not None and value in constant_cache:
return constant_cache[value]
Expand All @@ -431,7 +386,7 @@ def build_index_value(


def build_tensor_dim_value(
t: Value, dim: int, constant_cache: Optional[Dict[int, Value]] = None
t: Value, dim: int, constant_cache: Optional[dict[int, Value]] = None
) -> Value:
dim_value = build_index_value(dim, constant_cache=constant_cache)
return tensor_d.DimOp(t, dim_value).result
Expand Down
2 changes: 1 addition & 1 deletion python/shark_turbine/aot/support/procedural/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import torch

from ..ir_imports import (
from ....support.ir_imports import (
F32Type,
F64Type,
IndexType,
Expand Down
2 changes: 1 addition & 1 deletion python/shark_turbine/aot/support/procedural/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import torch

from ..ir_imports import (
from ....support.ir_imports import (
IrType,
Operation,
Value,
Expand Down
7 changes: 5 additions & 2 deletions python/shark_turbine/aot/support/procedural/iree_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch

from ..ir_imports import (
from ....support.ir_imports import (
IndexType,
IntegerType,
IrType,
Expand All @@ -23,8 +23,11 @@
flow_d,
)

from ..ir_utils import (
from ....support.conversions import (
TORCH_DTYPE_TO_IREE_TYPE,
)

from ..ir_utils import (
build_index_value,
)

Expand Down
2 changes: 1 addition & 1 deletion python/shark_turbine/aot/support/procedural/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
dynamic_dim,
)

from ..ir_imports import (
from ....support.ir_imports import (
F32Type,
IrType,
RankedTensorType,
Expand Down
2 changes: 1 addition & 1 deletion python/shark_turbine/aot/support/procedural/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Sequence,
)

from ..ir_imports import (
from ....support.ir_imports import (
Location,
StringAttr,
Value,
Expand Down
1 change: 0 additions & 1 deletion python/shark_turbine/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception


from .device import Device
from .tensor import (
enable,
TurbineMode,
Expand Down
2 changes: 1 addition & 1 deletion python/shark_turbine/dynamo/backends/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import functools
import sys

from ..device import (
from ...runtime.device import (
DeviceState,
)

Expand Down
Loading

0 comments on commit 68df316

Please sign in to comment.