Skip to content

Commit

Permalink
Allow symbol use.
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident committed Jan 31, 2024
1 parent c43fb6e commit ba422f1
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 40 deletions.
6 changes: 4 additions & 2 deletions python/shark_turbine/kernel/_support/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def index_symbol(name: str) -> IndexSymbol:

def index_expr(value: Any) -> IndexExpr:
expr = sympy.sympify(value)
if not expr.is_integer:
raise ValueError(f"Expected Integer from {value}. Got {expr} ({type(expr)})")
# if not expr.is_integer:
# raise ValueError(f"Expected Integer from {value}. Got {expr} ({type(expr)})")
return expr


Expand Down Expand Up @@ -428,6 +428,8 @@ def new_unbacked_symbol(self) -> IndexSymbol:
def bind_shaped(
self, instance: Any, shaped_type: ShapedType, dims: Dims
) -> _ShapedBinding:
if instance in self.shaped_bindings:
raise ValueError(f"Argument binding {instance} is already bound")
symbolic_shape = shaped_type.symbolic_shape
rank = shaped_type.rank
if rank != len(dims):
Expand Down
13 changes: 11 additions & 2 deletions python/shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
from .indexing import (
backed_sym_index_type,
BoundedRelation,
IndexExpr,
Grid,
KernelBuffer,
SymIndex,
)

