From 389b2f416bdc8fd16014f2fb669cbb7bcc6a8409 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Sun, 4 Feb 2024 15:58:50 +0530 Subject: [PATCH 1/8] Remove torch.float32 for tkl.f32 --- python/shark_turbine/kernel/_support/dtype.py | 53 +++++++++++++++++ .../shark_turbine/kernel/_support/indexing.py | 58 ++++--------------- .../shark_turbine/kernel/_support/tracing.py | 4 ++ .../kernel/compiler/vector_codegen.py | 2 +- python/shark_turbine/kernel/lang/__init__.py | 13 +++++ python/shark_turbine/kernel/lang/types.py | 1 - 6 files changed, 81 insertions(+), 50 deletions(-) create mode 100644 python/shark_turbine/kernel/_support/dtype.py diff --git a/python/shark_turbine/kernel/_support/dtype.py b/python/shark_turbine/kernel/_support/dtype.py new file mode 100644 index 000000000..d4f3ea38f --- /dev/null +++ b/python/shark_turbine/kernel/_support/dtype.py @@ -0,0 +1,53 @@ +__all__ = [ + "DataType", + "bool", + "i4", + "i8", + "i16", + "i32", + "i64", + "f16", + "f32", + "f64", + "index", +] + +_INT_TYPES = ["i1", "i4", "i8", "i16", "i32", "i64"] +_FLOAT_TYPES = ["f16", "f32", "f64"] +_INDEX_TYPES = ["index"] + + +class DataType: + name: str + + def __init__(self, name): + self.name = name + + def ir_type_asm(self): + return self.name + + def is_int(self): + return self.name in _INT_TYPES + + def is_float(self): + return self.name in _FLOAT_TYPES + + def is_index(self): + return self.name in _INDEX_TYPES + + def is_bool(self): + return self.name == "bool" + + +bool = DataType("bool") +i4 = DataType("i4") +i8 = DataType("i8") +i16 = DataType("i16") +i32 = DataType("i32") +i64 = DataType("i64") +f32 = DataType("f32") +f64 = DataType("f64") +f16 = DataType("f16") +f32 = DataType("f32") +f64 = DataType("f64") +index = DataType("index") diff --git a/python/shark_turbine/kernel/_support/indexing.py b/python/shark_turbine/kernel/_support/indexing.py index d9c405c29..ab2b153c8 100644 --- a/python/shark_turbine/kernel/_support/indexing.py +++ b/python/shark_turbine/kernel/_support/indexing.py @@ -10,6 +10,7 @@ from .. import ops from . import context +from . import dtype __all__ = [ "backed_sym_index_type", @@ -28,6 +29,9 @@ "TemporaryBuffer", ] +DataType = dtype.DataType +DefaultDataType = dtype.f32 + class NotSetType: ... @@ -37,46 +41,6 @@ class NotSetType: SubtypeT = TypeVar("SubtypeT") -############################################################################### -# ElementType -############################################################################### - - -class ElementType(ABC): - @staticmethod - def cast(something) -> "ElementType": - if isinstance(something, torch.dtype): - return TorchElementType(something) - else: - raise TypeError( - f"Cannot convert {something} (of type {type(something)}) to an element type" - ) - - @abstractmethod - def ir_type_asm(self) -> str: - ... - - -class TorchElementType(ElementType): - def __init__(self, dtype: torch.dtype): - self.dtype = dtype - - def __repr__(self): - return repr(self.dtype) - - def __eq__(self, other): - return isinstance(other, TorchElementType) and self.dtype == other.dtype - - def ir_type_asm(self) -> str: - dtype = self.dtype - if dtype == torch.float32: - return "f32" - else: - raise ValueError(f"Torch dtype {dtype} cannot be mapped to MLIR type") - - -DefaultElementType = TorchElementType(torch.float32) - ############################################################################### # Index symbols and expressions # These are just light-weight helpers around sympy symbols and expressions. @@ -224,7 +188,7 @@ class _KernelBufferMeta(type): This lets us specialize with symbolic shape information. """ - element_type: ElementType + element_type: DataType usage: KernelBufferUsage symbolic_shape: Optional[SymbolicShapeExpr] rank: Optional[int] @@ -235,7 +199,7 @@ def __new__( bases, dct, ): - element_type = dct.get("element_type") or DefaultElementType + element_type = dct.get("element_type") or DefaultDataType dct["element_type"] = element_type usage = dct.get("usage") or KernelBufferUsage.NONE dct["usage"] = usage @@ -253,7 +217,7 @@ def __new__( def new_subtype( cls: Type[SubtypeT], *, - element_type: Union[NotSetType, ElementType] = NotSet, + element_type: Union[NotSetType, DataType] = NotSet, symbolic_shape: Union[NotSetType, Optional[SymbolicShapeable]] = NotSet, usage: Union[NotSetType, KernelBufferUsage] = NotSet, ) -> Type[SubtypeT]: @@ -272,9 +236,7 @@ class Subtype(cls): return Subtype - def of( - cls: Type[SubtypeT], element_type: Union[Any, ElementType, torch.dtype] - ) -> Type[SubtypeT]: + def of(cls: Type[SubtypeT], element_type: Union[Any, DataType]) -> Type[SubtypeT]: return cls.new_subtype(element_type=element_type) def __repr__(cls): @@ -291,7 +253,7 @@ def is_kernel_buffer_meta_derived(t: type) -> bool: def _kernel_buffer_type_repr( *, - element_type: ElementType, + element_type: DataType, usage: KernelBufferUsage, symbolic_shape: Optional[tuple[IndexExpr]], ) -> str: @@ -300,7 +262,7 @@ def _kernel_buffer_type_repr( stem = f"{root}[{', '.join(repr(s) for s in symbolic_shape)}]" else: stem = f"{root}" - if element_type != DefaultElementType: + if element_type != DefaultDataType: stem += f".of({element_type})" return stem diff --git a/python/shark_turbine/kernel/_support/tracing.py b/python/shark_turbine/kernel/_support/tracing.py index 90716e0d9..19907dd4b 100644 --- a/python/shark_turbine/kernel/_support/tracing.py +++ b/python/shark_turbine/kernel/_support/tracing.py @@ -42,6 +42,7 @@ ) from . import context +from .dtype import DataType try: from typing import assert_type @@ -109,6 +110,9 @@ def create_arg(self, a): # Let IndexExpr persist as arguments. if isinstance(a, IndexExpr): return a + # Let DataType persist as arguments. + if isinstance(a, DataType): + return a return super().create_arg(a) diff --git a/python/shark_turbine/kernel/compiler/vector_codegen.py b/python/shark_turbine/kernel/compiler/vector_codegen.py index 6f95edf20..f53e6087e 100644 --- a/python/shark_turbine/kernel/compiler/vector_codegen.py +++ b/python/shark_turbine/kernel/compiler/vector_codegen.py @@ -496,7 +496,7 @@ def _(emitter: ThreadEmitter, node: fx.Node): shape = cast_py_literal(emitter, shape) # TODO: Have better way to get the dtype. - if dtype == torch.float32: + if dtype == tkl.f32: element_type = F32Type.get() vector_type = VectorType.get(shape, element_type) dense_value = DenseElementsAttr.get_splat(vector_type, FloatAttr.get_f32(value)) diff --git a/python/shark_turbine/kernel/lang/__init__.py b/python/shark_turbine/kernel/lang/__init__.py index 3b3d40b29..84217c691 100644 --- a/python/shark_turbine/kernel/lang/__init__.py +++ b/python/shark_turbine/kernel/lang/__init__.py @@ -12,3 +12,16 @@ TemporaryBuffer, sym, ) + +from .._support.dtype import ( + bool, + i4, + i8, + i16, + i32, + i64, + f16, + f32, + f64, + index, +) diff --git a/python/shark_turbine/kernel/lang/types.py b/python/shark_turbine/kernel/lang/types.py index be6a59f5d..1f42fdb11 100644 --- a/python/shark_turbine/kernel/lang/types.py +++ b/python/shark_turbine/kernel/lang/types.py @@ -1,6 +1,5 @@ from typing import Type - __all__ = [ "Index", "Vector", From 23ba315a50b7a9a369d4460e1718eb68ffda4d78 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Mon, 5 Feb 2024 04:34:45 +0530 Subject: [PATCH 2/8] Fix asm for bool --- python/shark_turbine/kernel/_support/dtype.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/shark_turbine/kernel/_support/dtype.py b/python/shark_turbine/kernel/_support/dtype.py index d4f3ea38f..823c1881a 100644 --- a/python/shark_turbine/kernel/_support/dtype.py +++ b/python/shark_turbine/kernel/_support/dtype.py @@ -39,7 +39,7 @@ def is_bool(self): return self.name == "bool" -bool = DataType("bool") +bool = DataType("i1") i4 = DataType("i4") i8 = DataType("i8") i16 = DataType("i16") From 36697b97113d953ecb418931336b6e05e5350b60 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Mon, 5 Feb 2024 04:39:12 +0530 Subject: [PATCH 3/8] dtype improvements --- python/shark_turbine/kernel/_support/dtype.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/python/shark_turbine/kernel/_support/dtype.py b/python/shark_turbine/kernel/_support/dtype.py index 823c1881a..a6c1109ce 100644 --- a/python/shark_turbine/kernel/_support/dtype.py +++ b/python/shark_turbine/kernel/_support/dtype.py @@ -18,28 +18,34 @@ class DataType: - name: str + _name: str + _ir_type_asm: str - def __init__(self, name): - self.name = name + def __init__(self, name, ir_type_asm=None): + self._name = name + if ir_type_asm is None: + self._ir_type_asm = name def ir_type_asm(self): - return self.name + return self._ir_type_asm - def is_int(self): - return self.name in _INT_TYPES + def __str__(self): + return self._name - def is_float(self): - return self.name in _FLOAT_TYPES + def __repr__(self): + return f"DataType({self._ir_type_asm})" - def is_index(self): - return self.name in _INDEX_TYPES + def is_int_asm(self): + return self._name in _INT_TYPES - def is_bool(self): - return self.name == "bool" + def is_float_asm(self): + return self._name in _FLOAT_TYPES + def is_index_asm(self): + return self._name in _INDEX_TYPES -bool = DataType("i1") + +bool = DataType("bool", "i1") i4 = DataType("i4") i8 = DataType("i8") i16 = DataType("i16") From a0c4610a64c03ec926faa0631ddb7dff630f0fa4 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Mon, 5 Feb 2024 05:43:39 +0530 Subject: [PATCH 4/8] Add constant support for types --- .../shark_turbine/kernel/compiler/builder.py | 66 +++++++++++-------- python/shark_turbine/kernel/compiler/ir.py | 2 + .../kernel/compiler/vector_codegen.py | 32 +++++---- 3 files changed, 58 insertions(+), 42 deletions(-) diff --git a/python/shark_turbine/kernel/compiler/builder.py b/python/shark_turbine/kernel/compiler/builder.py index 4311b28ef..63c8e3ad5 100644 --- a/python/shark_turbine/kernel/compiler/builder.py +++ b/python/shark_turbine/kernel/compiler/builder.py @@ -17,6 +17,7 @@ IndexType, IntegerAttr, IntegerType, + DenseElementsAttr, IrType, Location, Operation, @@ -26,9 +27,11 @@ arith_d, math_d, builtin_d, + F16Type, + F32Type, + F64Type, ) - # TODO: Have a way upstream to check if a floating point type. FLOAT_TYPES_ASM = { "bf16", @@ -89,6 +92,9 @@ def is_floating_point_type(self, t: IrType) -> bool: def is_integer_type(self, t: IrType) -> bool: return IntegerType.isinstance(t) + def is_index_type(self, t: IrType) -> bool: + return IndexType.isinstance(t) + def promote(self, value: Value, to_type: IrType) -> Value: value_type = value.type # Short-circuit if already the right type. @@ -104,25 +110,37 @@ def promote(self, value: Value, to_type: IrType) -> Value: ) return handler(value, to_type) - def zero_attr(self, t: IrType) -> Attribute: - attr_name = f"zero_attr_{t}" - try: - handler = getattr(self, attr_name) - except AttributeError: - raise CodegenError( - f"Cannot derive a zero value for type `{t}` (tried '{attr_name}')" - ) - return handler(t) + def constant_attr(self, val: int | float, element_type: IrType) -> Attribute: + if self.is_integer_type(element_type) or self.is_index_type(element_type): + if not isinstance(val, int): + raise TypeError(f"Expected an integer value, got {val}") + return IntegerAttr.get(element_type, val) - def constant(self, py_value) -> IRProxyValue: - attr_name = f"py_constant_{type(py_value).__name__}" - try: - handler = getattr(self, attr_name) - except AttributeError: - raise CodegenError( - f"Cannot convert Python value to constant: {py_value} of type {type(py_value)} (tried '{attr_name}')" - ) - return handler(py_value) + if self.is_floating_point_type(element_type): + if not isinstance(val, float): + raise TypeError(f"Expected a float value, got {val}") + return FloatAttr.get(element_type, val) + + raise CodegenError( + f"Cannot create a constant attribute for type `{element_type}`" + ) + + def zero_attr(self, t: IrType) -> Attribute: + if self.is_integer_type(t) or self.is_index_type(t): + return self.constant_attr(0, t) + if self.is_floating_point_type(t): + return self.constant_attr(0.0, t) + raise CodegenError(f"Cannot create a zero attribute for type `{t}`") + + def constant(self, py_value, element_type: IrType) -> IRProxyValue: + attr = self.constant_attr(py_value, element_type) + return IRProxyValue(arith_d.constant(element_type, attr)) + + def constant_vector(self, py_value, shape, element_type: IrType) -> IRProxyValue: + attr = self.constant_attr(py_value, element_type) + vector_type = VectorType.get(shape, element_type) + splat = DenseElementsAttr.get_splat(vector_type, attr) + return IRProxyValue(arith_d.constant(vector_type, splat)) def binary_arithmetic( self, op: str, lhs: IRProxyValue, rhs: IRProxyValue @@ -185,16 +203,6 @@ def promote_index_to_f32(self, value: Value, to_type: IrType) -> Value: def zero_attr_f32(self, t: IrType) -> Attribute: return FloatAttr.get(t, 0.0) - def py_constant_int(self, py_value) -> IRProxyValue: - # If coming from a stock 'int' Python type with no idea how to convert it, - # there isn't much smart we can do. We conservatively treat 'index' as - # reasonable. - result_type = IndexType.get() - return IRProxyValue( - arith_d.constant(result_type, IntegerAttr.get(result_type, py_value)), - py_value, - ) - # Binary index arithmetic. def binary_add_index_index( self, lhs: IRProxyValue, rhs: IRProxyValue diff --git a/python/shark_turbine/kernel/compiler/ir.py b/python/shark_turbine/kernel/compiler/ir.py index a37a858a0..b9b8b8e56 100644 --- a/python/shark_turbine/kernel/compiler/ir.py +++ b/python/shark_turbine/kernel/compiler/ir.py @@ -8,7 +8,9 @@ Block, Context, DenseElementsAttr, + F16Type, F32Type, + F64Type, FloatAttr, FunctionType, IndexType, diff --git a/python/shark_turbine/kernel/compiler/vector_codegen.py b/python/shark_turbine/kernel/compiler/vector_codegen.py index f53e6087e..1e7042fda 100644 --- a/python/shark_turbine/kernel/compiler/vector_codegen.py +++ b/python/shark_turbine/kernel/compiler/vector_codegen.py @@ -26,6 +26,8 @@ is_kernel_buffer_meta_derived, ) +from .._support import dtype + from .._support.tracing import CapturedTrace from .. import lang as tkl @@ -57,6 +59,7 @@ VectorType, DenseElementsAttr, F32Type, + IndexType, FloatAttr, InsertionPoint, IrType, @@ -494,16 +497,10 @@ def _(emitter: ThreadEmitter, node: fx.Node): raise ValidationError("Malformed arguments") from e shape = cast_py_literal(emitter, shape) - - # TODO: Have better way to get the dtype. - if dtype == tkl.f32: - element_type = F32Type.get() - vector_type = VectorType.get(shape, element_type) - dense_value = DenseElementsAttr.get_splat(vector_type, FloatAttr.get_f32(value)) - result = arith_d.ConstantOp(vector_type, dense_value).result - emitter.bind_node_proxy(node, IRProxyValue(result)) - else: - raise CodegenError(f"NYI: Constant type {dtype}") + dtype = cast_dtype(emitter, dtype) + constant = ScalarBuilder.constant_vector(value, shape, dtype) + print(constant) + emitter.bind_node_proxy(node, constant) ############################################################################### @@ -788,7 +785,7 @@ def cast_py_value(emitter: ThreadEmitter, value) -> IRProxyValue: f"Dynamically resolved symbolic values not yet implemented. Got: " f"{simplified}" ) from e - return ScalarBuilder.constant(value) + return ScalarBuilder.constant(value, IndexType.get()) def cast_py_lvalue(emitter: ThreadEmitter, py_value: fx.Node) -> tuple[Value, fx.Node]: @@ -859,6 +856,15 @@ def cast_vector( return vector_d.splat(vector_type, value) +def cast_dtype(emitter: ThreadEmitter, dtype: dtype.DataType) -> IrType: + try: + ir_dtype = IrType.parse(dtype.ir_type_asm()) + except CodegenError as e: + raise CodegenError(f"Failed to convert dtype {dtype} to IR type") from e + + return ir_dtype + + ############################################################################### # Slice and indexing ############################################################################### @@ -995,7 +1001,7 @@ def cast_dynamic_index_value(emitter: ThreadEmitter, py_index) -> IRProxyValue: except TypeError: # Need to materialize the expression. raise CodegenError(f"NYI: Materialized index expression {py_index}") - return ScalarBuilder.constant(int_value) + return ScalarBuilder.constant(int_value, IndexType.get()) def extract_slice_starts( @@ -1006,7 +1012,7 @@ def extract_slice_starts( def _extract(i): atom = slice_spec[i] if atom is None: - return ScalarBuilder.constant(0) + return ScalarBuilder.constant(0, IndexType.get()) elif isinstance(atom, slice): return cast_dynamic_index_value(emitter, atom.start).ir_value else: From 14d1d8d30708b9e89a649637ea17b5332b1cfdac Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Mon, 5 Feb 2024 06:20:46 +0530 Subject: [PATCH 5/8] Support codegen for any arithmetic type --- python/shark_turbine/kernel/_support/dtype.py | 3 +- .../shark_turbine/kernel/compiler/builder.py | 63 ++++++++++--------- .../kernel/compiler/vector_codegen.py | 1 - 3 files changed, 36 insertions(+), 31 deletions(-) diff --git a/python/shark_turbine/kernel/_support/dtype.py b/python/shark_turbine/kernel/_support/dtype.py index a6c1109ce..6c1cc1f5d 100644 --- a/python/shark_turbine/kernel/_support/dtype.py +++ b/python/shark_turbine/kernel/_support/dtype.py @@ -23,8 +23,7 @@ class DataType: def __init__(self, name, ir_type_asm=None): self._name = name - if ir_type_asm is None: - self._ir_type_asm = name + self._ir_type_asm = ir_type_asm if ir_type_asm else name def ir_type_asm(self): return self._ir_type_asm diff --git a/python/shark_turbine/kernel/compiler/builder.py b/python/shark_turbine/kernel/compiler/builder.py index 63c8e3ad5..c40fd9ec8 100644 --- a/python/shark_turbine/kernel/compiler/builder.py +++ b/python/shark_turbine/kernel/compiler/builder.py @@ -147,7 +147,14 @@ def binary_arithmetic( ) -> IRProxyValue: lhs_ir_type = lhs.ir_value.type rhs_ir_type = rhs.ir_value.type - attr_name = f"binary_{op}_{lhs_ir_type}_{rhs_ir_type}" + + if lhs_ir_type != rhs_ir_type: + raise CodegenError( + f"Cannot perform binary arithmetic operation '{op}' between {lhs_ir_type} and {rhs_ir_type} due to element type mismatch" + ) + + typeclass = "float" if self.is_floating_point_type(lhs_ir_type) else "integer" + attr_name = f"binary_{op}_{typeclass}" try: handler = getattr(self, attr_name) except AttributeError: @@ -163,7 +170,14 @@ def binary_vector_arithmetic( rhs_ir = rhs.ir_value lhs_element_type = VectorType(lhs_ir.type).element_type rhs_element_type = VectorType(rhs_ir.type).element_type - attr_name = f"binary_{op}_{lhs_element_type}_{rhs_element_type}" + + if lhs_element_type != rhs_element_type: + raise CodegenError( + f"Cannot perform binary arithmetic operation '{op}' between {lhs_ir.type} and {rhs_ir.type} due to element type mismatch" + ) + + typeclass = "float" if self.is_floating_point_type(lhs_element_type) else "integer" + attr_name = f"binary_{op}_{typeclass}" try: handler = getattr(self, attr_name) except AttributeError: @@ -174,7 +188,8 @@ def binary_vector_arithmetic( def unary_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue: val_ir_type = val.ir_value.type - attr_name = f"unary_{op}_{val_ir_type}" + typeclass = "float" if self.is_floating_point_type(val_ir_type) else "integer" + attr_name = f"unary_{op}_{typeclass}" try: handler = getattr(self, attr_name) except AttributeError: @@ -186,7 +201,8 @@ def unary_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue: def unary_vector_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue: val_ir = val.ir_value val_element_type = VectorType(val_ir.type).element_type - attr_name = f"unary_{op}_{val_element_type}" + typeclass = "float" if self.is_floating_point_type(val_element_type) else "integer" + attr_name = f"unary_{op}_{typeclass}" try: handler = getattr(self, attr_name) except AttributeError: @@ -195,59 +211,50 @@ def unary_vector_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue: ) return handler(val) + ### Specializations + def promote_index_to_f32(self, value: Value, to_type: IrType) -> Value: i32_type = IntegerType.get_signless(32) i32 = arith_d.index_cast(i32_type, value) return arith_d.sitofp(to_type, i32) - def zero_attr_f32(self, t: IrType) -> Attribute: - return FloatAttr.get(t, 0.0) - - # Binary index arithmetic. - def binary_add_index_index( - self, lhs: IRProxyValue, rhs: IRProxyValue - ) -> IRProxyValue: + # Binary integer/integer arithmetic. + def binary_add_integer(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: return IRProxyValue(arith_d.addi(lhs.ir_value, rhs.ir_value)) - def binary_mul_index_index( - self, lhs: IRProxyValue, rhs: IRProxyValue - ) -> IRProxyValue: + def binary_mul_integer(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: return IRProxyValue(arith_d.muli(lhs.ir_value, rhs.ir_value)) - def binary_sub_index_index( - self, lhs: IRProxyValue, rhs: IRProxyValue - ) -> IRProxyValue: + def binary_sub_integer(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: return IRProxyValue(arith_d.subi(lhs.ir_value, rhs.ir_value)) - def binary_mod_index_index( - self, lhs: IRProxyValue, rhs: IRProxyValue - ) -> IRProxyValue: + def binary_mod_integer(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: return IRProxyValue(arith_d.remsi(lhs.ir_value, rhs.ir_value)) - def binary_floordiv_index_index( + def binary_floordiv_integer( self, lhs: IRProxyValue, rhs: IRProxyValue ) -> IRProxyValue: return IRProxyValue(arith_d.floordivsi(lhs.ir_value, rhs.ir_value)) - # Binary f32 arithmetic - def binary_add_f32_f32(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: + # Binary float arithmetic + def binary_add_float(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: return IRProxyValue(arith_d.addf(lhs.ir_value, rhs.ir_value)) - def binary_mul_f32_f32(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: + def binary_mul_float(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: return IRProxyValue(arith_d.mulf(lhs.ir_value, rhs.ir_value)) - def binary_sub_f32_f32(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: + def binary_sub_float(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: return IRProxyValue(arith_d.subf(lhs.ir_value, rhs.ir_value)) - def binary_mod_f32_f32(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: + def binary_mod_float(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: return IRProxyValue(arith_d.remf(lhs.ir_value, rhs.ir_value)) - def binary_truediv_f32_f32( + def binary_truediv_float( self, lhs: IRProxyValue, rhs: IRProxyValue ) -> IRProxyValue: return IRProxyValue(arith_d.divf(lhs.ir_value, rhs.ir_value)) - def unary_exp2_f32(self, val: IRProxyValue) -> IRProxyValue: + def unary_exp2_float(self, val: IRProxyValue) -> IRProxyValue: return IRProxyValue(math_d.exp2(val.ir_value)) diff --git a/python/shark_turbine/kernel/compiler/vector_codegen.py b/python/shark_turbine/kernel/compiler/vector_codegen.py index 1e7042fda..953dabc41 100644 --- a/python/shark_turbine/kernel/compiler/vector_codegen.py +++ b/python/shark_turbine/kernel/compiler/vector_codegen.py @@ -499,7 +499,6 @@ def _(emitter: ThreadEmitter, node: fx.Node): shape = cast_py_literal(emitter, shape) dtype = cast_dtype(emitter, dtype) constant = ScalarBuilder.constant_vector(value, shape, dtype) - print(constant) emitter.bind_node_proxy(node, constant) From bb9f7f8cea4baa5544ef3ad20de0446e7e799d50 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Mon, 5 Feb 2024 06:20:57 +0530 Subject: [PATCH 6/8] Add some tests --- tests/kernel/arith_test.py | 67 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 tests/kernel/arith_test.py diff --git a/tests/kernel/arith_test.py b/tests/kernel/arith_test.py new file mode 100644 index 000000000..1ec1e3b54 --- /dev/null +++ b/tests/kernel/arith_test.py @@ -0,0 +1,67 @@ +import logging +import unittest + +import torch +import shark_turbine.kernel as tk +import shark_turbine.kernel.lang as tkl + +from shark_turbine.kernel.compiler import ( + builder, + kernel_codegen, + vector_codegen, +) +from shark_turbine.kernel._support import ( + indexing, +) + +M = tk.lang.sym.M +K = tk.lang.sym.K + + +class Test(unittest.TestCase): + # This test is using the compiler "the hard way" until we have all of the + # API layering in place. + def testIotaFx(self): + @tk.gen.thread(M) + def iota_kernel(out: tk.lang.OutputBuffer[M]): + # Integer types + for dtype in [tkl.bool, tkl.i4, tkl.i8, tkl.i16, tkl.i32, tkl.i64, tkl.index]: + a = tkl.constant((17, 37, 19), dtype, 5) + b = tkl.constant((17, 37, 19), dtype, 10) + c = tkl.constant((17, 37, 19), dtype, 2) + c = (a * b) // c + c = c + a - b + + # Float types + for dtype in [tkl.f16, tkl.f32, tkl.f64]: + a = tkl.constant((17, 37, 19), dtype, 5.0) + b = tkl.constant((17, 37, 19), dtype, 10.0) + c = tkl.constant((17, 37, 19), dtype, 2.0) + c = (a * b) / c + c = c + a - b + + trace = iota_kernel._trace + print(trace.region_graph) + mb = builder.ModuleBuilder() + with indexing.IndexingContext() as idxc: + idxc.bind_constant(M, 17) + idxc.finalize() + sig = kernel_codegen.KernelSignature() + sig.add_from_graph_placeholders(trace.get_root_graph()) + sig.add_grid(iota_kernel.grid_type) + print(sig) + bound_sig, func_op = kernel_codegen.FunctionalKernelSignature.create( + sig, mb + ) + try: + emitter = vector_codegen.ThreadEmitter(bound_sig, trace) + emitter.emit() + emitter.finish() + finally: + print(mb.module_op.get_asm()) + mb.module_op.verify() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() From f2a0f18bcac45adbdcf29791c929a4c107e8ea04 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Mon, 5 Feb 2024 10:28:30 +0530 Subject: [PATCH 7/8] Black --- python/shark_turbine/kernel/compiler/builder.py | 8 ++++++-- tests/kernel/arith_test.py | 10 +++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/python/shark_turbine/kernel/compiler/builder.py b/python/shark_turbine/kernel/compiler/builder.py index c40fd9ec8..12c9dca54 100644 --- a/python/shark_turbine/kernel/compiler/builder.py +++ b/python/shark_turbine/kernel/compiler/builder.py @@ -176,7 +176,9 @@ def binary_vector_arithmetic( f"Cannot perform binary arithmetic operation '{op}' between {lhs_ir.type} and {rhs_ir.type} due to element type mismatch" ) - typeclass = "float" if self.is_floating_point_type(lhs_element_type) else "integer" + typeclass = ( + "float" if self.is_floating_point_type(lhs_element_type) else "integer" + ) attr_name = f"binary_{op}_{typeclass}" try: handler = getattr(self, attr_name) @@ -201,7 +203,9 @@ def unary_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue: def unary_vector_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue: val_ir = val.ir_value val_element_type = VectorType(val_ir.type).element_type - typeclass = "float" if self.is_floating_point_type(val_element_type) else "integer" + typeclass = ( + "float" if self.is_floating_point_type(val_element_type) else "integer" + ) attr_name = f"unary_{op}_{typeclass}" try: handler = getattr(self, attr_name) diff --git a/tests/kernel/arith_test.py b/tests/kernel/arith_test.py index 1ec1e3b54..ad1f828bd 100644 --- a/tests/kernel/arith_test.py +++ b/tests/kernel/arith_test.py @@ -25,7 +25,15 @@ def testIotaFx(self): @tk.gen.thread(M) def iota_kernel(out: tk.lang.OutputBuffer[M]): # Integer types - for dtype in [tkl.bool, tkl.i4, tkl.i8, tkl.i16, tkl.i32, tkl.i64, tkl.index]: + for dtype in [ + tkl.bool, + tkl.i4, + tkl.i8, + tkl.i16, + tkl.i32, + tkl.i64, + tkl.index, + ]: a = tkl.constant((17, 37, 19), dtype, 5) b = tkl.constant((17, 37, 19), dtype, 10) c = tkl.constant((17, 37, 19), dtype, 2) From 34e3b491ef609e9effccd9a17c10d542630a6c34 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Mon, 5 Feb 2024 10:30:44 +0530 Subject: [PATCH 8/8] Fix tests --- tests/kernel/vector_codegen_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernel/vector_codegen_test.py b/tests/kernel/vector_codegen_test.py index 25bc3781c..79ec20a66 100644 --- a/tests/kernel/vector_codegen_test.py +++ b/tests/kernel/vector_codegen_test.py @@ -138,7 +138,7 @@ def gemm_kernel( grid_n = tkl.program_id(0) grid_m = tkl.program_id(1) - acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), torch.float32, 0.0) + acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), tkl.f32, 0.0) @tkl.for_loop(0, K // BLOCK_SIZE, init_args=[acc]) def body(i, c):