Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Nov 27, 2024
1 parent 1427f1b commit bd177e3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
7 changes: 7 additions & 0 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,13 @@ bool haveDifferentShardings(
if (!is_mapped_in_id_model(p_loop_id, c_loop_id, id_model)) {
return true;
}
if (p_parallel_type_to_id.count(parallel_type)) {
IterDomain* p_id = p_parallel_type_to_id.at(parallel_type);
IterDomain* c_id = c_parallel_type_to_id.at(parallel_type);
if (!exact_graph.disjointValSets().strictAreMapped(p_id, c_id)) {
return true;
}
}
}

return false;
Expand Down
5 changes: 5 additions & 0 deletions csrc/preseg_passes/insert_reshardings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ bool shouldReshardAfter(Expr* expr) {
}

void insertReshardingsBefore(Fusion* fusion) {
IdModel id_model(fusion, false, false, true);
id_model.buildPermissiveGraph();
// Remove this after we refactor this as a pre-segmenter pass.
FusionGuard fg(fusion);
for (Expr* expr : fusion->exprs()) {
Expand Down Expand Up @@ -77,6 +79,9 @@ void insertReshardingsBefore(Fusion* fusion) {
}

void insertReshardingsAfter(Fusion* fusion) {
IdModel id_model(fusion, false, false, true);
id_model.buildPermissiveGraph();

// Remove this after we refactor this as a pre-segmenter pass.
FusionGuard fg(fusion);
// Iterate backwards over fusion expressions. Reshard after will
Expand Down

0 comments on commit bd177e3

Please sign in to comment.