Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comprehensive tests to test the kernel across available dtypes. #484

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions core/shark_turbine/kernel/_support/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,19 @@
_FLOAT_TYPES = ["f16", "f32", "f64"]
_INDEX_TYPES = ["index"]

import torch

_TKL_TO_TORCH_DTYPE = {
"f16": torch.half,
"f32": torch.float,
"f64": torch.double,
"i1": torch.bool,
"i8": torch.int8,
"i16": torch.int16,
"i32": torch.int32,
"i64": torch.int64,
}


# TODO: this should really be a type.
class DataType:
Expand Down Expand Up @@ -44,6 +57,14 @@ def is_float_asm(self):
def is_index_asm(self):
return self._name in _INDEX_TYPES

def to_torch_type(self):
try:
return _TKL_TO_TORCH_DTYPE[self._name]
except KeyError:
print(
f"The support for '{self._name}' dtype to torch type isn't implemented."
)


bool = DataType("bool", "i1")
i4 = DataType("i4")
Expand Down
8 changes: 8 additions & 0 deletions core/shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,14 @@ def handle_exp2(self, op, val):
kwargs={},
)

def handle_rsqrt(self, op, val):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(val,),
kwargs={},
)

def handle_vector_constant(
self, op, shape: Tuple[int, ...], dtype, value: int | float
):
Expand Down
3 changes: 3 additions & 0 deletions core/shark_turbine/kernel/compiler/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,5 +306,8 @@ def binary_truediv_float(
def unary_exp2_float(self, val: IRProxyValue) -> IRProxyValue:
return IRProxyValue(math_d.exp2(val.ir_value))

def unary_rsqrt_float(self, val: IRProxyValue) -> IRProxyValue:
return IRProxyValue(math_d.rsqrt(val.ir_value))


ScalarBuilder = _ScalarBuilder()
1 change: 1 addition & 0 deletions core/shark_turbine/kernel/compiler/vector_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def _(emitter: ThreadEmitter, node: fx.Node):

UNARY_ARITHMETIC_OPS = [
(tkl.exp2, "exp2"),
(tkl.rsqrt, "rsqrt"),
]


Expand Down
2 changes: 2 additions & 0 deletions core/shark_turbine/kernel/lang/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"broadcast_in_dim",
"transpose",
"to_dtype",
"rsqrt",
]


Expand All @@ -37,6 +38,7 @@ def is_debug() -> bool:
# Math Operations
exp2 = ops.exp2
constant = ops.vector_constant
rsqrt = ops.rsqrt

# Reduction Operations
max = ops.vector_max
Expand Down
6 changes: 6 additions & 0 deletions core/shark_turbine/kernel/ops/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

__all__ = [
"exp2",
"rsqrt",
"vector_constant",
]

Expand All @@ -22,3 +23,8 @@ def exp2(val):
@define_op
def vector_constant(shape: Tuple[int, ...], dtype, value: int | float) -> "Vector":
...


@define_op
def rsqrt(val):
...
167 changes: 167 additions & 0 deletions core/tests/kernel/coverage_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import torch
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import pytest


FLOAT_DTYPES = [tkl.f16, tkl.f32, tkl.f64]
INT_DTYPES = [
tkl.bool,
tkl.i8,
tkl.i16,
tkl.i32,
tkl.i64,
]


def rms_norm_krnl(dtype, input, weight, output):
M = tkl.sym.M
K = tkl.sym.K

@tk.gen.thread(M)
def rms_norm_kernel(
input: tkl.OutputBuffer[M, K, dtype],
weight: tk.lang.InputBuffer[M, K, dtype],
output: tk.lang.OutputBuffer[M, K, dtype],
):
row_index = tk.lang.program_id(0)
eps = tkl.constant((1,), dtype, 0.00001)
zero = tkl.constant((1,), dtype, 0.0)
input_row = input[row_index, :]
sq_inp = input_row * input_row
sq_inp_red = tkl.sum(sq_inp)
# TODO: The input_row * zero is just dummy computation to pass in the right shapes,
# otherwise it leads to 'error: unknown: 'math.exp2' op operand #0 must be floating-point-like, but got 'vector<f16>'
denom = tkl.rsqrt(input_row * zero + sq_inp_red)
denom_eta = denom + eps
output[row_index, :] = denom_eta * input_row * weight[row_index, :]

