Skip to content

Commit

Permalink
[tk] Expose dtypes (#393)
Browse files Browse the repository at this point in the history
This patch exposes data types to users. Currently, creation of constant
vectors for expoxed types is supported and arithmetic on them.

What is not supported yet:

1. Scalar conversion
2. Vector conversion
3. KernelBuffer element type
4. Promotion/Truncation

Test output:
https://gist.github.com/Groverkss/30d5aaad0ea2b44960ba88b5340463ad
  • Loading branch information
Groverkss authored Feb 5, 2024
1 parent 78060b8 commit 0bb2a8c
Show file tree
Hide file tree
Showing 10 changed files with 257 additions and 120 deletions.
58 changes: 58 additions & 0 deletions python/shark_turbine/kernel/_support/dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
__all__ = [
"DataType",
"bool",
"i4",
"i8",
"i16",
"i32",
"i64",
"f16",
"f32",
"f64",
"index",
]

_INT_TYPES = ["i1", "i4", "i8", "i16", "i32", "i64"]
_FLOAT_TYPES = ["f16", "f32", "f64"]
_INDEX_TYPES = ["index"]


class DataType:
_name: str
_ir_type_asm: str

def __init__(self, name, ir_type_asm=None):
self._name = name
self._ir_type_asm = ir_type_asm if ir_type_asm else name

def ir_type_asm(self):
return self._ir_type_asm

def __str__(self):
return self._name

def __repr__(self):
return f"DataType({self._ir_type_asm})"

def is_int_asm(self):
return self._name in _INT_TYPES

def is_float_asm(self):
return self._name in _FLOAT_TYPES

def is_index_asm(self):
return self._name in _INDEX_TYPES


bool = DataType("bool", "i1")
i4 = DataType("i4")
i8 = DataType("i8")
i16 = DataType("i16")
i32 = DataType("i32")
i64 = DataType("i64")
f32 = DataType("f32")
f64 = DataType("f64")
f16 = DataType("f16")
f32 = DataType("f32")
f64 = DataType("f64")
index = DataType("index")
58 changes: 10 additions & 48 deletions python/shark_turbine/kernel/_support/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .. import ops

from . import context
from . import dtype

__all__ = [
"backed_sym_index_type",
Expand All @@ -28,6 +29,9 @@
"TemporaryBuffer",
]

DataType = dtype.DataType
DefaultDataType = dtype.f32


class NotSetType:
...
Expand All @@ -37,46 +41,6 @@ class NotSetType:

SubtypeT = TypeVar("SubtypeT")

###############################################################################
# ElementType
###############################################################################


class ElementType(ABC):
@staticmethod
def cast(something) -> "ElementType":
if isinstance(something, torch.dtype):
return TorchElementType(something)
else:
raise TypeError(
f"Cannot convert {something} (of type {type(something)}) to an element type"
)

@abstractmethod
def ir_type_asm(self) -> str:
...


class TorchElementType(ElementType):
def __init__(self, dtype: torch.dtype):
self.dtype = dtype

def __repr__(self):
return repr(self.dtype)

def __eq__(self, other):
return isinstance(other, TorchElementType) and self.dtype == other.dtype

def ir_type_asm(self) -> str:
dtype = self.dtype
if dtype == torch.float32:
return "f32"
else:
raise ValueError(f"Torch dtype {dtype} cannot be mapped to MLIR type")


DefaultElementType = TorchElementType(torch.float32)

###############################################################################
# Index symbols and expressions
# These are just light-weight helpers around sympy symbols and expressions.
Expand Down Expand Up @@ -224,7 +188,7 @@ class _KernelBufferMeta(type):
This lets us specialize with symbolic shape information.
"""

element_type: ElementType
element_type: DataType
usage: KernelBufferUsage
symbolic_shape: Optional[SymbolicShapeExpr]
rank: Optional[int]
Expand All @@ -235,7 +199,7 @@ def __new__(
bases,
dct,
):
element_type = dct.get("element_type") or DefaultElementType
element_type = dct.get("element_type") or DefaultDataType
dct["element_type"] = element_type
usage = dct.get("usage") or KernelBufferUsage.NONE
dct["usage"] = usage
Expand All @@ -253,7 +217,7 @@ def __new__(
def new_subtype(
cls: Type[SubtypeT],
*,
element_type: Union[NotSetType, ElementType] = NotSet,
element_type: Union[NotSetType, DataType] = NotSet,
symbolic_shape: Union[NotSetType, Optional[SymbolicShapeable]] = NotSet,
usage: Union[NotSetType, KernelBufferUsage] = NotSet,
) -> Type[SubtypeT]:
Expand All @@ -272,9 +236,7 @@ class Subtype(cls):

return Subtype

def of(
cls: Type[SubtypeT], element_type: Union[Any, ElementType, torch.dtype]
) -> Type[SubtypeT]:
def of(cls: Type[SubtypeT], element_type: Union[Any, DataType]) -> Type[SubtypeT]:
return cls.new_subtype(element_type=element_type)

def __repr__(cls):
Expand All @@ -291,7 +253,7 @@ def is_kernel_buffer_meta_derived(t: type) -> bool:

def _kernel_buffer_type_repr(
*,
element_type: ElementType,
element_type: DataType,
usage: KernelBufferUsage,
symbolic_shape: Optional[tuple[IndexExpr]],
) -> str:
Expand All @@ -300,7 +262,7 @@ def _kernel_buffer_type_repr(
stem = f"{root}[{', '.join(repr(s) for s in symbolic_shape)}]"
else:
stem = f"{root}"
if element_type != DefaultElementType:
if element_type != DefaultDataType:
stem += f".of({element_type})"
return stem

Expand Down
4 changes: 4 additions & 0 deletions python/shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)

from . import context
from .dtype import DataType

try:
from typing import assert_type
Expand Down Expand Up @@ -109,6 +110,9 @@ def create_arg(self, a):
# Let IndexExpr persist as arguments.
if isinstance(a, IndexExpr):
return a
# Let DataType persist as arguments.
if isinstance(a, DataType):
return a
return super().create_arg(a)


Expand Down
Loading

0 comments on commit 0bb2a8c

Please sign in to comment.