Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR]Fix Inlining of Non-Output Consumers in TileWithTensorIntrin with Padding #17161

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/tir/schedule/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block
}
auto consumers = sch->GetConsumers(block_rv);
for (const auto& consumer : consumers) {
sch->ComputeInline(consumer);
auto sref = sch->GetSRef(consumer);
if (!tir::IsOutputBlock(sch->state(), sref, tir::GetScopeRoot(sch->state(), sref, true)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could add a simple test case to check resulted IR validity under this new condition ?

Copy link
Contributor Author

@YXY-0922 YXY-0922 Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I encountered this bug while using the meta_schedule to tune a conv2d operator. Here is the TIR example:

import tvm
from tvm import te, topi, tir
from tvm.ir.module import IRModule
from tvm.script import tir as T
from tvm.tir.schedule.transform import tile_with_tensor_intrin
from tvm.tir.tensor_intrin.cuda import WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN


@tvm.script.ir_module
class conv2d_Module:
    @T.prim_func
    def main(A: T.Buffer((16, 3, 224, 224), "float16"), B: T.Buffer((64, 3, 7, 7), "float16"), conv2d_nchw: T.Buffer((16, 64, 112, 112), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        pad_temp = T.alloc_buffer((16, 3, 230, 230), "float16")
        conv2d_nchw_reindex = T.alloc_buffer((200704, 64), "float16")
        pad_temp_reindex = T.alloc_buffer((200704, 147), "float16")
        B_reindex = T.alloc_buffer((64, 147), "float16")
        for i0, i1, i2, i3 in T.grid(16, 3, 230, 230):
            with T.block("pad_temp"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(A[v_i0, v_i1, v_i2 - 3, v_i3 - 3])
                T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
                pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i2 and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, A[v_i0, v_i1, v_i2 - 3, v_i3 - 3], T.float16(0))
        for ax0, ax1 in T.grid(200704, 147):
            with T.block("pad_temp_reindex_reindex"):
                v0, v1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7])
                T.writes(pad_temp_reindex[v0, v1])
                pad_temp_reindex[v0, v1] = pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7]
        for ax0, ax1 in T.grid(64, 147):
            with T.block("B_reindex_reindex"):
                v0, v1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(B[v0, v1 // 49, v1 % 49 // 7, v1 % 7])
                T.writes(B_reindex[v0, v1])
                B_reindex[v0, v1] = B[v0, v1 // 49, v1 % 49 // 7, v1 % 7]
        for ax0, ax1, ax2 in T.grid(200704, 64, 147):
            with T.block("conv2d_nchw"):
                v0, v1, v2 = T.axis.remap("SSR", [ax0, ax1, ax2])
                T.reads(pad_temp_reindex[v0, v2], B_reindex[v1, v2])
                T.writes(conv2d_nchw_reindex[v0, v1])
                T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
                with T.init():
                    conv2d_nchw_reindex[v0, v1] = T.float16(0)
                conv2d_nchw_reindex[v0, v1] = conv2d_nchw_reindex[v0, v1] + pad_temp_reindex[v0, v2] * B_reindex[v1, v2]
        for ax0, ax1 in T.grid(200704, 64):
            with T.block("conv2d_nchw_reindex"):
                v0, v1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(conv2d_nchw_reindex[v0, v1])
                T.writes(conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112])
                conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112] = conv2d_nchw_reindex[v0, v1]

sch = tvm.tir.Schedule(conv2d_Module)

intrin =  WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN
block = sch.get_block("conv2d_nchw")

tiled_loop = tile_with_tensor_intrin(sch, block, intrin, True)

print(sch.mod)

And the output is :

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((16, 3, 224, 224), "float16"), B: T.Buffer((64, 3, 7, 7), "float16"), conv2d_nchw: T.Buffer((16, 64, 112, 112), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        pad_temp = T.alloc_buffer((16, 3, 230, 230), "float16")
        conv2d_nchw_reindex = T.alloc_buffer((200704, 64), "float16")
        pad_temp_reindex_pad = T.alloc_buffer((200704, 160), "float16")
        B_reindex_pad = T.alloc_buffer((64, 160), "float16")
        for i0, i1, i2, i3 in T.grid(16, 3, 230, 230):
            with T.block("pad_temp"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(A[v_i0, v_i1, v_i2 - 3, v_i3 - 3])
                T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
                pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i2 and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, A[v_i0, v_i1, v_i2 - 3, v_i3 - 3], T.float16(0))
        for i0, i1 in T.grid(200704, 160):
            with T.block("pad_temp_reindex_pad"):
                v0, v1 = T.axis.remap("SS", [i0, i1])
                T.reads(pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7])
                T.writes(pad_temp_reindex_pad[v0, v1])
                pad_temp_reindex_pad[v0, v1] = T.if_then_else(v1 < 147, pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7], T.float16(0))
        for i0, i1 in T.grid(64, 160):
            with T.block("B_reindex_pad"):
                v0, v1 = T.axis.remap("SS", [i0, i1])
                T.reads(B[v0, v1 // 49, v1 % 49 // 7, v1 % 7])
                T.writes(B_reindex_pad[v0, v1])
                B_reindex_pad[v0, v1] = T.if_then_else(v1 < 147, B[v0, v1 // 49, v1 % 49 // 7, v1 % 7], T.float16(0))
        for ax0_0, ax1_0, ax2_0, ax0_1, ax1_1, ax2_1 in T.grid(12544, 4, 10, 16, 16, 16):
            with T.block("conv2d_nchw"):
                v0 = T.axis.spatial(200704, ax0_0 * 16 + ax0_1)
                v1 = T.axis.spatial(64, ax1_0 * 16 + ax1_1)
                v2 = T.axis.reduce(160, ax2_0 * 16 + ax2_1)
                T.reads(pad_temp_reindex_pad[v0, v2], B_reindex_pad[v1, v2])
                T.writes(conv2d_nchw_reindex[v0, v1])
                T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
                with T.init():
                    conv2d_nchw_reindex[v0, v1] = T.float16(0)
                conv2d_nchw_reindex[v0, v1] = conv2d_nchw_reindex[v0, v1] + pad_temp_reindex_pad[v0, v2] * B_reindex_pad[v1, v2]
        for ax0, ax1 in T.grid(200704, 64):
            with T.block("conv2d_nchw_reindex"):
                v0, v1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(conv2d_nchw_reindex[v0, v1])
                T.writes(conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112])
                conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112] = conv2d_nchw_reindex[v0, v1]

The product of the three reduction axes is 147, hence padding is required.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YXY-0922

Could add it as a simple testcase script e.g. for tests/python/meta_schedule ?
During a tuning process similar (padding) issues might be overlooked, but a testcase always catch it in CI.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will do it later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add the case in Pull Request #17171

sch->ComputeInline(consumer);
}
}
// Construct a mapping from tir loops back to LoopRVs
Expand Down
Loading