Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR]Fix Inlining of Non-Output Consumers in TileWithTensorIntrin with Padding #17161

Closed
wants to merge 1 commit into from

Conversation

YXY-0922
Copy link
Contributor

Bug Fix

In the TileWithTensorIntrin function, when the allow_padding parameter is enabled, the original implementation inlines all consumer blocks. This behavior can lead to incorrect inlining of output blocks, causing issues with block shapes and dependencies. To ensure correct inlining operations, only non-output consumer blocks should be inlined.

Changes Made

  • Added Non-Output Block Check: Before inlining consumer blocks, added a check to ensure only non-output blocks are inlined.
    • Used the tir::IsOutputBlock function to determine if a block is an output block.
    • Applied sch->ComputeInline only if the block is not an output block.

Specific Code Changes

  1. Modified the consumer inlining logic in the TileWithTensorIntrin function:
    for (const auto& consumer : consumers) {
      auto sref = sch->GetSRef(consumer);
      if (!tir::IsOutputBlock(sch->state(), sref, tir::GetScopeRoot(sch->state(), sref, true)))
        sch->ComputeInline(consumer);
    }
    

Impact

These changes ensure that when padding is enabled, only non-output blocks will be inlined, maintaining correct block shapes and dependencies. This fixes the issue in previous versions where output blocks might be incorrectly inlined.

Please review these changes and provide feedback for further improvements. Thank you for your time and assistance!

…nsorIntrin

In the TileWithTensorIntrin function, modified the inlining behavior of consumer blocks. Now, when padding is applied, the function inlines only non-output consumer blocks. This ensures that the padding and inlining process is correctly handled for both producers and consumers.

Changes:
- Added a check to ensure only non-output consumer blocks are inlined using tir::IsOutputBlock.
- Updated the loop iterating over consumers to include the new check.

This fix addresses issues where output blocks were being inappropriately inlined, maintaining the correct block shapes and dependencies.
@YXY-0922 YXY-0922 changed the title Fix Inlining of Non-Output Consumers in TileWithTensorIntrin with Padding [TIR]Fix Inlining of Non-Output Consumers in TileWithTensorIntrin with Padding Jul 16, 2024
@cbalint13 cbalint13 self-assigned this Jul 16, 2024
@cbalint13
Copy link
Contributor

@YXY-0922

Thank you for the contribution !