from ..lang.types import (
Expand Down Expand Up @@ -98,10 +100,17 @@ class KernelTracer(SubgraphTracer):
# Register our custom proxies.
def proxy(self, node: fx.Node) -> fx.Proxy:
t = node.type
if t is not None and issubclass(t, KernelBuffer):
return KernelBufferProxy(node, self, t)
if t is not None:
if issubclass(t, KernelBuffer):
return KernelBufferProxy(node, self, t)
return super().proxy(node)

def create_arg(self, a):
# Let IndexExpr persist as arguments.
if isinstance(a, IndexExpr):
return a
return super().create_arg(a)


class CapturedTrace:
def __init__(self, region_graph: RegionGraph, root_graph: str):
Expand Down
40 changes: 39 additions & 1 deletion python/shark_turbine/kernel/compiler/vector_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def _(emitter: ThreadEmitter, node: fx.Node):
except ValueError as e:
raise ValidationError("Malformed arguments") from e

vector_shape = cast_py_literal(emitter, vector_shape)
kb_src, kb_ir_type, kb_py_type = cast_kernel_buffer(emitter, kb)
ref_shape = kb_py_type.symbolic_shape
slice_spec = cast_slice_spec(emitter, ref_shape, multi_index)
Expand Down Expand Up @@ -486,6 +487,8 @@ def _(emitter: ThreadEmitter, node: fx.Node):
except ValueError as e:
raise ValidationError("Malformed arguments") from e

shape = cast_py_literal(emitter, shape)

# TODO: Have better way to get the dtype.
if dtype == torch.float32:
element_type = F32Type.get()
Expand Down Expand Up @@ -696,6 +699,33 @@ def emit_reduction(
###############################################################################


def cast_py_literal(emitter: ThreadEmitter, value) -> Any:
"""Treats the given value as a Python literal.
An exception will be raised if it cannot be computed statically.
"""
if isinstance(value, IndexExpr):
simplified = IndexingContext.current().simplify_expr(value)
try:
return int(simplified)
except TypeError as e:
raise CodegenError(
f"Literal value required but got symbolic value requiring "
f"dynamic resolution: {simplified}"
) from e
elif isinstance(value, tuple):
return tuple(cast_py_literal(emitter, v) for v in value)
elif isinstance(value, list):
return [cast_py_literal(emitter, v) for v in value]
elif isinstance(value, dict):
return {
cast_py_literal(emitter, k): cast_py_literal(emitter, v)
for k, v in value.items()
}
elif isinstance(value, (int, float, str)):
return value


def cast_py_value(emitter: ThreadEmitter, value) -> IRProxyValue:
"""
Converts the given value to an IR Value.
Expand All @@ -710,7 +740,15 @@ def cast_py_value(emitter: ThreadEmitter, value) -> IRProxyValue:
return node_values[0]
except KeyError:
raise CodegenError(f"Producer node `{value}` has no IR Value")

elif isinstance(value, IndexExpr):
simplified = IndexingContext.current().simplify_expr(value)
try:
value = int(simplified)
except TypeError as e:
raise CodegenError(
f"Dynamically resolved symbolic values not yet implemented. Got: "
f"{simplified}"
) from e
return ScalarBuilder.constant(value)


Expand Down
4 changes: 2 additions & 2 deletions tests/kernel/dispatch_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def softmax_kernel(
print(trace.region_graph)
mb = builder.ModuleBuilder()
with indexing.IndexingContext() as idxc:
idxc.bind_constant(M, 128)
idxc.bind_constant(K, 64)
idxc.bind_shaped(0, tk.lang.InputBuffer[M, K], (128, 64))
idxc.bind_shaped(1, tk.lang.OutputBuffer[M, K], (128, 64))
idxc.finalize()

sig = kernel_codegen.KernelSignature()
Expand Down
48 changes: 15 additions & 33 deletions tests/kernel/vector_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def softmax_kernel(
print(trace.region_graph)
mb = builder.ModuleBuilder()
with indexing.IndexingContext() as idxc:
idxc.bind_constant(M, 128)
idxc.bind_constant(K, 64)
idxc.bind_shaped(0, tk.lang.KernelBuffer[M, K], (128, 64))
idxc.bind_shaped(1, tk.lang.KernelBuffer[M, K], (128, 64))
idxc.finalize()

sig = kernel_codegen.KernelSignature()
Expand Down Expand Up @@ -104,8 +104,8 @@ def prefetch_sum(i, iter_args):
print(trace.region_graph)
mb = builder.ModuleBuilder()
with indexing.IndexingContext() as idxc:
idxc.bind_constant(M, 128)
idxc.bind_constant(K, 64)
idxc.bind_shaped(0, tk.lang.KernelBuffer[M, K], (128, 64))
idxc.bind_shaped(1, tk.lang.KernelBuffer[M, K], (128, 64))
idxc.finalize()

sig = kernel_codegen.KernelSignature()
Expand All @@ -127,53 +127,35 @@ def testGemmFx(self):
N = tkl.sym.N
M = tkl.sym.M
K = tkl.sym.K
BLOCK_SIZE = tkl.sym.BLOCK_SIZE

GRID_N = tkl.sym.GRID_N
GRID_M = tkl.sym.GRID_M

def inner_gemm(
@tk.gen.thread(N // BLOCK_SIZE, M // BLOCK_SIZE)
def gemm_kernel(
A: tkl.KernelBuffer[N, K],
B: tkl.KernelBuffer[K, M],
output: tkl.KernelBuffer[N, M],
k: int,
block_size: int,
):
grid_n = tkl.program_id(0)
grid_m = tkl.program_id(1)

acc = tkl.constant((block_size, block_size), torch.float32, 0.0)
acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), torch.float32, 0.0)

@tkl.for_loop(0, k // block_size, init_args=[acc])
@tkl.for_loop(0, K // BLOCK_SIZE, init_args=[acc])
def body(i, c):
a = tkl.load(A, (grid_n, i * block_size), (block_size, block_size))
b = tkl.load(B, (i * block_size, grid_m), (block_size, block_size))
a = tkl.load(A, (grid_n, i * BLOCK_SIZE), (BLOCK_SIZE, BLOCK_SIZE))
b = tkl.load(B, (i * BLOCK_SIZE, grid_m), (BLOCK_SIZE, BLOCK_SIZE))
return (tkl.dot(a, b, c),)

tkl.store(output, (grid_n, grid_m), body[0])

@tk.gen.thread(GRID_N, GRID_M)
def gemm_kernel(
A: tkl.KernelBuffer[N, K],
B: tkl.KernelBuffer[K, M],
output: tkl.KernelBuffer[N, M],
):
# TODO: We should find a way to parameterize these so we can autotune over them.
# TODO: Ideally, we should be getting k from the symbol. The symbol value
# is currently not available at tracing time which is a problem.
k = 512
block_size = 32
inner_gemm(A, B, output, k, block_size)

trace = gemm_kernel._trace
print(trace.region_graph)
mb = builder.ModuleBuilder()
with indexing.IndexingContext() as idxc:
BLOCK_SIZE = 32
idxc.bind_constant(N, 512)
idxc.bind_constant(M, 512)
idxc.bind_constant(K, 512)
idxc.bind_constant(GRID_N, 512 // BLOCK_SIZE)
idxc.bind_constant(GRID_M, 512 // BLOCK_SIZE)
idxc.bind_shaped("A", tkl.KernelBuffer[N, K], (512, 1024))
idxc.bind_shaped("B", tkl.KernelBuffer[K, M], (1024, 2048))
idxc.bind_shaped("output", tkl.KernelBuffer[N, M], (512, 2048))
idxc.bind_constant(BLOCK_SIZE, 32)
idxc.finalize()

sig = kernel_codegen.KernelSignature()
Expand Down

0 comments on commit ba422f1

Please sign in to comment.