diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java index cf557245d2ae305..b273c5b6ba71f76 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java @@ -152,14 +152,6 @@ public Plan visitLogicalAggregate(LogicalAggregate agg, Distinct Alias alias = new Alias(newDistinctAggFunc); outputExpressions.add(alias); if (i == 0) { - List otherAggFuncAliases = otherAggFuncs.stream() - .map(e -> ExpressionUtils.replace(e, producerToConsumerSlotMap)).collect(Collectors.toList()); - for (Expression otherAggFuncAlias : otherAggFuncAliases) { - // otherAggFunc is instance of Alias - Alias outputOtherFunc = new Alias(otherAggFuncAlias.child(0)); - outputExpressions.add(outputOtherFunc); - newToOriginDistinctFuncAlias.put(outputOtherFunc, (Alias) otherAggFuncAlias); - } // save replacedGroupBy outputJoinGroupBys.addAll(replacedGroupBy); } @@ -167,10 +159,39 @@ public Plan visitLogicalAggregate(LogicalAggregate agg, Distinct newAggs.add(newAgg); newToOriginDistinctFuncAlias.put(alias, distinctFuncWithAlias.get(i)); } + buildOtherAggFuncAggregate(otherAggFuncs, producer, ctx, cloneAgg, newToOriginDistinctFuncAlias, newAggs); List groupBy = agg.getGroupByExpressions(); LogicalJoin join = constructJoin(newAggs, groupBy); - return constructProject(groupBy, newToOriginDistinctFuncAlias, - outputJoinGroupBys, join); + return constructProject(groupBy, newToOriginDistinctFuncAlias, outputJoinGroupBys, join); + } + + private static void buildOtherAggFuncAggregate(List otherAggFuncs, LogicalCTEProducer producer, + DistinctSplitContext ctx, LogicalAggregate cloneAgg, Map newToOriginDistinctFuncAlias, + List> newAggs) { + if (otherAggFuncs.isEmpty()) { + return; + } + LogicalCTEConsumer consumer = new LogicalCTEConsumer(ctx.statementContext.getNextRelationId(), + producer.getCteId(), "", producer); + ctx.cascadesContext.putCTEIdToConsumer(consumer); + Map producerToConsumerSlotMap = new HashMap<>(); + for (Map.Entry entry : consumer.getConsumerToProducerOutputMap().entrySet()) { + producerToConsumerSlotMap.put(entry.getValue(), entry.getKey()); + } + List replacedGroupBy = ExpressionUtils.replace(cloneAgg.getGroupByExpressions(), + producerToConsumerSlotMap); + List outputExpressions = replacedGroupBy.stream() + .map(Slot.class::cast).collect(Collectors.toList()); + List otherAggFuncAliases = otherAggFuncs.stream() + .map(e -> ExpressionUtils.replace(e, producerToConsumerSlotMap)).collect(Collectors.toList()); + for (Expression otherAggFuncAlias : otherAggFuncAliases) { + // otherAggFunc is instance of Alias + Alias outputOtherFunc = new Alias(otherAggFuncAlias.child(0)); + outputExpressions.add(outputOtherFunc); + newToOriginDistinctFuncAlias.put(outputOtherFunc, (Alias) otherAggFuncAlias); + } + LogicalAggregate newAgg = new LogicalAggregate<>(replacedGroupBy, outputExpressions, consumer); + newAggs.add(newAgg); } private static boolean isDistinctMultiColumns(AggregateFunction func) { diff --git a/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out b/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out index cd10c541cc24c79..a1e600a15d7ceab 100644 --- a/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out +++ b/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out @@ -411,10 +411,6 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalResultSink ----hashJoin[INNER_JOIN] hashCondition=((.d = .d)) otherCondition=() ------hashJoin[INNER_JOIN] hashCondition=((.d = .d)) otherCondition=() ---------hashAgg[DISTINCT_LOCAL] -----------hashAgg[GLOBAL] -------------hashAgg[LOCAL] ---------------PhysicalCteConsumer ( cteId=CTEId#0 ) --------hashJoin[INNER_JOIN] hashCondition=((.d = .d)) otherCondition=() ----------hashAgg[DISTINCT_LOCAL] ------------hashAgg[GLOBAL] @@ -424,11 +420,36 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------hashAgg[GLOBAL] --------------hashAgg[LOCAL] ----------------PhysicalCteConsumer ( cteId=CTEId#0 ) +--------hashAgg[DISTINCT_LOCAL] +----------hashAgg[GLOBAL] +------------hashAgg[LOCAL] +--------------PhysicalCteConsumer ( cteId=CTEId#0 ) ------hashAgg[DISTINCT_LOCAL] --------hashAgg[GLOBAL] ----------hashAgg[LOCAL] ------------PhysicalCteConsumer ( cteId=CTEId#0 ) +-- !has_other_func -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalOlapScan[test_distinct_multi] +--PhysicalResultSink +----NestedLoopJoin[CROSS_JOIN] +------NestedLoopJoin[CROSS_JOIN] +--------hashAgg[DISTINCT_GLOBAL] +----------hashAgg[DISTINCT_LOCAL] +------------hashAgg[GLOBAL] +--------------hashAgg[LOCAL] +----------------PhysicalCteConsumer ( cteId=CTEId#0 ) +--------hashAgg[DISTINCT_GLOBAL] +----------hashAgg[DISTINCT_LOCAL] +------------hashAgg[GLOBAL] +--------------hashAgg[LOCAL] +----------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------hashAgg[GLOBAL] +--------hashAgg[LOCAL] +----------PhysicalCteConsumer ( cteId=CTEId#0 ) + -- !multi_count_with_gby -- PhysicalResultSink --hashAgg[GLOBAL] diff --git a/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy b/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy index 2876e18449181ee..6c85ecf45aee290 100644 --- a/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy +++ b/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy @@ -114,7 +114,7 @@ suite("distinct_split") { qt_count_multi_sum_avg_no_gby "select sum(distinct b), count(distinct a,d), avg(distinct c) from test_distinct_multi" qt_count_sum_avg_with_gby "select sum(distinct b), count(distinct a), avg(distinct c) from test_distinct_multi group by b,a order by 1,2,3" qt_count_multi_sum_avg_with_gby "select sum(distinct b), count(distinct a,d), avg(distinct c) from test_distinct_multi group by a,b order by 1,2,3" -//这里以下还需要验证结果 + // There is a reference query in the upper layer qt_multi_sum_has_upper """select c1+ c2 from (select sum(distinct b) c1, sum(distinct a) c2 from test_distinct_multi) t""" qt_000_count_has_upper """select abs(c1) from (select count(distinct a) c1 from test_distinct_multi) t""" @@ -189,6 +189,7 @@ suite("distinct_split") { qt_multi_count_mulitcols_with_gby """explain shape plan select count(distinct b,c), count(distinct a,b) from test_distinct_multi group by d""" qt_three_count_mulitcols_without_gby """explain shape plan select count(distinct b,c), count(distinct a,b), count(distinct a,b,c) from test_distinct_multi""" qt_four_count_mulitcols_with_gby """explain shape plan select count(distinct b,c), count(distinct a,b),count(distinct b,c,d), count(distinct a,b,c) from test_distinct_multi group by d""" + qt_has_other_func "explain shape plan select count(distinct b), count(distinct a), max(b),sum(c),min(a) from test_distinct_multi" // should not rewrite qt_multi_count_with_gby """explain shape plan select count(distinct b), count(distinct a) from test_distinct_multi group by c"""