From 9b411928c32d4bdcf6fe4704f066085704a25e4b Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Tue, 25 Apr 2023 03:14:04 +0000 Subject: [PATCH] feat(fuse): support vertical reduce fuse reduce --- cinn/hlir/pass/fusion_merge_pass.cc | 23 ++++++++++++++++------- cinn/hlir/pass/fusion_merge_pass_util.h | 4 ---- cinn/hlir/pass/op_fusion_pass.cc | 2 +- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 148c82946b..62b21db2b2 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -827,14 +827,23 @@ class FusionMergePassHelper : public FusionHelperBase { auto& consumers = input_consumers.second; std::unordered_set updated_consumers; for (auto& consumer : consumers) { - // if group is sub group - if (consumer->belong_groups.size()) { - // inset belong group to consumers. - for (auto& belong_group : consumer->belong_groups) { - updated_consumers.insert(belong_group); + std::queue fused_groups; + fused_groups.push(consumer); + while (!fused_groups.empty()) { + auto& cur = fused_groups.front(); + fused_groups.pop(); + // if group is sub group + if (cur->belong_groups.empty()) { + updated_consumers.insert(cur); + } else { + for (auto& belong_group : cur->belong_groups) { + if (belong_group->group_id == cur->group_id) { + updated_consumers.insert(cur); + } else { + fused_groups.push(belong_group); + } + } } - } else { - updated_consumers.insert(consumer); } } consumers = updated_consumers; diff --git a/cinn/hlir/pass/fusion_merge_pass_util.h b/cinn/hlir/pass/fusion_merge_pass_util.h index 03c368d6d5..550e86d91e 100644 --- a/cinn/hlir/pass/fusion_merge_pass_util.h +++ b/cinn/hlir/pass/fusion_merge_pass_util.h @@ -387,10 +387,6 @@ CONDITION_FUNC(reduce_fuse_broadcast) { } CONDITION_FUNC(reduce_fuse_reduce) { - // check reduce horizontal with reduce. - if (!horizontal_relation(helper, first, second, framework::OpPatternKind::kReduction)) { - return false; - } if (!limit_args(helper, first, second)) { return false; } diff --git a/cinn/hlir/pass/op_fusion_pass.cc b/cinn/hlir/pass/op_fusion_pass.cc index 026f2c6195..021e66e9d3 100644 --- a/cinn/hlir/pass/op_fusion_pass.cc +++ b/cinn/hlir/pass/op_fusion_pass.cc @@ -267,7 +267,7 @@ class OpFusionPassHelper : public FusionHelperBase { // producer -> fusion relation.fusion_op_kind = { // horizontal or vertical relation(Reduce + Elementwise*), check without last dimension in reduce. - {framework::kElementWise, without_last_dimension_in_reduce}, + {framework::kElementWise, is_same_size}, // must be horizontal relation, check with same output shape and without last dimension in reduce. {framework::kBroadcast, reduce_fuse_broadcast}, // must be horizontal relation and with same reduce attr.