diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java index f32bf8ea91b355..c5d3d0fb49a0a5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java @@ -36,6 +36,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; +import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.HashMap; @@ -74,8 +75,8 @@ public List buildRules() { Set funcs = agg.getAggregateFunctions(); return !funcs.isEmpty() && funcs.stream() .allMatch(f -> (f instanceof Min || f instanceof Max || f instanceof Sum - || (f instanceof Count && !((Count) f).isCountStar())) && !f.isDistinct() - && f.child(0) instanceof Slot); + || f instanceof Count && !f.isDistinct() + && (f.children().isEmpty() || f.child(0) instanceof Slot))); }) .thenApply(ctx -> { Set enableNereidsRules = ctx.cascadesContext.getConnectContext() @@ -88,15 +89,16 @@ public List buildRules() { }) .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE), logicalAggregate(logicalProject(innerLogicalJoin())) - .when(agg -> agg.child().isAllSlots()) - .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty()) - .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) + // .when(agg -> agg.child().isAllSlots()) + // .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty()) + .whenNot(agg -> agg.child() + .child(0).children().stream().anyMatch(p -> p instanceof LogicalAggregate)) .when(agg -> { Set funcs = agg.getAggregateFunctions(); return !funcs.isEmpty() && funcs.stream() .allMatch(f -> (f instanceof Min || f instanceof Max || f instanceof Sum - || (f instanceof Count && (!((Count) f).isCountStar()))) && !f.isDistinct() - && f.child(0) instanceof Slot); + || f instanceof Count) && !f.isDistinct() + && (f.children().isEmpty() || f.child(0) instanceof Slot)); }) .thenApply(ctx -> { Set enableNereidsRules = ctx.cascadesContext.getConnectContext() @@ -118,23 +120,6 @@ public static LogicalAggregate pushMinMaxSumCount(LogicalAggregate join, List projects) { List leftOutput = join.left().getOutput(); List rightOutput = join.right().getOutput(); - - List leftFuncs = new ArrayList<>(); - List rightFuncs = new ArrayList<>(); - for (AggregateFunction func : agg.getAggregateFunctions()) { - Slot slot = (Slot) func.child(0); - if (leftOutput.contains(slot)) { - leftFuncs.add(func); - } else if (rightOutput.contains(slot)) { - rightFuncs.add(func); - } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); - } - } - if (leftFuncs.isEmpty() && rightFuncs.isEmpty()) { - return null; - } - Set leftGroupBy = new HashSet<>(); Set rightGroupBy = new HashSet<>(); for (Expression e : agg.getGroupByExpressions()) { @@ -144,18 +129,71 @@ public static LogicalAggregate pushMinMaxSumCount(LogicalAggregate inputForAliasSet = proj.getInputSlots(); + for (Slot aliasInputSlot : inputForAliasSet) { + if (leftOutput.contains(aliasInputSlot)) { + leftGroupBy.add(aliasInputSlot); + } else if (rightOutput.contains(aliasInputSlot)) { + rightGroupBy.add(aliasInputSlot); + } else { + return null; + } + } + break; + } + } + } } } - join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { - if (leftOutput.contains(slot)) { - leftGroupBy.add(slot); - } else if (rightOutput.contains(slot)) { - rightGroupBy.add(slot); + + List leftFuncs = new ArrayList<>(); + List rightFuncs = new ArrayList<>(); + Count countStar = null; + Count rewrittenCountStar = null; + for (AggregateFunction func : agg.getAggregateFunctions()) { + if (func instanceof Count && ((Count) func).isCountStar()) { + countStar = (Count) func; + } else { + Slot slot = (Slot) func.child(0); + if (leftOutput.contains(slot)) { + leftFuncs.add(func); + } else if (rightOutput.contains(slot)) { + rightFuncs.add(func); + } else { + throw new IllegalStateException("Slot " + slot + " not found in join output"); + } + } + } + // rewrite count(*) to count(A), where A is slot from left/right group by key + if (countStar != null) { + if (!leftGroupBy.isEmpty()) { + rewrittenCountStar = (Count) countStar.withChildren(leftGroupBy.iterator().next()); + leftFuncs.add(rewrittenCountStar); + } else if (!rightGroupBy.isEmpty()) { + rewrittenCountStar = (Count) countStar.withChildren(rightGroupBy.iterator().next()); + rightFuncs.add(rewrittenCountStar); } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); + return null; + } + } + for (Expression condition : join.getHashJoinConjuncts()) { + for (Slot joinConditionSlot : condition.getInputSlots()) { + if (leftOutput.contains(joinConditionSlot)) { + leftGroupBy.add(joinConditionSlot); + } else if (rightOutput.contains(joinConditionSlot)) { + rightGroupBy.add(joinConditionSlot); + } else { + // apply failed + return null; + } } - })); + } Plan left = join.left(); Plan right = join.right(); @@ -196,6 +234,10 @@ public static LogicalAggregate pushMinMaxSumCount(LogicalAggregate pushMinMaxSumCount(LogicalAggregate newProjections = Lists.newArrayList(); + newProjections.addAll(project.getProjects()); + Set leftDifference = new HashSet(left.getOutput()); + leftDifference.removeAll(project.getProjects()); + newProjections.addAll(leftDifference); + Set rightDifference = new HashSet(right.getOutput()); + rightDifference.removeAll(project.getProjects()); + newProjections.addAll(rightDifference); + newAggChild = ((LogicalProject) agg.child()).withProjectsAndChild(newProjections, newJoin); + } + return agg.withAggOutputChild(newOutputExprs, newAggChild); } private static Expression replaceAggFunc(AggregateFunction func, Slot inputSlot) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java index 58ab7fbe9e925f..cffe91045d0ab2 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java @@ -323,11 +323,11 @@ void testSingleCountStar() { .applyTopDown(new PushDownAggThroughJoinOneSide()) .printlnTree() .matches( - logicalAggregate( - logicalJoin( - logicalOlapScan(), + logicalJoin( + logicalAggregate( logicalOlapScan() - ) + ), + logicalOlapScan() ) ); } @@ -346,11 +346,9 @@ void testBothSideCountAndCountStar() { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new PushDownAggThroughJoinOneSide()) .matches( - logicalAggregate( - logicalJoin( - logicalOlapScan(), - logicalOlapScan() - ) + logicalJoin( + logicalAggregate(logicalOlapScan()), + logicalAggregate(logicalOlapScan()) ) ); } diff --git a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out index da69919becd7f2..8267eb3e38ff91 100644 --- a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out +++ b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out @@ -1034,3 +1034,25 @@ Used: UnUsed: use_push_down_agg_through_join_one_side SyntaxError: +-- !shape -- +PhysicalResultSink +--PhysicalTopN[MERGE_SORT] +----PhysicalTopN[LOCAL_SORT] +------hashAgg[GLOBAL] +--------hashAgg[LOCAL] +----------hashJoin[INNER_JOIN] hashCondition=((dwd_tracking_sensor_init_tmp_ymd.dt = dw_user_b2c_tracking_info_tmp_ymd.dt) and (dwd_tracking_sensor_init_tmp_ymd.guid = dw_user_b2c_tracking_info_tmp_ymd.guid)) otherCondition=((dwd_tracking_sensor_init_tmp_ymd.dt >= substring(first_visit_time, 1, 10))) +------------hashAgg[GLOBAL] +--------------hashAgg[LOCAL] +----------------filter((dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19') and (dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click')) +------------------PhysicalOlapScan[dwd_tracking_sensor_init_tmp_ymd] +------------filter((dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19')) +--------------PhysicalOlapScan[dw_user_b2c_tracking_info_tmp_ymd] + +Hint log: +Used: use_PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE +UnUsed: +SyntaxError: + +-- !agg_pushed -- +2 是 2024-08-19 + diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy index 02e06710296333..e551fa04c9110a 100644 --- a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy +++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy @@ -426,4 +426,99 @@ suite("push_down_count_through_join_one_side") { qt_with_hint_groupby_pushdown_nested_queries """ explain shape plan select /*+ USE_CBO_RULE(push_down_agg_through_join_one_side) */ count(*) from (select * from count_t_one_side where score > 20) t1 join (select * from count_t_one_side where id < 100) t2 on t1.id = t2.id group by t1.name; """ + + sql """ + drop table if exists dw_user_b2c_tracking_info_tmp_ymd; + create table dw_user_b2c_tracking_info_tmp_ymd ( + guid int, + dt varchar, + first_visit_time varchar + )Engine=Olap + DUPLICATE KEY(guid) + distributed by hash(dt) buckets 3 + properties('replication_num' = '1'); + + insert into dw_user_b2c_tracking_info_tmp_ymd values (1, '2024-08-19', '2024-08-19'); + + drop table if exists dwd_tracking_sensor_init_tmp_ymd; + create table dwd_tracking_sensor_init_tmp_ymd ( + guid int, + dt varchar, + tracking_type varchar + )Engine=Olap + DUPLICATE KEY(guid) + distributed by hash(dt) buckets 3 + properties('replication_num' = '1'); + + insert into dwd_tracking_sensor_init_tmp_ymd values(1, '2024-08-19', 'click'), (1, '2024-08-19', 'click'); + """ + sql """ + set ENABLE_NEREIDS_RULES = "PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE"; + set disable_join_reorder=true; + """ + + qt_shape """ + explain shape plan + SELECT /*+use_cbo_rule(PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE)*/ + Count(*) AS accee593, + CASE + WHEN dwd_tracking_sensor_init_tmp_ymd.dt = + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, + 10) THEN + '是' + WHEN dwd_tracking_sensor_init_tmp_ymd.dt > + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, + 10) THEN + '否' + ELSE '-1' + end AS a1302fb2, + dwd_tracking_sensor_init_tmp_ymd.dt AS ad466123 + FROM dwd_tracking_sensor_init_tmp_ymd + LEFT JOIN dw_user_b2c_tracking_info_tmp_ymd + ON dwd_tracking_sensor_init_tmp_ymd.guid = + dw_user_b2c_tracking_info_tmp_ymd.guid + AND dwd_tracking_sensor_init_tmp_ymd.dt = + dw_user_b2c_tracking_info_tmp_ymd.dt + WHERE dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19' + AND dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19' + AND dwd_tracking_sensor_init_tmp_ymd.dt >= + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, 10) + AND dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click' + GROUP BY 2, + 3 + ORDER BY 3 ASC + LIMIT 10000; + """ + + qt_agg_pushed """ + SELECT /*+use_cbo_rule(PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE)*/ + Count(*) AS accee593, + CASE + WHEN dwd_tracking_sensor_init_tmp_ymd.dt = + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, + 10) THEN + '是' + WHEN dwd_tracking_sensor_init_tmp_ymd.dt > + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, + 10) THEN + '否' + ELSE '-1' + end AS a1302fb2, + dwd_tracking_sensor_init_tmp_ymd.dt AS ad466123 + FROM dwd_tracking_sensor_init_tmp_ymd + LEFT JOIN dw_user_b2c_tracking_info_tmp_ymd + ON dwd_tracking_sensor_init_tmp_ymd.guid = + dw_user_b2c_tracking_info_tmp_ymd.guid + AND dwd_tracking_sensor_init_tmp_ymd.dt = + dw_user_b2c_tracking_info_tmp_ymd.dt + WHERE dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19' + AND dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19' + AND dwd_tracking_sensor_init_tmp_ymd.dt >= + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, 10) + AND dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click' + GROUP BY 2, + 3 + ORDER BY 3 ASC + LIMIT 10000; + """ }