with tk.gen.TestLaunchContext():
rms_norm_kernel(input, weight, output)


def iota_krnl(dtype, input):
M = tkl.sym.M

@tk.gen.thread(M)
def iota_kernel(out: tkl.OutputBuffer[M, dtype]):
a = (
tkl.constant((17, 37, 19), dtype, 5)
if dtype in INT_DTYPES
else tkl.constant((17, 37, 19), dtype, 5.0)
)
b = (
tkl.constant((17, 37, 19), dtype, 10)
if dtype in INT_DTYPES
else tkl.constant((17, 37, 19), dtype, 10.0)
)
c = (
tkl.constant((17, 37, 19), dtype, 2)
if dtype in INT_DTYPES
else tkl.constant((17, 37, 19), dtype, 2.0)
)
if dtype in INT_DTYPES:
c = (a * b) // c
else:
c = (a * b) / c
c = c + a - b

with tk.gen.TestLaunchContext():
iota_kernel(input)


def softmax_krnl(dtype, input, output):
M = tkl.sym.M
K = tkl.sym.K

@tk.gen.thread(M)
def softmax_kernel(
input: tk.lang.InputBuffer[M, K, dtype],
output: tk.lang.OutputBuffer[M, K, dtype],
):
row_index = tk.lang.program_id(0)
input_row = input[row_index, :]
numerator = tkl.exp2(input_row - tkl.max(input_row))
if dtype in INT_DTYPES:
output_row = numerator // tkl.sum(numerator)
else:
output_row = numerator / tkl.sum(numerator)
output[row_index, :] = output_row

with tk.gen.TestLaunchContext():
softmax_kernel(input, output)


def gemm_fx_kernel(dtype, A, B, output):
N = tkl.sym.N
M = tkl.sym.M
K = tkl.sym.K
BLOCK_SIZE = tkl.sym.BLOCK_SIZE

@tk.gen.thread(N // BLOCK_SIZE, M // BLOCK_SIZE)
def gemm_kernel(
A: tkl.InputBuffer[N, K, dtype],
B: tkl.InputBuffer[K, M, dtype],
output: tkl.OutputBuffer[N, M, dtype],
):
grid_n = tkl.program_id(0)
grid_m = tkl.program_id(1)

acc = None
# TODO: Only considering the float and integer cases.
if dtype in INT_DTYPES:
acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0)
else:
acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0.0)

@tkl.for_loop(0, K // BLOCK_SIZE, init_args=[acc])
def body(i, c):
a = tkl.load(A, (grid_n, i * BLOCK_SIZE), (BLOCK_SIZE, BLOCK_SIZE))
b = tkl.load(B, (i * BLOCK_SIZE, grid_m), (BLOCK_SIZE, BLOCK_SIZE))
return (tkl.dot(a, b, c),)

tkl.store(output, (grid_n, grid_m), body[0])

with tk.gen.TestLaunchContext({BLOCK_SIZE: 32}):
gemm_kernel(A, B, output)


@pytest.mark.parametrize(
("dtype",),
[(x,) for x in FLOAT_DTYPES + INT_DTYPES],
)
def test_iota_krnl(dtype):
input = torch.zeros(17)
iota_krnl(dtype, input)


@pytest.mark.parametrize(
("dtype",),
[(x,) for x in FLOAT_DTYPES],
)
def test_rms_norm_krnl(dtype):
input = torch.randn(128, 64).to(dtype.to_torch_type())
weight = torch.randn(128, 64).to(dtype.to_torch_type())
output = torch.randn(128, 64).to(dtype.to_torch_type())
rms_norm_krnl(dtype, input, weight, output)


@pytest.mark.parametrize(
("dtype",),
[(x,) for x in FLOAT_DTYPES],
)
def test_softmax_krnl(dtype):
input = torch.randn(128, 64).to(dtype.to_torch_type())
output = torch.randn(128, 64).to(dtype.to_torch_type())
softmax_krnl(dtype, input, output)


@pytest.mark.parametrize(
("dtype",),
[(x,) for x in FLOAT_DTYPES + INT_DTYPES],
)
def test_gemm_krnl(dtype):
A = torch.randn(512, 1024).to(dtype.to_torch_type())
B = torch.randn(1024, 2048).to(dtype.to_torch_type())
output = torch.zeros(512, 2048).to(dtype.to_torch_type())
gemm_fx_kernel(dtype, A, B, output)
Loading