@@ -340,7 +340,9 @@ Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block
}
auto consumers = sch->GetConsumers(block_rv);
for (const auto& consumer : consumers) {
sch->ComputeInline(consumer);
auto sref = sch->GetSRef(consumer);
if (!tir::IsOutputBlock(sch->state(), sref, tir::GetScopeRoot(sch->state(), sref, true)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could add a simple test case to check resulted IR validity under this new condition ?

Copy link
Contributor Author

@YXY-0922 YXY-0922 Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I encountered this bug while using the meta_schedule to tune a conv2d operator. Here is the TIR example:

import tvm
from tvm import te, topi, tir
from tvm.ir.module import IRModule
from tvm.script import tir as T
from tvm.tir.schedule.transform import tile_with_tensor_intrin
from tvm.tir.tensor_intrin.cuda import WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN


@tvm.script.ir_module
class conv2d_Module:
    @T.prim_func
    def main(A: T.Buffer((16, 3, 224, 224), "float16"), B: T.Buffer((64, 3, 7, 7), "float16"), conv2d_nchw: T.Buffer((16, 64, 112, 112), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        pad_temp = T.alloc_buffer((16, 3, 230, 230), "float16")
        conv2d_nchw_reindex = T.alloc_buffer((200704, 64), "float16")
        pad_temp_reindex = T.alloc_buffer((200704, 147), "float16")
        B_reindex = T.alloc_buffer((64, 147), "float16")
        for i0, i1, i2, i3 in T.grid(16, 3, 230, 230):
            with T.block("pad_temp"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(A[v_i0, v_i1, v_i2 - 3, v_i3 - 3])
                T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
                pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i2 and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, A[v_i0, v_i1, v_i2 - 3, v_i3 - 3], T.float16(0))
        for ax0, ax1 in T.grid(200704, 147):
            with T.block("pad_temp_reindex_reindex"):
                v0, v1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7])
                T.writes(pad_temp_reindex[v0, v1])
                pad_temp_reindex[v0, v1] = pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7]
        for ax0, ax1 in T.grid(64, 147):
            with T.block("B_reindex_reindex"):
                v0, v1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(B[v0, v1 // 49, v1 % 49 // 7, v1 % 7])
                T.writes(B_reindex[v0, v1])
                B_reindex[v0, v1] = B[v0, v1 // 49, v1 % 49 // 7, v1 % 7]
        for ax0, ax1, ax2 in T.grid(200704, 64, 147):
            with T.block("conv2d_nchw"):
                v0, v1, v2 = T.axis.remap("SSR", [ax0, ax1, ax2])
                T.reads(pad_temp_reindex[v0, v2], B_reindex[v1, v2])
                T.writes(conv2d_nchw_reindex[v0, v1])
                T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
                with T.init():
                    conv2d_nchw_reindex[v0, v1] = T.float16(0)
                conv2d_nchw_reindex[v0, v1] = conv2d_nchw_reindex[v0, v1] + pad_temp_reindex[v0, v2] * B_reindex[v1, v2]
        for ax0, ax1 in T.grid(200704, 64):
            with T.block("conv2d_nchw_reindex"):
                v0, v1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(conv2d_nchw_reindex[v0, v1])
                T.writes(conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112])
                conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112] = conv2d_nchw_reindex[v0, v1]

sch = tvm.tir.Schedule(conv2d_Module)

intrin =  WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN
block = sch.get_block("conv2d_nchw")

tiled_loop = tile_with_tensor_intrin(sch, block, intrin, True)

print(sch.mod)

And the output is :

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((16, 3, 224, 224), "float16"), B: T.Buffer((64, 3, 7, 7), "float16"), conv2d_nchw: T.Buffer((16, 64, 112, 112), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        pad_temp = T.alloc_buffer((16, 3, 230, 230), "float16")
        conv2d_nchw_reindex = T.alloc_buffer((200704, 64), "float16")
        pad_temp_reindex_pad = T.alloc_buffer((200704, 160), "float16")
        B_reindex_pad = T.alloc_buffer((64, 160), "float16")
        for i0, i1, i2, i3 in T.grid(16, 3, 230, 230):
            with T.block("pad_temp"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(A[v_i0, v_i1, v_i2 - 3, v_i3 - 3])
                T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
                pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i2 and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, A[v_i0, v_i1, v_i2 - 3, v_i3 - 3], T.float16(0))
        for i0, i1 in T.grid(200704, 160):
            with T.block("pad_temp_reindex_pad"):
                v0, v1 = T.axis.remap("SS", [i0, i1])
                T.reads(pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7])
                T.writes(pad_temp_reindex_pad[v0, v1])
                pad_temp_reindex_pad[v0, v1] = T.if_then_else(v1 < 147, pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7], T.float16(0))
        for i0, i1 in T.grid(64, 160):
            with T.block("B_reindex_pad"):
                v0, v1 = T.axis.remap("SS", [i0, i1])
                T.reads(B[v0, v1 // 49, v1 % 49 // 7, v1 % 7])
                T.writes(B_reindex_pad[v0, v1])
                B_reindex_pad[v0, v1] = T.if_then_else(v1 < 147, B[v0, v1 // 49, v1 % 49 // 7, v1 % 7], T.float16(0))
        for ax0_0, ax1_0, ax2_0, ax0_1, ax1_1, ax2_1 in T.grid(12544, 4, 10, 16, 16, 16):
            with T.block("conv2d_nchw"):
                v0 = T.axis.spatial(200704, ax0_0 * 16 + ax0_1)
                v1 = T.axis.spatial(64, ax1_0 * 16 + ax1_1)
                v2 = T.axis.reduce(160, ax2_0 * 16 + ax2_1)
                T.reads(pad_temp_reindex_pad[v0, v2], B_reindex_pad[v1, v2])
                T.writes(conv2d_nchw_reindex[v0, v1])
                T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
                with T.init():
                    conv2d_nchw_reindex[v0, v1] = T.float16(0)
                conv2d_nchw_reindex[v0, v1] = conv2d_nchw_reindex[v0, v1] + pad_temp_reindex_pad[v0, v2] * B_reindex_pad[v1, v2]
        for ax0, ax1 in T.grid(200704, 64):
            with T.block("conv2d_nchw_reindex"):
                v0, v1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(conv2d_nchw_reindex[v0, v1])
                T.writes(conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112])
                conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112] = conv2d_nchw_reindex[v0, v1]

The product of the three reduction axes is 147, hence padding is required.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YXY-0922

Could add it as a simple testcase script e.g. for tests/python/meta_schedule ?
During a tuning process similar (padding) issues might be overlooked, but a testcase always catch it in CI.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will do it later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add the case in Pull Request #17171

@cbalint13
Copy link
Contributor

Closing, work here done in #17171 .

Thank you @YXY-0922 !

@cbalint13 cbalint13 closed this Jul 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants