-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR]Fix Inlining of Non-Output Consumers in TileWithTensorIntrin with Padding #17161
Conversation
…nsorIntrin In the TileWithTensorIntrin function, modified the inlining behavior of consumer blocks. Now, when padding is applied, the function inlines only non-output consumer blocks. This ensures that the padding and inlining process is correctly handled for both producers and consumers. Changes: - Added a check to ensure only non-output consumer blocks are inlined using tir::IsOutputBlock. - Updated the loop iterating over consumers to include the new check. This fix addresses issues where output blocks were being inappropriately inlined, maintaining the correct block shapes and dependencies.
Thank you for the contribution ! |
@@ -340,7 +340,9 @@ Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block | |||
} | |||
auto consumers = sch->GetConsumers(block_rv); | |||
for (const auto& consumer : consumers) { | |||
sch->ComputeInline(consumer); | |||
auto sref = sch->GetSRef(consumer); | |||
if (!tir::IsOutputBlock(sch->state(), sref, tir::GetScopeRoot(sch->state(), sref, true))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could add a simple test case to check resulted IR validity under this new condition ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I encountered this bug while using the meta_schedule to tune a conv2d operator. Here is the TIR example:
import tvm
from tvm import te, topi, tir
from tvm.ir.module import IRModule
from tvm.script import tir as T
from tvm.tir.schedule.transform import tile_with_tensor_intrin
from tvm.tir.tensor_intrin.cuda import WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN
@tvm.script.ir_module
class conv2d_Module:
@T.prim_func
def main(A: T.Buffer((16, 3, 224, 224), "float16"), B: T.Buffer((64, 3, 7, 7), "float16"), conv2d_nchw: T.Buffer((16, 64, 112, 112), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
pad_temp = T.alloc_buffer((16, 3, 230, 230), "float16")
conv2d_nchw_reindex = T.alloc_buffer((200704, 64), "float16")
pad_temp_reindex = T.alloc_buffer((200704, 147), "float16")
B_reindex = T.alloc_buffer((64, 147), "float16")
for i0, i1, i2, i3 in T.grid(16, 3, 230, 230):
with T.block("pad_temp"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(A[v_i0, v_i1, v_i2 - 3, v_i3 - 3])
T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i2 and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, A[v_i0, v_i1, v_i2 - 3, v_i3 - 3], T.float16(0))
for ax0, ax1 in T.grid(200704, 147):
with T.block("pad_temp_reindex_reindex"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7])
T.writes(pad_temp_reindex[v0, v1])
pad_temp_reindex[v0, v1] = pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7]
for ax0, ax1 in T.grid(64, 147):
with T.block("B_reindex_reindex"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(B[v0, v1 // 49, v1 % 49 // 7, v1 % 7])
T.writes(B_reindex[v0, v1])
B_reindex[v0, v1] = B[v0, v1 // 49, v1 % 49 // 7, v1 % 7]
for ax0, ax1, ax2 in T.grid(200704, 64, 147):
with T.block("conv2d_nchw"):
v0, v1, v2 = T.axis.remap("SSR", [ax0, ax1, ax2])
T.reads(pad_temp_reindex[v0, v2], B_reindex[v1, v2])
T.writes(conv2d_nchw_reindex[v0, v1])
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
with T.init():
conv2d_nchw_reindex[v0, v1] = T.float16(0)
conv2d_nchw_reindex[v0, v1] = conv2d_nchw_reindex[v0, v1] + pad_temp_reindex[v0, v2] * B_reindex[v1, v2]
for ax0, ax1 in T.grid(200704, 64):
with T.block("conv2d_nchw_reindex"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(conv2d_nchw_reindex[v0, v1])
T.writes(conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112])
conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112] = conv2d_nchw_reindex[v0, v1]
sch = tvm.tir.Schedule(conv2d_Module)
intrin = WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN
block = sch.get_block("conv2d_nchw")
tiled_loop = tile_with_tensor_intrin(sch, block, intrin, True)
print(sch.mod)
And the output is :
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((16, 3, 224, 224), "float16"), B: T.Buffer((64, 3, 7, 7), "float16"), conv2d_nchw: T.Buffer((16, 64, 112, 112), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
pad_temp = T.alloc_buffer((16, 3, 230, 230), "float16")
conv2d_nchw_reindex = T.alloc_buffer((200704, 64), "float16")
pad_temp_reindex_pad = T.alloc_buffer((200704, 160), "float16")
B_reindex_pad = T.alloc_buffer((64, 160), "float16")
for i0, i1, i2, i3 in T.grid(16, 3, 230, 230):
with T.block("pad_temp"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(A[v_i0, v_i1, v_i2 - 3, v_i3 - 3])
T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i2 and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, A[v_i0, v_i1, v_i2 - 3, v_i3 - 3], T.float16(0))
for i0, i1 in T.grid(200704, 160):
with T.block("pad_temp_reindex_pad"):
v0, v1 = T.axis.remap("SS", [i0, i1])
T.reads(pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7])
T.writes(pad_temp_reindex_pad[v0, v1])
pad_temp_reindex_pad[v0, v1] = T.if_then_else(v1 < 147, pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7], T.float16(0))
for i0, i1 in T.grid(64, 160):
with T.block("B_reindex_pad"):
v0, v1 = T.axis.remap("SS", [i0, i1])
T.reads(B[v0, v1 // 49, v1 % 49 // 7, v1 % 7])
T.writes(B_reindex_pad[v0, v1])
B_reindex_pad[v0, v1] = T.if_then_else(v1 < 147, B[v0, v1 // 49, v1 % 49 // 7, v1 % 7], T.float16(0))
for ax0_0, ax1_0, ax2_0, ax0_1, ax1_1, ax2_1 in T.grid(12544, 4, 10, 16, 16, 16):
with T.block("conv2d_nchw"):
v0 = T.axis.spatial(200704, ax0_0 * 16 + ax0_1)
v1 = T.axis.spatial(64, ax1_0 * 16 + ax1_1)
v2 = T.axis.reduce(160, ax2_0 * 16 + ax2_1)
T.reads(pad_temp_reindex_pad[v0, v2], B_reindex_pad[v1, v2])
T.writes(conv2d_nchw_reindex[v0, v1])
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
with T.init():
conv2d_nchw_reindex[v0, v1] = T.float16(0)
conv2d_nchw_reindex[v0, v1] = conv2d_nchw_reindex[v0, v1] + pad_temp_reindex_pad[v0, v2] * B_reindex_pad[v1, v2]
for ax0, ax1 in T.grid(200704, 64):
with T.block("conv2d_nchw_reindex"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(conv2d_nchw_reindex[v0, v1])
T.writes(conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112])
conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112] = conv2d_nchw_reindex[v0, v1]
The product of the three reduction axes is 147, hence padding is required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could add it as a simple testcase script e.g. for tests/python/meta_schedule
?
During a tuning process similar (padding) issues might be overlooked, but a testcase always catch it in CI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I will do it later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I add the case in Pull Request #17171
Bug Fix
In the
TileWithTensorIntrin
function, when theallow_padding
parameter is enabled, the original implementation inlines all consumer blocks. This behavior can lead to incorrect inlining of output blocks, causing issues with block shapes and dependencies. To ensure correct inlining operations, only non-output consumer blocks should be inlined.Changes Made
tir::IsOutputBlock
function to determine if a block is an output block.sch->ComputeInline
only if the block is not an output block.Specific Code Changes
TileWithTensorIntrin
function:Impact
These changes ensure that when padding is enabled, only non-output blocks will be inlined, maintaining correct block shapes and dependencies. This fixes the issue in previous versions where output blocks might be incorrectly inlined.
Please review these changes and provide feedback for further improvements. Thank you for your time and assistance!