Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix][MetaSchedule] MultiLevelTilingTensorCore generates inconsistent thread-binding sketch for batched matmul #17012

Merged
merged 2 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> MultiLevelTilingNode::SplitLoo
return {factors, splits};
}

std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
std::vector<State> MultiLevelTilingNode::TileLoopNest(State state,
int tile_inner_most_space_loop_num) const {
Schedule& sch = state->sch;
const BlockRV& block_rv = state->block_rv;
// Step 1. Assuming trivial binding, pair the loops and their iter-var-types
Expand All @@ -199,6 +200,16 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
ICHECK_EQ(loops.size(), iter_types.size());
// Step 2. For each loop axis, tile it
int64_t spatial_loop_product = 1;

int total_spatial_loop_num = 0;
std::for_each(iter_types.begin(), iter_types.end(), [&](const auto& iter_type) {
if (iter_type == IterVarType::kDataPar) total_spatial_loop_num++;
});
CHECK_GE(total_spatial_loop_num, tile_inner_most_space_loop_num);
if (tile_inner_most_space_loop_num < 0) tile_inner_most_space_loop_num = total_spatial_loop_num;
int outer_most_spatial_loop_skipped_num = total_spatial_loop_num - tile_inner_most_space_loop_num;

Array<LoopRV> skipped_outer_spatial_loops;
std::vector<Array<LoopRV>> tiles(s_indices_.size() + r_indices_.size());
state->tile_factors.resize(tiles.size());
std::vector<Array<tir::ExprRV>> tile_factors;
Expand All @@ -208,6 +219,11 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
const std::vector<int>* idx = nullptr;

if (iter_types[i] == IterVarType::kDataPar) {
if (outer_most_spatial_loop_skipped_num > 0) {
skipped_outer_spatial_loops.push_back(loop);
outer_most_spatial_loop_skipped_num--;
continue;
}
idx = &s_indices_;
if (spatial_loop_product != -1) {
if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(loop).get())) {
Expand Down Expand Up @@ -241,6 +257,11 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
sch->Reorder(support::ConcatArrayList<LoopRV>(tiles.begin(), tiles.end()));
// Step 4. Bind the tiles to threads
int n_binds = std::min(tile_binds.size(), tiles.size());
if (skipped_outer_spatial_loops.size() && n_binds) {
auto& the_first_tile = tiles[0];
the_first_tile.insert(the_first_tile.begin(), skipped_outer_spatial_loops.begin(),
skipped_outer_spatial_loops.end());
}
for (int i = 0; i < n_binds; ++i) {
LoopRV fused = sch->Fuse(tiles[i]);
sch->Bind(fused, tile_binds[i]);
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/schedule_rule/multi_level_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
// SubRule 1. add write cache
std::vector<State> AddWriteReuse(State state) const;
// SubRule 2. tile the loop nest
std::vector<State> TileLoopNest(State state) const;
std::vector<State> TileLoopNest(State state, int tile_inner_most_space_loop_num = -1) const;
// SubRule 3. add read cache
std::vector<State> AddReadReuse(State state) const;
// SubRule 4. add async pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ std::vector<State> MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<Sta
});
states = SubRule(std::move(states), [&](State state) {
TensorCoreState tc_state = Downcast<TensorCoreState>(state);
return tc_state->is_mma ? MMATileLoopNest(tc_state) : TileLoopNest(state);
return tc_state->is_mma ? MMATileLoopNest(tc_state) : TileLoopNest(state, 2);
});
states = SubRule(std::move(states), [&](State state) {
return TransformIntermediateOutputLayout(Downcast<TensorCoreState>(state));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -903,39 +903,39 @@ def test_conv_1x1():
def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer((1, 1, 64, 64), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 64), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
# with T.block("root"):
conv2d_nhwc_reindex_shared = T.alloc_buffer((2, 2, 8, 2, 16, 16), scope="shared")
conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((2, 2, 8, 2, 16, 16), scope="wmma.accumulator")
conv2d_nhwc_reindex_shared = T.alloc_buffer((2, 1, 8, 4, 16, 16), scope="shared")
conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((2, 1, 8, 4, 16, 16), scope="wmma.accumulator")
PadInput_reindex_shared = T.alloc_buffer((256, 64), "float16", scope="shared")
weight_reindex_shared = T.alloc_buffer((1, 1, 64, 64), "float16", scope="shared")
PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer((256, 64), "float16", scope="wmma.matrix_a")
weight_reindex_shared_wmma_matrix_b = T.alloc_buffer((1, 1, 64, 64), "float16", scope="wmma.matrix_b")
for ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused in T.thread_binding(4, thread="blockIdx.y"):
for ax0_1_ax1_1_ax2_0_1_ax3_0_1_fused in T.thread_binding(1, thread="blockIdx.x"):
for ax0_2_ax1_2_ax2_0_2_ax3_0_2_fused in T.thread_binding(1, thread="threadIdx.y"):
for ax4_0_0 in range(1):
for ax0_ax1_ax2_0_0_ax3_0_0_fused in T.thread_binding(1, thread="blockIdx.y"):
for ax2_0_1_ax3_0_1_fused in T.thread_binding(1, thread="blockIdx.x"):
for ax2_0_2_ax3_0_2_fused in T.thread_binding(2, thread="threadIdx.y"):
for ax4_0_0 in range(2):
for ax0_ax1_fused in range(8192):
with T.block("PadInput_reindex_shared"):
v0 = T.axis.spatial(256, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 128 + ax0_ax1_fused // 64)
v1 = T.axis.spatial(64, ax0_ax1_fused % 64)
v0 = T.axis.spatial(256, ax0_ax1_fused // 32)
v1 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_fused % 32)
T.reads(inputs[0, v0 // 16, v0 % 16, v1])
T.writes(PadInput_reindex_shared[v0, v1])
T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2})
T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8})
PadInput_reindex_shared[v0, v1] = inputs[0, v0 // 16, v0 % 16, v1]
for ax0_ax1_ax2_ax3_fused in range(2048):
with T.block("weight_reindex_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(1, 0)
v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused // 32)
v3 = T.axis.spatial(64, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32)
v2 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_ax2_ax3_fused // 64)
v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64)
T.reads(weight[v0, v1, v2, v3])
T.writes(weight_reindex_shared[v0, v1, v2, v3])
T.block_attr({"buffer_dim_align": [[0, 2, 32, 8]], "meta_schedule.cooperative_fetch": 8})
T.block_attr({"buffer_dim_align": [[0, 2, 32, 8]], "meta_schedule.cooperative_fetch": 4})
weight_reindex_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3]
for ax4_0_1 in range(1):
for ax0_0, ax1_0 in T.grid(8, 4):
for ax0_0, ax1_0 in T.grid(8, 2):
with T.block("PadInput_reindex_shared_wmma.matrix_a_o"):
v0_o = T.axis.spatial(16, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 8 + ax0_0)
v1_o = T.axis.spatial(4, ax1_0)
v0_o = T.axis.spatial(16, ax2_0_2_ax3_0_2_fused * 8 + ax0_0)
v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax1_0)
T.reads(PadInput_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"})
Expand All @@ -945,10 +945,11 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer(
T.reads(PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 4, 2):
for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 4):
with T.block("weight_reindex_shared_wmma.matrix_b_o"):
v0_o, v1_o, v2_o = T.axis.remap("SSS", [ax0, ax1, ax2_0])
v3_o = T.axis.spatial(4, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 2 + ax3_0)
v0_o, v1_o = T.axis.remap("SS", [ax0, ax1])
v2_o = T.axis.spatial(4, ax4_0_0 * 2 + ax2_0)
v3_o = T.axis.spatial(4, ax3_0)
T.reads(weight_reindex_shared[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16])
T.writes(weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16])
T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"})
Expand All @@ -958,38 +959,38 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer(
T.reads(weight_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i])
T.writes(weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i])
weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = weight_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i]
for ax0_3, ax1_3, ax2_0_3, ax3_0_3, ax4_0_2, ax0_4, ax1_4, ax2_0_4, ax3_0_4 in T.grid(1, 1, 8, 2, 4, 1, 1, 1, 1):
for ax2_0_3, ax3_0_3, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(8, 1, 2, 1, 4):
with T.block("conv2d_nhwc_o"):
v0_o = T.axis.spatial(1, ax0_3 + ax0_4)
v1_o = T.axis.spatial(1, ax1_3 + ax1_4)
v2_o = T.axis.spatial(16, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 8 + ax2_0_3 + ax2_0_4)
v3_o = T.axis.spatial(4, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 2 + ax3_0_3 + ax3_0_4)
v4_o = T.axis.reduce(4, ax4_0_0 * 4 + ax4_0_1 * 4 + ax4_0_2)
v0_o = T.axis.spatial(1, 0)
v1_o = T.axis.spatial(1, 0)
v2_o = T.axis.spatial(16, ax2_0_2_ax3_0_2_fused * 8 + ax2_0_3 + ax2_0_4)
v3_o = T.axis.spatial(4, ax3_0_3 * 4 + ax3_0_4)
v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 * 2 + ax4_0_2)
T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16:v4_o * 16 + 16, v3_o * 16:v3_o * 16 + 16])
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, 0:16, 0:16])
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, 0:16, 0:16])
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1})
with T.init():
for ax2_1, ax3_1 in T.grid(16, 16):
with T.block("conv2d_nhwc_init"):
v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1])
T.reads()
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i_init, v3_i_init])
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i_init, v3_i_init] = T.float32(0)
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i_init, v3_i_init])
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i_init, v3_i_init] = T.float32(0)
for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16):
with T.block("conv2d_nhwc"):
v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1])
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i, v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i])
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i, v3_i])
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i])
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i])
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i, v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i, v3_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i])
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i])
for ax2 in range(8):
for ax0_ax1_fused in T.thread_binding(1, thread="threadIdx.y"):
for ax2_1, ax3 in T.grid(1, 2):
for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"):
for ax2_1, ax3 in T.grid(1, 4):
with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
v0_o = T.axis.spatial(2, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2)
v1_o = T.axis.spatial(2, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2)
v0_o = T.axis.spatial(2, ax0_ax1_fused)
v1_o = T.axis.spatial(1, 0)
v2_o = T.axis.spatial(8, ax2 + ax2_1)
v3_o = T.axis.spatial(2, ax3)
v3_o = T.axis.spatial(4, ax3)
v4_o = T.axis.spatial(1, 0)
v5_o = T.axis.spatial(1, 0)
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16])
Expand All @@ -1001,29 +1002,27 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer(
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i])
T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i])
conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]
for ax0_ax1_ax3_ax4_ax5_fused in range(512):
for ax0_ax1_ax3_ax4_ax5_fused in range(2048):
with T.block("conv2d_nhwc_reindex_shared"):
v0 = T.axis.spatial(2, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2)
v1 = T.axis.spatial(2, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2)
v0 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 1024)
v1 = T.axis.spatial(1, 0)
v2 = T.axis.spatial(8, ax2)
v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 256)
v3 = T.axis.spatial(4, ax0_ax1_ax3_ax4_ax5_fused % 1024 // 256)
v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16)
v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16)
T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5])
T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16 + v1 * 32])
T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16])
T.block_attr({"meta_schedule.cooperative_fetch": 1})
conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16 + v1 * 32] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]
conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]
# fmt: on

decision_0 = [
("SamplePerfectTile", [1, 1, 1, 1, 1]),
("SamplePerfectTile", [1, 1, 1, 1, 1]),
("SamplePerfectTile", [2, 1, 1, 8, 1]),
("SamplePerfectTile", [2, 1, 1, 2, 1]),
("SamplePerfectTile", [1, 1, 4]),
("SamplePerfectTile", [1, 1, 2, 8, 1]),
("SamplePerfectTile", [1, 1, 1, 1, 4]),
("SamplePerfectTile", [2, 1, 2]),
("SampleCategorical", 0),
("SampleCategorical", 1),
("SampleCategorical", 3),
("SampleCategorical", 2),
]

mod = te.create_prim_func(
Expand Down
Loading