Skip to content

Commit

Permalink
[BugFix][MetaSchedule] MultiLevelTilingTensorCore generates inconsist…
Browse files Browse the repository at this point in the history
…ent thread-binding sketch for batched matmul
  • Loading branch information
tsu-bin committed May 22, 2024
1 parent a5862a5 commit 7c4c620
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
23 changes: 22 additions & 1 deletion src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> MultiLevelTilingNode::SplitLoo
return {factors, splits};
}

std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
std::vector<State> MultiLevelTilingNode::TileLoopNest(State state,
int tile_inner_most_space_loop_num) const {
Schedule& sch = state->sch;
const BlockRV& block_rv = state->block_rv;
// Step 1. Assuming trivial binding, pair the loops and their iter-var-types
Expand All @@ -199,6 +200,16 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
ICHECK_EQ(loops.size(), iter_types.size());
// Step 2. For each loop axis, tile it
int64_t spatial_loop_product = 1;

int total_spatial_loop_num = 0;
std::for_each(iter_types.begin(), iter_types.end(), [&](const auto& iter_type) {
if (iter_type == IterVarType::kDataPar) total_spatial_loop_num++;
});
CHECK_GE(total_spatial_loop_num, tile_inner_most_space_loop_num);
if (tile_inner_most_space_loop_num < 0) tile_inner_most_space_loop_num = total_spatial_loop_num;
int outer_most_spatial_loop_skipped_num = total_spatial_loop_num - tile_inner_most_space_loop_num;

Array<LoopRV> skipped_outer_spatial_loops;
std::vector<Array<LoopRV>> tiles(s_indices_.size() + r_indices_.size());
state->tile_factors.resize(tiles.size());
std::vector<Array<tir::ExprRV>> tile_factors;
Expand All @@ -208,6 +219,11 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
const std::vector<int>* idx = nullptr;

if (iter_types[i] == IterVarType::kDataPar) {
if (outer_most_spatial_loop_skipped_num > 0) {
skipped_outer_spatial_loops.push_back(loop);
outer_most_spatial_loop_skipped_num--;
continue;
}
idx = &s_indices_;
if (spatial_loop_product != -1) {
if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(loop).get())) {
Expand Down Expand Up @@ -241,6 +257,11 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
sch->Reorder(support::ConcatArrayList<LoopRV>(tiles.begin(), tiles.end()));
// Step 4. Bind the tiles to threads
int n_binds = std::min(tile_binds.size(), tiles.size());
if (skipped_outer_spatial_loops.size() && n_binds) {
auto& the_first_tile = tiles[0];
the_first_tile.insert(the_first_tile.begin(), skipped_outer_spatial_loops.begin(),
skipped_outer_spatial_loops.end());
}
for (int i = 0; i < n_binds; ++i) {
LoopRV fused = sch->Fuse(tiles[i]);
sch->Bind(fused, tile_binds[i]);
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/schedule_rule/multi_level_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
// SubRule 1. add write cache
std::vector<State> AddWriteReuse(State state) const;
// SubRule 2. tile the loop nest
std::vector<State> TileLoopNest(State state) const;
std::vector<State> TileLoopNest(State state, int tile_inner_most_space_loop_num = -1) const;
// SubRule 3. add read cache
std::vector<State> AddReadReuse(State state) const;
// SubRule 4. add async pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ std::vector<State> MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<Sta
});
states = SubRule(std::move(states), [&](State state) {
TensorCoreState tc_state = Downcast<TensorCoreState>(state);
return tc_state->is_mma ? MMATileLoopNest(tc_state) : TileLoopNest(state);
return tc_state->is_mma ? MMATileLoopNest(tc_state) : TileLoopNest(state, 2);
});
states = SubRule(std::move(states), [&](State state) {
return TransformIntermediateOutputLayout(Downcast<TensorCoreState>(state));
Expand Down

0 comments on commit 7c4c620

Please sign in to comment.