Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tk] Expose dtypes #393

Merged
merged 8 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions python/shark_turbine/kernel/_support/dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
__all__ = [
"DataType",
"bool",
"i4",
"i8",
"i16",
"i32",
"i64",
"f16",
"f32",
"f64",
"index",
]

_INT_TYPES = ["i1", "i4", "i8", "i16", "i32", "i64"]
_FLOAT_TYPES = ["f16", "f32", "f64"]
_INDEX_TYPES = ["index"]


class DataType:
_name: str
_ir_type_asm: str

def __init__(self, name, ir_type_asm=None):
self._name = name
self._ir_type_asm = ir_type_asm if ir_type_asm else name

def ir_type_asm(self):
return self._ir_type_asm

def __str__(self):
return self._name

def __repr__(self):
return f"DataType({self._ir_type_asm})"

def is_int_asm(self):
return self._name in _INT_TYPES

def is_float_asm(self):
return self._name in _FLOAT_TYPES

def is_index_asm(self):
return self._name in _INDEX_TYPES


bool = DataType("bool", "i1")
i4 = DataType("i4")
i8 = DataType("i8")
i16 = DataType("i16")
i32 = DataType("i32")
i64 = DataType("i64")
f32 = DataType("f32")
f64 = DataType("f64")
f16 = DataType("f16")
f32 = DataType("f32")
f64 = DataType("f64")
index = DataType("index")
58 changes: 10 additions & 48 deletions python/shark_turbine/kernel/_support/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .. import ops

from . import context
from . import dtype

__all__ = [
"backed_sym_index_type",
Expand All @@ -28,6 +29,9 @@
"TemporaryBuffer",
]

DataType = dtype.DataType
DefaultDataType = dtype.f32


class NotSetType:
...
Expand All @@ -37,46 +41,6 @@ class NotSetType:

SubtypeT = TypeVar("SubtypeT")

###############################################################################
# ElementType
###############################################################################


class ElementType(ABC):
@staticmethod
def cast(something) -> "ElementType":
if isinstance(something, torch.dtype):
return TorchElementType(something)
else:
raise TypeError(
f"Cannot convert {something} (of type {type(something)}) to an element type"
)

@abstractmethod
def ir_type_asm(self) -> str:
...


class TorchElementType(ElementType):
def __init__(self, dtype: torch.dtype):
self.dtype = dtype

def __repr__(self):
return repr(self.dtype)

def __eq__(self, other):
return isinstance(other, TorchElementType) and self.dtype == other.dtype

def ir_type_asm(self) -> str:
dtype = self.dtype
if dtype == torch.float32:
return "f32"
else:
raise ValueError(f"Torch dtype {dtype} cannot be mapped to MLIR type")


DefaultElementType = TorchElementType(torch.float32)

###############################################################################
# Index symbols and expressions
# These are just light-weight helpers around sympy symbols and expressions.
Expand Down Expand Up @@ -224,7 +188,7 @@ class _KernelBufferMeta(type):
This lets us specialize with symbolic shape information.
"""

element_type: ElementType
element_type: DataType
usage: KernelBufferUsage
symbolic_shape: Optional[SymbolicShapeExpr]
rank: Optional[int]
Expand All @@ -235,7 +199,7 @@ def __new__(
bases,
dct,
):
element_type = dct.get("element_type") or DefaultElementType
element_type = dct.get("element_type") or DefaultDataType
dct["element_type"] = element_type
usage = dct.get("usage") or KernelBufferUsage.NONE
dct["usage"] = usage
Expand All @@ -253,7 +217,7 @@ def __new__(
def new_subtype(
cls: Type[SubtypeT],
*,
element_type: Union[NotSetType, ElementType] = NotSet,
element_type: Union[NotSetType, DataType] = NotSet,
symbolic_shape: Union[NotSetType, Optional[SymbolicShapeable]] = NotSet,
usage: Union[NotSetType, KernelBufferUsage] = NotSet,
) -> Type[SubtypeT]:
Expand All @@ -272,9 +236,7 @@ class Subtype(cls):

return Subtype

def of(
cls: Type[SubtypeT], element_type: Union[Any, ElementType, torch.dtype]
) -> Type[SubtypeT]:
def of(cls: Type[SubtypeT], element_type: Union[Any, DataType]) -> Type[SubtypeT]:
return cls.new_subtype(element_type=element_type)

def __repr__(cls):
Expand All @@ -291,7 +253,7 @@ def is_kernel_buffer_meta_derived(t: type) -> bool:

def _kernel_buffer_type_repr(
*,
element_type: ElementType,
element_type: DataType,
usage: KernelBufferUsage,
symbolic_shape: Optional[tuple[IndexExpr]],
) -> str:
Expand All @@ -300,7 +262,7 @@ def _kernel_buffer_type_repr(
stem = f"{root}[{', '.join(repr(s) for s in symbolic_shape)}]"
else:
stem = f"{root}"
if element_type != DefaultElementType:
if element_type != DefaultDataType:
stem += f".of({element_type})"
return stem

Expand Down
4 changes: 4 additions & 0 deletions python/shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)

from . import context
from .dtype import DataType

try:
from typing import assert_type
Expand Down Expand Up @@ -109,6 +110,9 @@ def create_arg(self, a):
# Let IndexExpr persist as arguments.
if isinstance(a, IndexExpr):
return a
# Let DataType persist as arguments.
if isinstance(a, DataType):
return a
return super().create_arg(a)


Expand Down
Loading
Loading