diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 702947ebc0dc3..bcaf4343e2561 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -190,7 +190,8 @@ std::pair, Array> MultiLevelTilingNode::SplitLoo return {factors, splits}; } -std::vector MultiLevelTilingNode::TileLoopNest(State state) const { +std::vector MultiLevelTilingNode::TileLoopNest(State state, + int tile_inner_most_space_loop_num) const { Schedule& sch = state->sch; const BlockRV& block_rv = state->block_rv; // Step 1. Assuming trivial binding, pair the loops and their iter-var-types @@ -199,6 +200,16 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { ICHECK_EQ(loops.size(), iter_types.size()); // Step 2. For each loop axis, tile it int64_t spatial_loop_product = 1; + + int total_spatial_loop_num = 0; + std::for_each(iter_types.begin(), iter_types.end(), [&](const auto& iter_type) { + if (iter_type == IterVarType::kDataPar) total_spatial_loop_num++; + }); + CHECK_GE(total_spatial_loop_num, tile_inner_most_space_loop_num); + if (tile_inner_most_space_loop_num < 0) tile_inner_most_space_loop_num = total_spatial_loop_num; + int outer_most_spatial_loop_skipped_num = total_spatial_loop_num - tile_inner_most_space_loop_num; + + Array skipped_outer_spatial_loops; std::vector> tiles(s_indices_.size() + r_indices_.size()); state->tile_factors.resize(tiles.size()); std::vector> tile_factors; @@ -208,6 +219,11 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { const std::vector* idx = nullptr; if (iter_types[i] == IterVarType::kDataPar) { + if (outer_most_spatial_loop_skipped_num > 0) { + skipped_outer_spatial_loops.push_back(loop); + outer_most_spatial_loop_skipped_num--; + continue; + } idx = &s_indices_; if (spatial_loop_product != -1) { if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(loop).get())) { @@ -241,6 +257,11 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { sch->Reorder(support::ConcatArrayList(tiles.begin(), tiles.end())); // Step 4. Bind the tiles to threads int n_binds = std::min(tile_binds.size(), tiles.size()); + if (skipped_outer_spatial_loops.size() && n_binds) { + auto& the_first_tile = tiles[0]; + the_first_tile.insert(the_first_tile.begin(), skipped_outer_spatial_loops.begin(), + skipped_outer_spatial_loops.end()); + } for (int i = 0; i < n_binds; ++i) { LoopRV fused = sch->Fuse(tiles[i]); sch->Bind(fused, tile_binds[i]); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index 2b06aba9c106b..23d6599a25380 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -162,7 +162,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode { // SubRule 1. add write cache std::vector AddWriteReuse(State state) const; // SubRule 2. tile the loop nest - std::vector TileLoopNest(State state) const; + std::vector TileLoopNest(State state, int tile_inner_most_space_loop_num = -1) const; // SubRule 3. add read cache std::vector AddReadReuse(State state) const; // SubRule 4. add async pipeline diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index e3b51dda154aa..e038ab908dd88 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -251,7 +251,7 @@ std::vector MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector(state); - return tc_state->is_mma ? MMATileLoopNest(tc_state) : TileLoopNest(state); + return tc_state->is_mma ? MMATileLoopNest(tc_state) : TileLoopNest(state, 2); }); states = SubRule(std::move(states), [&](State state) { return TransformIntermediateOutputLayout(Downcast(state));