diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 702947ebc0dc..bcaf4343e256 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -190,7 +190,8 @@ std::pair, Array> MultiLevelTilingNode::SplitLoo return {factors, splits}; } -std::vector MultiLevelTilingNode::TileLoopNest(State state) const { +std::vector MultiLevelTilingNode::TileLoopNest(State state, + int tile_inner_most_space_loop_num) const { Schedule& sch = state->sch; const BlockRV& block_rv = state->block_rv; // Step 1. Assuming trivial binding, pair the loops and their iter-var-types @@ -199,6 +200,16 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { ICHECK_EQ(loops.size(), iter_types.size()); // Step 2. For each loop axis, tile it int64_t spatial_loop_product = 1; + + int total_spatial_loop_num = 0; + std::for_each(iter_types.begin(), iter_types.end(), [&](const auto& iter_type) { + if (iter_type == IterVarType::kDataPar) total_spatial_loop_num++; + }); + CHECK_GE(total_spatial_loop_num, tile_inner_most_space_loop_num); + if (tile_inner_most_space_loop_num < 0) tile_inner_most_space_loop_num = total_spatial_loop_num; + int outer_most_spatial_loop_skipped_num = total_spatial_loop_num - tile_inner_most_space_loop_num; + + Array skipped_outer_spatial_loops; std::vector> tiles(s_indices_.size() + r_indices_.size()); state->tile_factors.resize(tiles.size()); std::vector> tile_factors; @@ -208,6 +219,11 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { const std::vector* idx = nullptr; if (iter_types[i] == IterVarType::kDataPar) { + if (outer_most_spatial_loop_skipped_num > 0) { + skipped_outer_spatial_loops.push_back(loop); + outer_most_spatial_loop_skipped_num--; + continue; + } idx = &s_indices_; if (spatial_loop_product != -1) { if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(loop).get())) { @@ -241,6 +257,11 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { sch->Reorder(support::ConcatArrayList(tiles.begin(), tiles.end())); // Step 4. Bind the tiles to threads int n_binds = std::min(tile_binds.size(), tiles.size()); + if (skipped_outer_spatial_loops.size() && n_binds) { + auto& the_first_tile = tiles[0]; + the_first_tile.insert(the_first_tile.begin(), skipped_outer_spatial_loops.begin(), + skipped_outer_spatial_loops.end()); + } for (int i = 0; i < n_binds; ++i) { LoopRV fused = sch->Fuse(tiles[i]); sch->Bind(fused, tile_binds[i]); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index 2b06aba9c106..23d6599a2538 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -162,7 +162,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode { // SubRule 1. add write cache std::vector AddWriteReuse(State state) const; // SubRule 2. tile the loop nest - std::vector TileLoopNest(State state) const; + std::vector TileLoopNest(State state, int tile_inner_most_space_loop_num = -1) const; // SubRule 3. add read cache std::vector AddReadReuse(State state) const; // SubRule 4. add async pipeline diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index e3b51dda154a..e038ab908dd8 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -251,7 +251,7 @@ std::vector MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector(state); - return tc_state->is_mma ? MMATileLoopNest(tc_state) : TileLoopNest(state); + return tc_state->is_mma ? MMATileLoopNest(tc_state) : TileLoopNest(state, 2); }); states = SubRule(std::move(states), [&](State state) { return TransformIntermediateOutputLayout(Downcast(state)); diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py index 034bddd97132..da00f294ba0e 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py @@ -903,39 +903,39 @@ def test_conv_1x1(): def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer((1, 1, 64, 64), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 64), "float32")): T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) # with T.block("root"): - conv2d_nhwc_reindex_shared = T.alloc_buffer((2, 2, 8, 2, 16, 16), scope="shared") - conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((2, 2, 8, 2, 16, 16), scope="wmma.accumulator") + conv2d_nhwc_reindex_shared = T.alloc_buffer((2, 1, 8, 4, 16, 16), scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((2, 1, 8, 4, 16, 16), scope="wmma.accumulator") PadInput_reindex_shared = T.alloc_buffer((256, 64), "float16", scope="shared") weight_reindex_shared = T.alloc_buffer((1, 1, 64, 64), "float16", scope="shared") PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer((256, 64), "float16", scope="wmma.matrix_a") weight_reindex_shared_wmma_matrix_b = T.alloc_buffer((1, 1, 64, 64), "float16", scope="wmma.matrix_b") - for ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused in T.thread_binding(4, thread="blockIdx.y"): - for ax0_1_ax1_1_ax2_0_1_ax3_0_1_fused in T.thread_binding(1, thread="blockIdx.x"): - for ax0_2_ax1_2_ax2_0_2_ax3_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): - for ax4_0_0 in range(1): + for ax0_ax1_ax2_0_0_ax3_0_0_fused in T.thread_binding(1, thread="blockIdx.y"): + for ax2_0_1_ax3_0_1_fused in T.thread_binding(1, thread="blockIdx.x"): + for ax2_0_2_ax3_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax4_0_0 in range(2): for ax0_ax1_fused in range(8192): with T.block("PadInput_reindex_shared"): - v0 = T.axis.spatial(256, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 128 + ax0_ax1_fused // 64) - v1 = T.axis.spatial(64, ax0_ax1_fused % 64) + v0 = T.axis.spatial(256, ax0_ax1_fused // 32) + v1 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_fused % 32) T.reads(inputs[0, v0 // 16, v0 % 16, v1]) T.writes(PadInput_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) PadInput_reindex_shared[v0, v1] = inputs[0, v0 // 16, v0 % 16, v1] for ax0_ax1_ax2_ax3_fused in range(2048): with T.block("weight_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) - v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused // 32) - v3 = T.axis.spatial(64, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) + v2 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_ax2_ax3_fused // 64) + v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_reindex_shared[v0, v1, v2, v3]) - T.block_attr({"buffer_dim_align": [[0, 2, 32, 8]], "meta_schedule.cooperative_fetch": 8}) + T.block_attr({"buffer_dim_align": [[0, 2, 32, 8]], "meta_schedule.cooperative_fetch": 4}) weight_reindex_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] for ax4_0_1 in range(1): - for ax0_0, ax1_0 in T.grid(8, 4): + for ax0_0, ax1_0 in T.grid(8, 2): with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): - v0_o = T.axis.spatial(16, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 8 + ax0_0) - v1_o = T.axis.spatial(4, ax1_0) + v0_o = T.axis.spatial(16, ax2_0_2_ax3_0_2_fused * 8 + ax0_0) + v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax1_0) T.reads(PadInput_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) @@ -945,10 +945,11 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( T.reads(PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 4, 2): + for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 4): with T.block("weight_reindex_shared_wmma.matrix_b_o"): - v0_o, v1_o, v2_o = T.axis.remap("SSS", [ax0, ax1, ax2_0]) - v3_o = T.axis.spatial(4, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 2 + ax3_0) + v0_o, v1_o = T.axis.remap("SS", [ax0, ax1]) + v2_o = T.axis.spatial(4, ax4_0_0 * 2 + ax2_0) + v3_o = T.axis.spatial(4, ax3_0) T.reads(weight_reindex_shared[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) T.writes(weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) @@ -958,38 +959,38 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( T.reads(weight_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) T.writes(weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = weight_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i] - for ax0_3, ax1_3, ax2_0_3, ax3_0_3, ax4_0_2, ax0_4, ax1_4, ax2_0_4, ax3_0_4 in T.grid(1, 1, 8, 2, 4, 1, 1, 1, 1): + for ax2_0_3, ax3_0_3, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(8, 1, 2, 1, 4): with T.block("conv2d_nhwc_o"): - v0_o = T.axis.spatial(1, ax0_3 + ax0_4) - v1_o = T.axis.spatial(1, ax1_3 + ax1_4) - v2_o = T.axis.spatial(16, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 8 + ax2_0_3 + ax2_0_4) - v3_o = T.axis.spatial(4, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 2 + ax3_0_3 + ax3_0_4) - v4_o = T.axis.reduce(4, ax4_0_0 * 4 + ax4_0_1 * 4 + ax4_0_2) + v0_o = T.axis.spatial(1, 0) + v1_o = T.axis.spatial(1, 0) + v2_o = T.axis.spatial(16, ax2_0_2_ax3_0_2_fused * 8 + ax2_0_3 + ax2_0_4) + v3_o = T.axis.spatial(4, ax3_0_3 * 4 + ax3_0_4) + v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 * 2 + ax4_0_2) T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16:v4_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, 0:16, 0:16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, 0:16, 0:16]) T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax2_1, ax3_1 in T.grid(16, 16): with T.block("conv2d_nhwc_init"): v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) T.reads() - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i_init, v3_i_init]) - conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i_init, v3_i_init] = T.float32(0) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i_init, v3_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i_init, v3_i_init] = T.float32(0) for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): with T.block("conv2d_nhwc"): v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i, v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i, v3_i]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i]) T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) - conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i, v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i, v3_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) for ax2 in range(8): - for ax0_ax1_fused in T.thread_binding(1, thread="threadIdx.y"): - for ax2_1, ax3 in T.grid(1, 2): + for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 4): with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(2, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2) - v1_o = T.axis.spatial(2, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2) + v0_o = T.axis.spatial(2, ax0_ax1_fused) + v1_o = T.axis.spatial(1, 0) v2_o = T.axis.spatial(8, ax2 + ax2_1) - v3_o = T.axis.spatial(2, ax3) + v3_o = T.axis.spatial(4, ax3) v4_o = T.axis.spatial(1, 0) v5_o = T.axis.spatial(1, 0) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) @@ -1001,29 +1002,27 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] - for ax0_ax1_ax3_ax4_ax5_fused in range(512): + for ax0_ax1_ax3_ax4_ax5_fused in range(2048): with T.block("conv2d_nhwc_reindex_shared"): - v0 = T.axis.spatial(2, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2) - v1 = T.axis.spatial(2, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2) + v0 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 1024) + v1 = T.axis.spatial(1, 0) v2 = T.axis.spatial(8, ax2) - v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 256) + v3 = T.axis.spatial(4, ax0_ax1_ax3_ax4_ax5_fused % 1024 // 256) v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]) - T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16 + v1 * 32]) + T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16]) T.block_attr({"meta_schedule.cooperative_fetch": 1}) - conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16 + v1 * 32] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] + conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] # fmt: on decision_0 = [ - ("SamplePerfectTile", [1, 1, 1, 1, 1]), - ("SamplePerfectTile", [1, 1, 1, 1, 1]), - ("SamplePerfectTile", [2, 1, 1, 8, 1]), - ("SamplePerfectTile", [2, 1, 1, 2, 1]), - ("SamplePerfectTile", [1, 1, 4]), + ("SamplePerfectTile", [1, 1, 2, 8, 1]), + ("SamplePerfectTile", [1, 1, 1, 1, 4]), + ("SamplePerfectTile", [2, 1, 2]), ("SampleCategorical", 0), - ("SampleCategorical", 1), ("SampleCategorical", 3), + ("SampleCategorical", 2), ] mod = te.create_prim_func(