Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss committed Jan 19, 2024
1 parent 1f4bd46 commit 911eb01
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
10 changes: 5 additions & 5 deletions tests/kernel/dispatch_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,22 @@ def softmax_kernel(
output_row = numerator / torch.sum(numerator)
output[row_index, :] = output_row

gm = softmax_kernel._trace.gm
print(gm.graph)
trace = softmax_kernel._trace
print(trace.region_graph)
mb = builder.ModuleBuilder()
with indexing.IndexingContext() as idxc:
idxc.bind_constant(M, 128)
idxc.bind_constant(K, 64)

sig = kernel_codegen.KernelSignature()
sig.add_from_graph_placeholders(gm.graph)
sig.add_from_graph_placeholders(trace.get_root_graph())
sig.add_grid(softmax_kernel.grid_type)

try:
exe = dispatch_codegen.StreamExecutable(mb)
dispatch_entrypoint = exe.define_entrypoint("dispatch", sig)
emitter = vector_codegen.ThreadEmitter(dispatch_entrypoint)
emitter.emit_graph(gm.graph)
emitter = vector_codegen.ThreadEmitter(dispatch_entrypoint, trace)
emitter.emit()
emitter.finish()
finally:
print(mb.module_op.get_asm())
Expand Down
4 changes: 2 additions & 2 deletions tests/kernel/simple_kernel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def iota_kernel(out: tk.lang.KernelBuffer[M]):
i = tk.lang.program_id(0)
out[i] = i

print(iota_kernel._trace.gm.graph)
print(iota_kernel._trace.region_graph)
# Prints:
# .graph():
# %out : shark_turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=out]
Expand Down Expand Up @@ -70,7 +70,7 @@ def softmax(x):
generated = softmax(input)
actual = torch.softmax(input, -1)
torch.testing.assert_close(generated, actual)
print(softmax_kernel._trace.gm.graph)
print(softmax_kernel._trace.region_graph)
# Prints:
# graph():
# %input_1 : shark_turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=input]
Expand Down
20 changes: 10 additions & 10 deletions tests/kernel/vector_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,21 @@ def iota_kernel(out: tk.lang.OutputBuffer[M]):
secret_value = ((i * (33 - i) + 4) % 8) // 2
out[i] = secret_value

gm = iota_kernel._trace.gm
print(gm.graph)
trace = iota_kernel._trace
print(trace.region_graph)
mb = builder.ModuleBuilder()
with indexing.IndexingContext() as idxc:
idxc.bind_constant(M, 17)
sig = kernel_codegen.KernelSignature()
sig.add_from_graph_placeholders(gm.graph)
sig.add_from_graph_placeholders(trace.get_root_graph())
sig.add_grid(iota_kernel.grid_type)
print(sig)
bound_sig, func_op = kernel_codegen.FunctionalKernelSignature.create(
sig, mb
)
try:
emitter = vector_codegen.ThreadEmitter(bound_sig)
emitter.emit_graph(gm.graph)
emitter = vector_codegen.ThreadEmitter(bound_sig, trace)
emitter.emit()
emitter.finish()
finally:
print(mb.module_op.get_asm())
Expand All @@ -58,23 +58,23 @@ def softmax_kernel(
output_row = numerator / torch.sum(numerator)
output[row_index, :] = output_row

gm = softmax_kernel._trace.gm
print(gm.graph)
trace = softmax_kernel._trace
print(trace.region_graph)
mb = builder.ModuleBuilder()
with indexing.IndexingContext() as idxc:
idxc.bind_constant(M, 128)
idxc.bind_constant(K, 64)

sig = kernel_codegen.KernelSignature()
sig.add_from_graph_placeholders(gm.graph)
sig.add_from_graph_placeholders(trace.get_root_graph())
sig.add_grid(softmax_kernel.grid_type)
print(sig)
bound_sig, func_op = kernel_codegen.FunctionalKernelSignature.create(
sig, mb
)
emitter = vector_codegen.ThreadEmitter(bound_sig, trace)
try:
emitter = vector_codegen.ThreadEmitter(bound_sig)
emitter.emit_graph(gm.graph)
emitter.emit()
finally:
emitter.finish()
print(mb.module_op.get_asm())
Expand Down

0 comments on commit 911eb01

Please sign in to comment.