Skip to content

Commit

Permalink
add gemm and for_loop tests in vector_codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss committed Jan 19, 2024
1 parent 911eb01 commit 5686f84
Showing 1 changed file with 107 additions and 0 deletions.
107 changes: 107 additions & 0 deletions tests/kernel/vector_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl

from shark_turbine.kernel.compiler import (
builder,
Expand Down Expand Up @@ -80,6 +81,112 @@ def softmax_kernel(
print(mb.module_op.get_asm())
mb.module_op.verify()

def testForLoopFx(self):
@tk.gen.thread(M)
def for_loop_kernel(
input: tk.lang.KernelBuffer[M, K], output: tk.lang.KernelBuffer[M, K]
):
row_idx = tkl.program_id(0)
sum = input[row_idx, 0]
prefetch = input[row_idx, 1]

@tkl.for_loop(2, 5, init_args=[sum, prefetch])
def prefetch_sum(i, iter_args):
new_sum = iter_args[0] + iter_args[1]
new_prefetch = input[row_idx, i]
return new_sum, new_prefetch

output[row_idx, 0] = prefetch_sum[0]

trace = for_loop_kernel._trace
print(trace.region_graph)
mb = builder.ModuleBuilder()
with indexing.IndexingContext() as idxc:
idxc.bind_constant(M, 128)
idxc.bind_constant(K, 64)

sig = kernel_codegen.KernelSignature()
sig.add_from_graph_placeholders(trace.get_root_graph())
sig.add_grid(for_loop_kernel.grid_type)
print(sig)
bound_sig, func_op = kernel_codegen.FunctionalKernelSignature.create(
sig, mb
)
emitter = vector_codegen.ThreadEmitter(bound_sig, trace)
try:
emitter.emit()
finally:
emitter.finish()
print(mb.module_op.get_asm())
mb.module_op.verify()

def testGemmFx(self):
N = tkl.sym.N
M = tkl.sym.M
K = tkl.sym.K

GRID_N = tkl.sym.GRID_N
GRID_M = tkl.sym.GRID_M

def inner_gemm(
A: tkl.KernelBuffer[N, K],
B: tkl.KernelBuffer[K, M],
output: tkl.KernelBuffer[N, M],
k: int,
block_size: int,
):
grid_n = tkl.program_id(0)
grid_m = tkl.program_id(1)

acc = tkl.constant((block_size, block_size), torch.float32, 0.0)

@tkl.for_loop(0, k // block_size, init_args=[acc])
def body(i, c):
a = tkl.load(A, (grid_n, i * block_size), (block_size, block_size))
b = tkl.load(B, (i * block_size, grid_m), (block_size, block_size))
return (tkl.dot(a, b, c),)

tkl.store(output, (grid_n, grid_m), body[0])

@tk.gen.thread(GRID_N, GRID_M)
def gemm_kernel(
A: tkl.KernelBuffer[N, K],
B: tkl.KernelBuffer[K, M],
output: tkl.KernelBuffer[N, M],
):
# TODO: We should find a way to parameterize these so we can autotune over them.
# TODO: Ideally, we should be getting k from the symbol. The symbol value
# is currently not available at tracing time which is a problem.
k = 512
block_size = 32
inner_gemm(A, B, output, k, block_size)

trace = gemm_kernel._trace
print(trace.region_graph)
mb = builder.ModuleBuilder()
with indexing.IndexingContext() as idxc:
BLOCK_SIZE = 32
idxc.bind_constant(N, 512)
idxc.bind_constant(M, 512)
idxc.bind_constant(K, 512)
idxc.bind_constant(GRID_N, 512 // BLOCK_SIZE)
idxc.bind_constant(GRID_M, 512 // BLOCK_SIZE)

sig = kernel_codegen.KernelSignature()
sig.add_from_graph_placeholders(trace.get_root_graph())
sig.add_grid(gemm_kernel.grid_type)
print(sig)
bound_sig, func_op = kernel_codegen.FunctionalKernelSignature.create(
sig, mb
)
emitter = vector_codegen.ThreadEmitter(bound_sig, trace)
try:
emitter.emit()
finally:
emitter.finish()
print(mb.module_op.get_asm())
mb.module_op.verify()


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
Expand Down

0 comments on commit 5686f84

Please sign in to comment.