Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/resize_scheduler_opt' into resiz…
Browse files Browse the repository at this point in the history
…e_scheduler_opt
  • Loading branch information
naoyam committed Jan 11, 2025
2 parents 46b6e3f + 21b18c2 commit 8fe98b2
Showing 1 changed file with 7 additions and 23 deletions.
30 changes: 7 additions & 23 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3603,37 +3603,30 @@ class MergeUpCastArithDownCast {
bool merged = true;
while (merged) {
merged = false;
std::unordered_set<SegmentedGroup*> considered;
std::unordered_set<SegmentedGroup*> considered_groups;
for (SegmentedGroup* group : segment_candidate_finder_->groups()) {
if (!isUpCast(group) || considered.count(group)) {
if (!isUpCast(group) || considered_groups.count(group)) {
continue;
}

std::cerr << "Initial group: " << group << "\n";
for (auto expr : group->exprs()) {
std::cerr << expr->toString();
}

// This group consists of a single up-cast expr
auto closed_groups = getClosedGroup(group);
if (!closed_groups.has_value()) {
if (closed_groups.size() < 2) {
continue;
}

for (auto group : *closed_groups) {
considered.insert(group);
for (auto group : closed_groups) {
considered_groups.insert(group);
}

if (mergeClosedGroups(closed_groups.value())) {
if (mergeClosedGroups(closed_groups)) {
merged = true;
break;
}
}
}
}

std::optional<std::vector<SegmentedGroup*>> getClosedGroup(
SegmentedGroup* initial_group) {
std::vector<SegmentedGroup*> getClosedGroup(SegmentedGroup* initial_group) {
std::vector<SegmentedGroup*> groups_to_merge;
std::unordered_set<SegmentedGroup*> groups_to_merge_set;

Expand Down Expand Up @@ -3675,14 +3668,6 @@ class MergeUpCastArithDownCast {

SegmentedGroup* mergeClosedGroups(
const std::vector<SegmentedGroup*>& groups) {
std::cerr << "Try merging upcast-arith-downcast groups\n";
for (auto group : groups) {
std::cerr << toString(group) << "\n";
for (auto expr : group->exprs()) {
std::cerr << "\t" << expr->toString();
}
}

auto sched_type = tryMerge(
segment_candidate_finder_->segmented_fusion_.get(),
segment_candidate_finder_->runtimeInfo(),
Expand All @@ -3692,7 +3677,6 @@ class MergeUpCastArithDownCast {
return nullptr;
}

std::cerr << "Merge upcast-arith-downcast groups: " << sched_type << "\n";
auto joined_group = segment_candidate_finder_->mergeAllGivenGroups(groups);

return joined_group;
Expand Down

0 comments on commit 8fe98b2

Please sign in to comment.