Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
feat(fuse): support vertical reduce fuse reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
BiynXu authored and 6clc committed May 5, 2023
1 parent 2fe484a commit 9b41192
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 12 deletions.
23 changes: 16 additions & 7 deletions cinn/hlir/pass/fusion_merge_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -827,14 +827,23 @@ class FusionMergePassHelper : public FusionHelperBase {
auto& consumers = input_consumers.second;
std::unordered_set<GroupPtr, Hasher, Comparator> updated_consumers;
for (auto& consumer : consumers) {
// if group is sub group
if (consumer->belong_groups.size()) {
// inset belong group to consumers.
for (auto& belong_group : consumer->belong_groups) {
updated_consumers.insert(belong_group);
std::queue<GroupPtr> fused_groups;
fused_groups.push(consumer);
while (!fused_groups.empty()) {
auto& cur = fused_groups.front();
fused_groups.pop();
// if group is sub group
if (cur->belong_groups.empty()) {
updated_consumers.insert(cur);
} else {
for (auto& belong_group : cur->belong_groups) {
if (belong_group->group_id == cur->group_id) {
updated_consumers.insert(cur);
} else {
fused_groups.push(belong_group);
}
}
}
} else {
updated_consumers.insert(consumer);
}
}
consumers = updated_consumers;
Expand Down
4 changes: 0 additions & 4 deletions cinn/hlir/pass/fusion_merge_pass_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,6 @@ CONDITION_FUNC(reduce_fuse_broadcast) {
}

CONDITION_FUNC(reduce_fuse_reduce) {
// check reduce horizontal with reduce.
if (!horizontal_relation(helper, first, second, framework::OpPatternKind::kReduction)) {
return false;
}
if (!limit_args(helper, first, second)) {
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/pass/op_fusion_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ class OpFusionPassHelper : public FusionHelperBase {
// producer -> fusion
relation.fusion_op_kind = {
// horizontal or vertical relation(Reduce + Elementwise*), check without last dimension in reduce.
{framework::kElementWise, without_last_dimension_in_reduce},
{framework::kElementWise, is_same_size},
// must be horizontal relation, check with same output shape and without last dimension in reduce.
{framework::kBroadcast, reduce_fuse_broadcast},
// must be horizontal relation and with same reduce attr.
Expand Down

0 comments on commit 9b41192

Please sign in to comment.