From 13c384e8e8fadab0b0c386e478260096b4cc5afa Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Mon, 9 Dec 2024 10:12:37 +0800 Subject: [PATCH 1/5] add rule count distinct split add rule count distinct split add rule count distinct split add regression test add regression fix code style change by comment change to custom rewrite change to custom rewrite fix regression fix regression --- .../doris/nereids/jobs/executor/Rewriter.java | 5 + .../apache/doris/nereids/rules/RuleType.java | 1 + .../nereids/rules/analysis/CheckAnalysis.java | 31 -- .../implementation/AggregateStrategies.java | 4 +- .../nereids/rules/rewrite/DistinctSplit.java | 261 ++++++++++ .../expressions/functions/agg/Count.java | 8 +- .../functions/agg/GroupConcat.java | 3 +- .../trees/expressions/functions/agg/Sum.java | 3 +- .../trees/expressions/functions/agg/Sum0.java | 3 +- .../functions/agg/SupportMultiDistinct.java | 23 + .../rules/rewrite/DistinctSplitTest.java | 97 ++++ .../distinct_split/disitinct_split.out | 449 ++++++++++++++++++ .../distinct_split/disitinct_split.groovy | 197 ++++++++ .../aggregate_without_roll_up.groovy | 4 +- .../mv/dimension/dimension_1.groovy | 2 +- .../mv/dimension/dimension_2_3.groovy | 2 +- .../mv/dimension/dimension_2_4.groovy | 2 +- .../aggregate_strategies.groovy | 6 - 18 files changed, 1052 insertions(+), 49 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/SupportMultiDistinct.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DistinctSplitTest.java create mode 100644 regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out create mode 100644 regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 5b276258263f37..38aebc44154996 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -56,6 +56,7 @@ import org.apache.doris.nereids.rules.rewrite.CountLiteralRewrite; import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow; import org.apache.doris.nereids.rules.rewrite.DeferMaterializeTopNResult; +import org.apache.doris.nereids.rules.rewrite.DistinctSplit; import org.apache.doris.nereids.rules.rewrite.EliminateAggCaseWhen; import org.apache.doris.nereids.rules.rewrite.EliminateAggregate; import org.apache.doris.nereids.rules.rewrite.EliminateAssertNumRows; @@ -444,6 +445,7 @@ public class Rewriter extends AbstractBatchJobExecutor { new CollectCteConsumerOutput() ) ), + // topic("distinct split", topDown(new DistinctSplit())), topic("Collect used column", custom(RuleType.COLLECT_COLUMNS, QueryColumnCollector::new) ) ) @@ -550,6 +552,9 @@ private static List getWholeTreeRewriteJobs( rewriteJobs.addAll(jobs(topic("or expansion", custom(RuleType.OR_EXPANSION, () -> OrExpansion.INSTANCE)))); } + rewriteJobs.addAll(jobs(topic("distinct split", + custom(RuleType.DISTINCT_SPLIT, () -> DistinctSplit.INSTANCE)))); + if (needSubPathPushDown) { rewriteJobs.addAll(jobs( topic("variant element_at push down", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 86d0495b851bd2..a4b9b410358d0d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -322,6 +322,7 @@ public enum RuleType { MERGE_TOP_N(RuleTypeClass.REWRITE), BUILD_AGG_FOR_UNION(RuleTypeClass.REWRITE), COUNT_DISTINCT_REWRITE(RuleTypeClass.REWRITE), + DISTINCT_SPLIT(RuleTypeClass.REWRITE), INNER_TO_CROSS_JOIN(RuleTypeClass.REWRITE), CROSS_TO_INNER_JOIN(RuleTypeClass.REWRITE), PRUNE_EMPTY_PARTITION(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java index 7ca8637446b0d6..13455720b07a5c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java @@ -21,7 +21,6 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.OrderExpression; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction; @@ -139,36 +138,6 @@ private void checkExpressionInputTypes(Plan plan) { } private void checkAggregate(LogicalAggregate aggregate) { - Set aggregateFunctions = aggregate.getAggregateFunctions(); - boolean distinctMultiColumns = false; - for (AggregateFunction func : aggregateFunctions) { - if (!func.isDistinct()) { - continue; - } - if (func.arity() <= 1) { - continue; - } - for (int i = 1; i < func.arity(); i++) { - if (!func.child(i).getInputSlots().isEmpty() && !(func.child(i) instanceof OrderExpression)) { - // think about group_concat(distinct col_1, ',') - distinctMultiColumns = true; - break; - } - } - if (distinctMultiColumns) { - break; - } - } - - long distinctFunctionNum = 0; - for (AggregateFunction aggregateFunction : aggregateFunctions) { - distinctFunctionNum += aggregateFunction.isDistinct() ? 1 : 0; - } - - if (distinctMultiColumns && distinctFunctionNum > 1) { - throw new AnalysisException( - "The query contains multi count distinct or sum distinct, each can't have multi columns"); - } for (Expression expr : aggregate.getGroupByExpressions()) { if (expr.anyMatch(AggregateFunction.class::isInstance)) { throw new AnalysisException( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index 3d25bce0b48b5e..e98a9c6767daea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -52,7 +52,6 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; -import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0; import org.apache.doris.nereids.trees.expressions.functions.scalar.If; @@ -1810,8 +1809,7 @@ private List getHashAggregatePartitionExpressions( private AggregateFunction tryConvertToMultiDistinct(AggregateFunction function) { if (function instanceof Count && function.isDistinct()) { - return new MultiDistinctCount(function.getArgument(0), - function.getArguments().subList(1, function.arity()).toArray(new Expression[0])); + return ((Count) function).convertToMultiDistinct(); } else if (function instanceof Sum && function.isDistinct()) { return ((Sum) function).convertToMultiDistinct(); } else if (function instanceof Sum0 && function.isDistinct()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java new file mode 100644 index 00000000000000..cf557245d2ae30 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java @@ -0,0 +1,261 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.StatementContext; +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.rules.rewrite.DistinctSplit.DistinctSplitContext; +import org.apache.doris.nereids.trees.copier.DeepCopierContext; +import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.SupportMultiDistinct; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableList; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * LogicalAggregate(output:count(distinct a) as c1, count(distinct b) as c2) + * +--Plan + * -> + * LogicalCTEAnchor + * +--LogicalCTEProducer + * +--Plan + * +--LogicalProject(c1, c2) + * +--LogicalJoin + * +--LogicalAggregate(output:count(distinct a)) + * +--LogicalCTEConsumer + * +--LogicalAggregate(output:count(distinct b)) + * +--LogicalCTEConsumer + * */ +public class DistinctSplit extends DefaultPlanRewriter implements CustomRewriter { + public static DistinctSplit INSTANCE = new DistinctSplit(); + + /**DistinctSplitContext*/ + public static class DistinctSplitContext { + List> cteProducerList; + StatementContext statementContext; + CascadesContext cascadesContext; + + public DistinctSplitContext(StatementContext statementContext, CascadesContext cascadesContext) { + this.statementContext = statementContext; + this.cteProducerList = new ArrayList<>(); + this.cascadesContext = cascadesContext; + } + } + + @Override + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + DistinctSplitContext ctx = new DistinctSplitContext( + jobContext.getCascadesContext().getStatementContext(), jobContext.getCascadesContext()); + plan = plan.accept(this, ctx); + for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) { + LogicalCTEProducer producer = ctx.cteProducerList.get(i); + plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan); + } + return plan; + } + + @Override + public Plan visitLogicalCTEAnchor( + LogicalCTEAnchor anchor, DistinctSplitContext ctx) { + Plan child1 = anchor.child(0).accept(this, ctx); + DistinctSplitContext consumerContext = + new DistinctSplitContext(ctx.statementContext, ctx.cascadesContext); + Plan child2 = anchor.child(1).accept(this, consumerContext); + for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) { + LogicalCTEProducer producer = consumerContext.cteProducerList.get(i); + child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, child2); + } + return anchor.withChildren(ImmutableList.of(child1, child2)); + } + + @Override + public Plan visitLogicalAggregate(LogicalAggregate agg, DistinctSplitContext ctx) { + List distinctFuncWithAlias = new ArrayList<>(); + List otherAggFuncs = new ArrayList<>(); + if (!needTransform((LogicalAggregate) agg, distinctFuncWithAlias, otherAggFuncs)) { + return agg; + } + + LogicalAggregate cloneAgg = (LogicalAggregate) LogicalPlanDeepCopier.INSTANCE + .deepCopy(agg, new DeepCopierContext()); + LogicalCTEProducer producer = new LogicalCTEProducer<>(ctx.statementContext.getNextCTEId(), + cloneAgg.child()); + ctx.cteProducerList.add(producer); + Map originToProducerSlot = new HashMap<>(); + for (int i = 0; i < agg.child().getOutput().size(); ++i) { + Slot originSlot = agg.child().getOutput().get(i); + Slot cloneSlot = cloneAgg.child().getOutput().get(i); + originToProducerSlot.put(originSlot, cloneSlot); + } + distinctFuncWithAlias = ExpressionUtils.replace((List) distinctFuncWithAlias, originToProducerSlot); + otherAggFuncs = ExpressionUtils.replace((List) otherAggFuncs, originToProducerSlot); + // construct cte consumer and aggregate + List> newAggs = new ArrayList<>(); + // All otherAggFuncs are placed in the first one + Map newToOriginDistinctFuncAlias = new HashMap<>(); + List outputJoinGroupBys = new ArrayList<>(); + for (int i = 0; i < distinctFuncWithAlias.size(); ++i) { + Expression distinctAggFunc = distinctFuncWithAlias.get(i).child(0); + LogicalCTEConsumer consumer = new LogicalCTEConsumer(ctx.statementContext.getNextRelationId(), + producer.getCteId(), "", producer); + ctx.cascadesContext.putCTEIdToConsumer(consumer); + Map producerToConsumerSlotMap = new HashMap<>(); + for (Map.Entry entry : consumer.getConsumerToProducerOutputMap().entrySet()) { + producerToConsumerSlotMap.put(entry.getValue(), entry.getKey()); + } + List replacedGroupBy = ExpressionUtils.replace(cloneAgg.getGroupByExpressions(), + producerToConsumerSlotMap); + Expression newDistinctAggFunc = ExpressionUtils.replace(distinctAggFunc, producerToConsumerSlotMap); + List outputExpressions = replacedGroupBy.stream() + .map(Slot.class::cast).collect(Collectors.toList()); + Alias alias = new Alias(newDistinctAggFunc); + outputExpressions.add(alias); + if (i == 0) { + List otherAggFuncAliases = otherAggFuncs.stream() + .map(e -> ExpressionUtils.replace(e, producerToConsumerSlotMap)).collect(Collectors.toList()); + for (Expression otherAggFuncAlias : otherAggFuncAliases) { + // otherAggFunc is instance of Alias + Alias outputOtherFunc = new Alias(otherAggFuncAlias.child(0)); + outputExpressions.add(outputOtherFunc); + newToOriginDistinctFuncAlias.put(outputOtherFunc, (Alias) otherAggFuncAlias); + } + // save replacedGroupBy + outputJoinGroupBys.addAll(replacedGroupBy); + } + LogicalAggregate newAgg = new LogicalAggregate<>(replacedGroupBy, outputExpressions, consumer); + newAggs.add(newAgg); + newToOriginDistinctFuncAlias.put(alias, distinctFuncWithAlias.get(i)); + } + List groupBy = agg.getGroupByExpressions(); + LogicalJoin join = constructJoin(newAggs, groupBy); + return constructProject(groupBy, newToOriginDistinctFuncAlias, + outputJoinGroupBys, join); + } + + private static boolean isDistinctMultiColumns(AggregateFunction func) { + if (func.arity() <= 1) { + return false; + } + for (int i = 1; i < func.arity(); ++i) { + // think about group_concat(distinct col_1, ',') + if (!(func.child(i) instanceof OrderExpression) && !func.child(i).getInputSlots().isEmpty()) { + return true; + } + } + return false; + } + + private static boolean needTransform(LogicalAggregate agg, List aliases, List otherAggFuncs) { + // TODO with source repeat aggregate need to be supported in future + if (agg.getSourceRepeat().isPresent()) { + return false; + } + Set distinctFunc = new HashSet<>(); + boolean distinctMultiColumns = false; + for (NamedExpression namedExpression : agg.getOutputExpressions()) { + if (!(namedExpression instanceof Alias) || !(namedExpression.child(0) instanceof AggregateFunction)) { + continue; + } + AggregateFunction aggFunc = (AggregateFunction) namedExpression.child(0); + if (aggFunc instanceof SupportMultiDistinct && aggFunc.isDistinct()) { + aliases.add((Alias) namedExpression); + distinctFunc.add(aggFunc); + distinctMultiColumns = distinctMultiColumns || isDistinctMultiColumns(aggFunc); + } else { + otherAggFuncs.add((Alias) namedExpression); + } + } + if (distinctFunc.size() <= 1) { + return false; + } + if (!distinctMultiColumns && !agg.getGroupByExpressions().isEmpty()) { + return false; + } + return true; + } + + private static LogicalProject constructProject(List groupBy, Map joinOutput, + List outputJoinGroupBys, LogicalJoin join) { + List projects = new ArrayList<>(); + for (Map.Entry entry : joinOutput.entrySet()) { + projects.add(new Alias(entry.getValue().getExprId(), entry.getKey().toSlot(), entry.getValue().getName())); + } + // outputJoinGroupBys.size() == agg.getGroupByExpressions().size() + for (int i = 0; i < groupBy.size(); ++i) { + Slot slot = (Slot) groupBy.get(i); + projects.add(new Alias(slot.getExprId(), outputJoinGroupBys.get(i), slot.getName())); + } + return new LogicalProject<>(projects, join); + } + + private static LogicalJoin constructJoin(List> newAggs, + List groupBy) { + LogicalJoin join; + if (groupBy.isEmpty()) { + join = new LogicalJoin<>(JoinType.CROSS_JOIN, newAggs.get(0), newAggs.get(1), null); + for (int j = 2; j < newAggs.size(); ++j) { + join = new LogicalJoin<>(JoinType.CROSS_JOIN, join, newAggs.get(j), null); + } + } else { + int len = groupBy.size(); + List leftSlots = newAggs.get(0).getOutput(); + List rightSlots = newAggs.get(1).getOutput(); + List hashConditions = new ArrayList<>(); + for (int i = 0; i < len; ++i) { + hashConditions.add(new EqualTo(leftSlots.get(i), rightSlots.get(i))); + } + join = new LogicalJoin<>(JoinType.INNER_JOIN, hashConditions, newAggs.get(0), newAggs.get(1), null); + for (int j = 2; j < newAggs.size(); ++j) { + List belowJoinSlots = join.left().getOutput(); + List belowRightSlots = newAggs.get(j).getOutput(); + List aboveHashConditions = new ArrayList<>(); + for (int i = 0; i < len; ++i) { + aboveHashConditions.add(new EqualTo(belowJoinSlots.get(i), belowRightSlots.get(i))); + } + join = new LogicalJoin<>(JoinType.INNER_JOIN, aboveHashConditions, join, newAggs.get(j), null); + } + } + return join; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java index 21e6ee1cba6b21..ba16b07ed5fafa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java @@ -37,7 +37,7 @@ /** count agg function. */ public class Count extends NotNullableAggregateFunction - implements ExplicitlyCastableSignature, SupportWindowAnalytic, RollUpTrait { + implements ExplicitlyCastableSignature, SupportWindowAnalytic, RollUpTrait, SupportMultiDistinct { public static final List SIGNATURES = ImmutableList.of( // count(*) @@ -162,4 +162,10 @@ public boolean canRollUp() { public Expression resultForEmptyInput() { return new BigIntLiteral(0); } + + @Override + public AggregateFunction convertToMultiDistinct() { + return new MultiDistinctCount(getArgument(0), + getArguments().subList(1, arity()).toArray(new Expression[0])); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java index 2505329b2fe901..61cd525e65117b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java @@ -37,7 +37,7 @@ * AggregateFunction 'group_concat'. This class is generated by GenerateFunction. */ public class GroupConcat extends NullableAggregateFunction - implements ExplicitlyCastableSignature { + implements ExplicitlyCastableSignature, SupportMultiDistinct { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT), @@ -133,6 +133,7 @@ public List getSignatures() { return SIGNATURES; } + @Override public MultiDistinctGroupConcat convertToMultiDistinct() { Preconditions.checkArgument(distinct, "can't convert to multi_distinct_group_concat because there is no distinct args"); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java index e55f926ae4d3ca..b6a8cd86566b93 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java @@ -47,7 +47,7 @@ */ public class Sum extends NullableAggregateFunction implements UnaryExpression, ExplicitlyCastableSignature, ComputePrecisionForSum, SupportWindowAnalytic, - RollUpTrait { + RollUpTrait, SupportMultiDistinct { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE), @@ -78,6 +78,7 @@ public Sum(boolean distinct, boolean alwaysNullable, Expression arg) { super("sum", distinct, alwaysNullable, arg); } + @Override public MultiDistinctSum convertToMultiDistinct() { Preconditions.checkArgument(distinct, "can't convert to multi_distinct_sum because there is no distinct args"); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java index 5a1f0f9fb93d34..9d220237a699a8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java @@ -54,7 +54,7 @@ */ public class Sum0 extends NotNullableAggregateFunction implements UnaryExpression, ExplicitlyCastableSignature, ComputePrecisionForSum, - SupportWindowAnalytic, RollUpTrait { + SupportWindowAnalytic, RollUpTrait, SupportMultiDistinct { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE), @@ -81,6 +81,7 @@ public Sum0(boolean distinct, Expression arg) { super("sum0", distinct, arg); } + @Override public MultiDistinctSum0 convertToMultiDistinct() { Preconditions.checkArgument(distinct, "can't convert to multi_distinct_sum because there is no distinct args"); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/SupportMultiDistinct.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/SupportMultiDistinct.java new file mode 100644 index 00000000000000..848c529e5c32b2 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/SupportMultiDistinct.java @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +/** MultiDistinctTrait*/ +public interface SupportMultiDistinct { + AggregateFunction convertToMultiDistinct(); +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DistinctSplitTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DistinctSplitTest.java new file mode 100644 index 00000000000000..41801abd352568 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DistinctSplitTest.java @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.util.MatchingUtils; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.utframe.TestWithFeService; + +import org.junit.jupiter.api.Test; + +public class DistinctSplitTest extends TestWithFeService implements MemoPatternMatchSupported { + @Override + protected void runBeforeAll() throws Exception { + createDatabase("test"); + createTable("create table test.test_distinct_multi(a int, b int, c int, d varchar(10), e date)" + + "distributed by hash(a) properties('replication_num'='1');"); + connectContext.setDatabase("test"); + connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); + } + + @Test + void multiCountWithoutGby() { + String sql = "select count(distinct b), count(distinct a) from test_distinct_multi"; + PlanChecker.from(connectContext).checkExplain(sql, planner -> { + Plan plan = planner.getOptimizedPlan(); + MatchingUtils.assertMatches(plan, physicalCTEAnchor(physicalCTEProducer(any()), physicalResultSink(physicalProject(physicalNestedLoopJoin( + physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))))), + physicalDistribute(physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any()))))))) + ))))); + }); + } + + @Test + void multiSumWithoutGby() { + String sql = "select sum(distinct b), sum(distinct a) from test_distinct_multi"; + PlanChecker.from(connectContext).checkExplain(sql, planner -> { + Plan plan = planner.getOptimizedPlan(); + MatchingUtils.assertMatches(plan, physicalCTEAnchor(physicalCTEProducer(any()), physicalResultSink(physicalProject(physicalNestedLoopJoin( + physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))))), + physicalDistribute(physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any()))))))) + ))))); + }); + } + + @Test + void sumCountWithoutGby() { + String sql = "select sum(distinct b), count(distinct a) from test_distinct_multi"; + PlanChecker.from(connectContext).checkExplain(sql, planner -> { + Plan plan = planner.getOptimizedPlan(); + MatchingUtils.assertMatches(plan, physicalCTEAnchor(physicalCTEProducer(any()), physicalResultSink(physicalProject(physicalNestedLoopJoin( + physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))))), + physicalDistribute(physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any()))))))) + ))))); + }); + } + + @Test + void countMultiColumnsWithoutGby() { + String sql = "select count(distinct b,c), count(distinct a,b) from test_distinct_multi"; + PlanChecker.from(connectContext).checkExplain(sql, planner -> { + Plan plan = planner.getOptimizedPlan(); + MatchingUtils.assertMatches(plan, physicalCTEAnchor(physicalCTEProducer(any()), physicalResultSink(physicalProject(physicalNestedLoopJoin( + physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))), + physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any()))))) + ))))); + }); + } + + @Test + void countMultiColumnsWithGby() { + String sql = "select count(distinct b,c), count(distinct a,b) from test_distinct_multi group by d"; + PlanChecker.from(connectContext).checkExplain(sql, planner -> { + Plan plan = planner.getOptimizedPlan(); + MatchingUtils.assertMatches(plan, physicalCTEAnchor(physicalCTEProducer(any()), physicalResultSink(physicalDistribute(physicalProject(physicalHashJoin( + physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))), + physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))) + )))))); + }); + } +} diff --git a/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out b/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out new file mode 100644 index 00000000000000..cd10c541cc24c7 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out @@ -0,0 +1,449 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !000_count -- +2 + +-- !001_count -- +1 2 + +-- !010_count -- +2 + +-- !010_count_same_column_with_groupby -- +1 +1 + +-- !011_count_same_column_with_groupby -- +1 1 +1 1 + +-- !011_count_diff_column_with_groupby -- +1 1 +2 1 + +-- !011_count_diff_column_with_groupby_multi -- +1 1 +1 1 +1 1 + +-- !011_count_diff_column_with_groupby_all -- +1 1 +1 1 +1 1 + +-- !100 -- +2 + +-- !101 -- +2 3 + +-- !101_count_one_col_and_two_col -- +2 2 + +-- !101_count_one_col_and_two_col -- +2 2 + +-- !110_count_diff_column_with_groupby -- +1 +2 + +-- !110_count_same_column_with_groupby1 -- +1 +1 + +-- !110_count_same_column_with_groupby2 -- +1 +1 + +-- !111_count_same_column_with_groupby1 -- +1 1 +2 2 + +-- !111_count_same_column_with_groupby2 -- +1 1 +1 1 +1 1 + +-- !111_count_diff_column_with_groupby -- +1 1 +2 2 + +-- !000_count_other_func -- +2 2 22 1 + +-- !001_count_other_func -- +1 2 2 22 1 + +-- !010_count_other_func -- +2 2 22 1 2 + +-- !011_count_other_func -- +1 1 2 8 2 2 +1 1 2 14 1 1 + +-- !100_count_other_func -- +2 1 2 22 1 + +-- !101_count_other_func -- +2 3 2 22 1 + +-- !110_count_other_func -- +1 2 6 1 3 +2 2 16 1 4 + +-- !111_count_other_func -- +1 1 2 6 1 3 +2 2 2 16 1 4 + +-- !001_three -- +1 2 2 + +-- !001_four -- +1 2 2 3 + +-- !001_five -- +1 2 2 3 4 + +-- !011_three -- +1 1 1 +1 2 2 +1 2 2 + +-- !011_four -- +1 1 1 1 +1 2 2 1 +1 2 2 1 + +-- !011_five -- +1 1 1 1 1 +1 2 2 1 2 +1 2 2 1 2 + +-- !011_three_gby_multi -- +1 1 1 +1 1 1 +1 1 1 +1 1 1 +1 1 2 + +-- !011_four_gby_multi -- +1 1 1 1 +1 1 1 1 +1 1 1 1 +1 1 1 1 +1 1 1 1 +1 1 1 1 + +-- !011_five_gby_multi -- +1 1 1 1 1 +1 1 1 1 1 +1 1 1 1 1 +1 1 1 1 1 +1 1 2 1 2 + +-- !101_three -- +2 3 2 + +-- !101_four -- +2 3 2 2 + +-- !101_five -- +2 3 5 2 6 + +-- !111_three -- +1 1 1 +2 2 2 + +-- !111_four -- +1 1 1 1 +1 1 1 1 +1 1 1 1 +2 2 2 2 + +-- !111_five -- +1 1 1 1 1 +1 1 1 1 1 +1 1 1 1 1 +2 2 3 2 3 + +-- !111_three_gby_multi -- +1 1 1 +1 1 1 +1 1 1 + +-- !111_four_gby_multi -- +1 1 1 1 +1 1 1 1 +1 1 1 1 +1 1 1 1 +1 1 1 1 + +-- !111_five_gby_multi -- +1 1 1 1 1 +1 1 1 1 1 +1 1 1 1 1 +1 1 1 1 1 +1 1 1 1 1 +1 1 1 1 1 + +-- !00_sum -- +2 + +-- !10_sum -- +2 3 + +-- !01_sum -- +2 +2 + +-- !11_sum -- +2 1 +2 2 + +-- !00_avg -- +2.0 + +-- !10_avg -- +2.0 1.5 + +-- !01_avg -- +2.0 +2.0 + +-- !11_avg -- +2.0 1.0 +2.0 2.0 + +-- !count_sum_avg_no_gby -- +2 2 3.5 + +-- !count_multi_sum_avg_no_gby -- +2 5 3.5 + +-- !count_sum_avg_with_gby -- +2 1 3.5 +2 1 4.0 + +-- !count_multi_sum_avg_with_gby -- +2 2 4.0 +2 3 3.5 + +-- !multi_sum_has_upper -- +5 + +-- !000_count_has_upper -- +2 + +-- !010_count_has_upper -- +102 + +-- !011_count_diff_column_with_groupby_all_has_upper -- +1 1 + +-- !100_has_upper -- +3 + +-- !101_has_upper -- +105 + +-- !111_count_same_column_with_groupby1_has_upper -- +2 1 + +-- !010_count_sum_other_func_has_upper -- +2 4 25 2 + +-- !010_count_other_func_has_upper -- +1 4 25 2 + +-- !cte_producer -- +2 1 + +-- !cte_consumer -- +2 1 2 1 + +-- !cte_multi_producer -- +2 1 3 2 3 1 + +-- !multi_cte_nest -- +2 1 3 2 3 1 2 1 3 2 3 1 + +-- !multi_cte_nest2 -- +2 1 3 2 3 1 2 1 3 2 3 1 2 1 3 2 3 1 2 1 3 2 3 1 2 1 3 2 3 1 2 1 3 2 3 1 + +-- !cte_consumer_count_multi_column_with_group_by -- +1 1 1 1 +1 1 2 2 +1 1 2 2 +2 2 1 1 +2 2 1 1 +2 2 2 2 +2 2 2 2 +2 2 2 2 +2 2 2 2 + +-- !cte_consumer_count_multi_column_without_group_by -- +3 2 3 2 + +-- !cte_multi_producer_multi_column -- +1 2 1 2 3 1 +1 2 3 2 3 1 +2 3 1 2 3 1 +2 3 3 2 3 1 + +-- !cte_multi_nested -- +1 2 1 2 3 1 3 2 1 1 +1 2 1 2 3 1 3 2 3 2 +1 2 1 2 3 1 3 2 3 2 +1 2 3 2 3 1 3 2 1 1 +1 2 3 2 3 1 3 2 3 2 +1 2 3 2 3 1 3 2 3 2 +2 3 1 2 3 1 3 2 1 1 +2 3 1 2 3 1 3 2 3 2 +2 3 1 2 3 1 3 2 3 2 +2 3 3 2 3 1 3 2 1 1 +2 3 3 2 3 1 3 2 3 2 +2 3 3 2 3 1 3 2 3 2 + +-- !multi_count_without_gby -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalOlapScan[test_distinct_multi] +--PhysicalResultSink +----NestedLoopJoin[CROSS_JOIN] +------hashAgg[DISTINCT_GLOBAL] +--------hashAgg[DISTINCT_LOCAL] +----------hashAgg[GLOBAL] +------------hashAgg[LOCAL] +--------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------hashAgg[DISTINCT_GLOBAL] +--------hashAgg[DISTINCT_LOCAL] +----------hashAgg[GLOBAL] +------------hashAgg[LOCAL] +--------------PhysicalCteConsumer ( cteId=CTEId#0 ) + +-- !multi_sum_without_gby -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalOlapScan[test_distinct_multi] +--PhysicalResultSink +----NestedLoopJoin[CROSS_JOIN] +------hashAgg[DISTINCT_GLOBAL] +--------hashAgg[DISTINCT_LOCAL] +----------hashAgg[GLOBAL] +------------hashAgg[LOCAL] +--------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------hashAgg[DISTINCT_GLOBAL] +--------hashAgg[DISTINCT_LOCAL] +----------hashAgg[GLOBAL] +------------hashAgg[LOCAL] +--------------PhysicalCteConsumer ( cteId=CTEId#0 ) + +-- !sum_count_without_gby -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalOlapScan[test_distinct_multi] +--PhysicalResultSink +----NestedLoopJoin[CROSS_JOIN] +------hashAgg[DISTINCT_GLOBAL] +--------hashAgg[DISTINCT_LOCAL] +----------hashAgg[GLOBAL] +------------hashAgg[LOCAL] +--------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------hashAgg[DISTINCT_GLOBAL] +--------hashAgg[DISTINCT_LOCAL] +----------hashAgg[GLOBAL] +------------hashAgg[LOCAL] +--------------PhysicalCteConsumer ( cteId=CTEId#0 ) + +-- !multi_count_mulitcols_without_gby -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalOlapScan[test_distinct_multi] +--PhysicalResultSink +----NestedLoopJoin[CROSS_JOIN] +------hashAgg[DISTINCT_LOCAL] +--------hashAgg[GLOBAL] +----------hashAgg[LOCAL] +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------hashAgg[DISTINCT_LOCAL] +--------hashAgg[GLOBAL] +----------hashAgg[LOCAL] +------------PhysicalCteConsumer ( cteId=CTEId#0 ) + +-- !multi_count_mulitcols_with_gby -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalOlapScan[test_distinct_multi] +--PhysicalResultSink +----hashJoin[INNER_JOIN] hashCondition=((.d = .d)) otherCondition=() +------hashAgg[DISTINCT_LOCAL] +--------hashAgg[GLOBAL] +----------hashAgg[LOCAL] +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------hashAgg[DISTINCT_LOCAL] +--------hashAgg[GLOBAL] +----------hashAgg[LOCAL] +------------PhysicalCteConsumer ( cteId=CTEId#0 ) + +-- !three_count_mulitcols_without_gby -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalOlapScan[test_distinct_multi] +--PhysicalResultSink +----NestedLoopJoin[CROSS_JOIN] +------NestedLoopJoin[CROSS_JOIN] +--------hashAgg[DISTINCT_LOCAL] +----------hashAgg[GLOBAL] +------------hashAgg[LOCAL] +--------------PhysicalCteConsumer ( cteId=CTEId#0 ) +--------hashAgg[DISTINCT_LOCAL] +----------hashAgg[GLOBAL] +------------hashAgg[LOCAL] +--------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------hashAgg[DISTINCT_LOCAL] +--------hashAgg[GLOBAL] +----------hashAgg[LOCAL] +------------PhysicalCteConsumer ( cteId=CTEId#0 ) + +-- !four_count_mulitcols_with_gby -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalOlapScan[test_distinct_multi] +--PhysicalResultSink +----hashJoin[INNER_JOIN] hashCondition=((.d = .d)) otherCondition=() +------hashJoin[INNER_JOIN] hashCondition=((.d = .d)) otherCondition=() +--------hashAgg[DISTINCT_LOCAL] +----------hashAgg[GLOBAL] +------------hashAgg[LOCAL] +--------------PhysicalCteConsumer ( cteId=CTEId#0 ) +--------hashJoin[INNER_JOIN] hashCondition=((.d = .d)) otherCondition=() +----------hashAgg[DISTINCT_LOCAL] +------------hashAgg[GLOBAL] +--------------hashAgg[LOCAL] +----------------PhysicalCteConsumer ( cteId=CTEId#0 ) +----------hashAgg[DISTINCT_LOCAL] +------------hashAgg[GLOBAL] +--------------hashAgg[LOCAL] +----------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------hashAgg[DISTINCT_LOCAL] +--------hashAgg[GLOBAL] +----------hashAgg[LOCAL] +------------PhysicalCteConsumer ( cteId=CTEId#0 ) + +-- !multi_count_with_gby -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalOlapScan[test_distinct_multi] + +-- !multi_sum_with_gby -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalOlapScan[test_distinct_multi] + +-- !sum_count_with_gby -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalOlapScan[test_distinct_multi] + diff --git a/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy b/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy new file mode 100644 index 00000000000000..2876e18449181e --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy @@ -0,0 +1,197 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("distinct_split") { + sql "set runtime_filter_mode = OFF" + sql "drop table if exists test_distinct_multi" + sql "create table test_distinct_multi(a int, b int, c int, d varchar(10), e date) distributed by hash(a) properties('replication_num'='1');" + sql "insert into test_distinct_multi values(1,2,3,'abc','2024-01-02'),(1,2,4,'abc','2024-01-03'),(2,2,4,'abcd','2024-01-02'),(1,2,3,'abcd','2024-01-04'),(1,2,4,'eee','2024-02-02'),(2,2,4,'abc','2024-01-02');" + + // first bit 0 means distinct 1 col, 1 means distinct more than 1 col; second bit 0 means without group by, 1 means with group by; + // third bit 0 means there is 1 count(distinct) in projects, 1 means more than 1 count(distinct) in projects. + + //000 distinct has 1 column, no group by, projection column has 1 count (distinct). four stages agg + qt_000_count """select count(distinct a) from test_distinct_multi""" + + //001 distinct has 1 column, no group by, and multiple counts (distinct) in the projection column. The two-stage agg is slow for single point calculation in the second stage + qt_001_count """select count(distinct b), count(distinct a) from test_distinct_multi""" + + //010 distinct has 1 column with group by, and the projection column has 1 count (distinct). two-stage agg. The second stage follows group by hash + qt_010_count """select count(distinct a) from test_distinct_multi group by b order by 1""" + qt_010_count_same_column_with_groupby """select count(distinct a) from test_distinct_multi group by a order by 1""" + + //011 distinct has one column with group by, and the projection column has multiple counts (distinct). two stages agg. The second stage follows group by hash + qt_011_count_same_column_with_groupby """select count(distinct a),count(distinct b) from test_distinct_multi group by a order by 1,2""" + qt_011_count_diff_column_with_groupby """select count(distinct a),count(distinct b) from test_distinct_multi group by c order by 1,2""" + qt_011_count_diff_column_with_groupby_multi """select count(distinct a),count(distinct b) from test_distinct_multi group by a,c order by 1,2""" + qt_011_count_diff_column_with_groupby_all """select count(distinct a),count(distinct b) from test_distinct_multi group by a,b,c order by 1,2""" + + //100 distinct columns with no group by, projection column with 1 count (distinct). Three stage agg, second stage gather + qt_100 """select count(distinct a,b) from test_distinct_multi""" + + //101 distinct has multiple columns, no group by, and multiple counts (distinct) in the projection column (intercept). If the intercept is removed, it can be executed, but the result is incorrect + qt_101 """select count(distinct a,b), count(distinct a,c) from test_distinct_multi""" + qt_101_count_one_col_and_two_col """select count(distinct a,b), count(distinct c) from test_distinct_multi""" + qt_101_count_one_col_and_two_col """select count(distinct a,b), count(distinct a) from test_distinct_multi""" + + //110 distinct has multiple columns, including group by, and the projection column has one count (distinct). three-stage agg. The second stage follows group by hash + qt_110_count_diff_column_with_groupby """select count(distinct a,b) from test_distinct_multi group by c order by 1""" + qt_110_count_same_column_with_groupby1 """select count(distinct a,b) from test_distinct_multi group by a order by 1""" + qt_110_count_same_column_with_groupby2 """select count(distinct a,b) from test_distinct_multi group by a,b order by 1""" + + //111 distinct has multiple columns, including group by, and the projection column has multiple counts (distinct) (intercept). If the intercept is removed, it can be executed, but the result is incorrect + qt_111_count_same_column_with_groupby1 """select count(distinct a,b), count(distinct a,c) from test_distinct_multi group by c order by 1,2""" + qt_111_count_same_column_with_groupby2 """select count(distinct a,b), count(distinct c) from test_distinct_multi group by a,c order by 1,2""" + qt_111_count_diff_column_with_groupby """select count(distinct a,b), count(distinct a) from test_distinct_multi group by c order by 1,2""" + + // testing other functions + qt_000_count_other_func """select count(distinct a), max(b),sum(c),min(a) from test_distinct_multi""" + qt_001_count_other_func """select count(distinct b), count(distinct a), max(b),sum(c),min(a) from test_distinct_multi""" + qt_010_count_other_func """select count(distinct a), max(b),sum(c),min(a),b from test_distinct_multi group by b order by 1,2,3,4,5""" + qt_011_count_other_func """select count(distinct a), count(distinct b),max(b),sum(c),min(a),a from test_distinct_multi group by a order by 1,2,3,4,5,6""" + qt_100_count_other_func """select count(distinct a,b), count(distinct b),max(b),sum(c),min(a) from test_distinct_multi""" + qt_101_count_other_func """select count(distinct a,b), count(distinct a,c),max(b),sum(c),min(a) from test_distinct_multi""" + qt_110_count_other_func """select count(distinct a,b),max(b),sum(c),min(a),c from test_distinct_multi group by c order by 1,2,3,4,5""" + qt_111_count_other_func """select count(distinct a,b), count(distinct a,c),max(b),sum(c),min(a),c from test_distinct_multi group by c order by 1,2,3,4,5,6""" + + // multi distinct three four five + qt_001_three """select count(distinct b), count(distinct a), count(distinct c) from test_distinct_multi""" + qt_001_four """select count(distinct b), count(distinct a), count(distinct c), count(distinct d) from test_distinct_multi""" + qt_001_five """select count(distinct b), count(distinct a), count(distinct c), count(distinct d), count(distinct e) from test_distinct_multi""" + + qt_011_three """select count(distinct b), count(distinct a), count(distinct c) from test_distinct_multi group by d order by 1,2,3""" + qt_011_four """select count(distinct b), count(distinct a), count(distinct c), count(distinct d) from test_distinct_multi group by d order by 1,2,3,4""" + qt_011_five """select count(distinct b), count(distinct a), count(distinct c), count(distinct d), count(distinct e) from test_distinct_multi group by d order by 1,2,3,4,5""" + qt_011_three_gby_multi """select count(distinct b), count(distinct a), count(distinct c) from test_distinct_multi group by d,a order by 1,2,3""" + qt_011_four_gby_multi """select count(distinct b), count(distinct a), count(distinct c), count(distinct d) from test_distinct_multi group by d,c,a order by 1,2,3,4""" + qt_011_five_gby_multi """select count(distinct b), count(distinct a), count(distinct c), count(distinct d), count(distinct e) from test_distinct_multi group by d,b,a order by 1,2,3,4,5""" + + qt_101_three """select count(distinct a,b), count(distinct a,c) , count(distinct a) from test_distinct_multi""" + qt_101_four """select count(distinct a,b), count(distinct a,c) , count(distinct a), count(distinct c) from test_distinct_multi""" + qt_101_five """select count(distinct a,b), count(distinct a,c) , count(distinct a,d), count(distinct c) , count(distinct a,b,c,d) from test_distinct_multi""" + + qt_111_three """select count(distinct a,b), count(distinct a,c) , count(distinct a) from test_distinct_multi group by c order by 1,2,3""" + qt_111_four """select count(distinct a,b), count(distinct a,c) , count(distinct a), count(distinct c) from test_distinct_multi group by e order by 1,2,3,4""" + qt_111_five """select count(distinct a,b), count(distinct a,c) , count(distinct a,d), count(distinct c) , count(distinct a,b,c,d) from test_distinct_multi group by e order by 1,2,3,4,5""" + qt_111_three_gby_multi """select count(distinct a,b), count(distinct a,c) , count(distinct a) from test_distinct_multi group by c,a order by 1,2,3""" + qt_111_four_gby_multi """select count(distinct a,b), count(distinct a,c) , count(distinct a), count(distinct c) from test_distinct_multi group by e,a,b order by 1,2,3,4""" + qt_111_five_gby_multi """select count(distinct a,b), count(distinct a,c) , count(distinct a,d), count(distinct c) , count(distinct a,b,c,d) from test_distinct_multi group by e,a,b,c,d order by 1,2,3,4,5""" + + // sum has two dimensions: 1. Is there one or more projection columns (0 for one, 1 for more) 2. Is there a group by (0 for none, 1 for yes) + qt_00_sum """select sum(distinct b) from test_distinct_multi""" + qt_10_sum """select sum(distinct b), sum(distinct a) from test_distinct_multi""" + qt_01_sum """select sum(distinct b) from test_distinct_multi group by a order by 1""" + qt_11_sum """select sum(distinct b), sum(distinct a) from test_distinct_multi group by a order by 1,2""" + + // avg has two dimensions: 1. Is there one or more projection columns (0 for one, 1 for more) 2. Is there a group by (0 for no, 1 for yes) + qt_00_avg """select avg(distinct b) from test_distinct_multi""" + qt_10_avg """select avg(distinct b), avg(distinct a) from test_distinct_multi""" + qt_01_avg """select avg(distinct b) from test_distinct_multi group by a order by 1""" + qt_11_avg """select avg(distinct b), avg(distinct a) from test_distinct_multi group by a order by 1,2""" + + //group_concat + sql """select group_concat(distinct d order by d) from test_distinct_multi""" + sql """select group_concat(distinct d order by d), group_concat(distinct cast(a as string) order by cast(a as string)) from test_distinct_multi""" + sql """select group_concat(distinct d order by d) from test_distinct_multi group by a order by 1""" + sql """select group_concat(distinct d order by d), group_concat(distinct cast(a as string) order by cast(a as string)) from test_distinct_multi group by a order by 1,2""" + + // mixed distinct function + qt_count_sum_avg_no_gby "select sum(distinct b), count(distinct a), avg(distinct c) from test_distinct_multi" + qt_count_multi_sum_avg_no_gby "select sum(distinct b), count(distinct a,d), avg(distinct c) from test_distinct_multi" + qt_count_sum_avg_with_gby "select sum(distinct b), count(distinct a), avg(distinct c) from test_distinct_multi group by b,a order by 1,2,3" + qt_count_multi_sum_avg_with_gby "select sum(distinct b), count(distinct a,d), avg(distinct c) from test_distinct_multi group by a,b order by 1,2,3" +//这里以下还需要验证结果 + // There is a reference query in the upper layer + qt_multi_sum_has_upper """select c1+ c2 from (select sum(distinct b) c1, sum(distinct a) c2 from test_distinct_multi) t""" + qt_000_count_has_upper """select abs(c1) from (select count(distinct a) c1 from test_distinct_multi) t""" + qt_010_count_has_upper """select c1+100 from (select count(distinct a) c1 from test_distinct_multi group by b) t order by 1""" + qt_011_count_diff_column_with_groupby_all_has_upper """select max(c2), max(c1) from (select count(distinct a) c1,count(distinct b) c2 from test_distinct_multi group by a,b,c) t""" + qt_100_has_upper """select c1+1 from (select count(distinct a,b) c1 from test_distinct_multi) t where c1>0""" + qt_101_has_upper """select c1+c2+100 from (select count(distinct a,b) c1, count(distinct a,c) c2 from test_distinct_multi) t""" + qt_111_count_same_column_with_groupby1_has_upper """select max(c1), min(c2) from (select count(distinct a,b) c1, count(distinct a,c) c2 from test_distinct_multi group by c) t""" + qt_010_count_sum_other_func_has_upper """select sum(c0),max(c1+c2), min(c2+c3+c4),max(b) from (select sum(distinct b) c0,count(distinct a) c1, max(b) c2,sum(c) c3,min(a) c4,b from test_distinct_multi group by b) t""" + qt_010_count_other_func_has_upper"""select sum(c0), max(c1+c2), min(c2+c3+c4),max(b) from (select count(distinct b) c0,count(distinct a) c1, max(b) c2,sum(c) c3,min(a) c4,b from test_distinct_multi group by b) t""" + + // In cte or in nested cte. + qt_cte_producer """with t1 as (select a,b from test_distinct_multi) + select count(distinct t.a), count(distinct tt.b) from t1 t cross join t1 tt;""" + qt_cte_consumer """with t1 as (select count(distinct a), count(distinct b) from test_distinct_multi) + select * from t1 t cross join t1 tt;""" + qt_cte_multi_producer """ + with t1 as (select count(distinct a), count(distinct b) from test_distinct_multi), + t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi), + t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi) + select * from t1,t2,t3; + """ + qt_multi_cte_nest """ + with t1 as (select count(distinct a), count(distinct b) from test_distinct_multi), + t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi), + t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi) + select * from t1,t2,t3, (with t1 as (select count(distinct a), count(distinct b) from test_distinct_multi), + t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi), + t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi) + select * from t1,t2,t3) tmp; + """ + qt_multi_cte_nest2 """ + with t1 as (with t1 as (select count(distinct a), count(distinct b) from test_distinct_multi), + t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi), + t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi) + select * from t1,t2,t3, (with t1 as (select count(distinct a), count(distinct b) from test_distinct_multi), + t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi), + t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi) + select * from t1,t2,t3) tmp) + select * from t1,t1,(with t1 as (select count(distinct a), count(distinct b) from test_distinct_multi), + t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi), + t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi) + select * from t1,t2,t3, (with t1 as (select count(distinct a), count(distinct b) from test_distinct_multi), + t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi), + t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi) + select * from t1,t2,t3) tmp) t + """ + qt_cte_consumer_count_multi_column_with_group_by """with t1 as (select count(distinct a,b), count(distinct b,c) from test_distinct_multi group by d) + select * from t1 t cross join t1 tt order by 1,2,3,4;""" + qt_cte_consumer_count_multi_column_without_group_by """with t1 as (select sum(distinct a), count(distinct b,c) from test_distinct_multi) + select * from t1 t cross join t1 tt;""" + qt_cte_multi_producer_multi_column """ + with t1 as (select count(distinct a), count(distinct b,d) from test_distinct_multi group by c), + t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi group by c), + t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi) + select * from t1,t2,t3 order by 1,2,3,4,5,6; + """ + qt_cte_multi_nested """ + with tmp as (with t1 as (select count(distinct a), count(distinct b,d) from test_distinct_multi group by c), + t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi group by c), + t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi) + select * from t1,t2,t3) + select * from tmp, (select sum(distinct a), count(distinct b,c) from test_distinct_multi) t, (select sum(distinct a), count(distinct b,c) from test_distinct_multi group by d) tt order by 1,2,3,4,5,6,7,8,9,10 + """ + + // shape + sql "SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'" + qt_multi_count_without_gby """explain shape plan select count(distinct b), count(distinct a) from test_distinct_multi""" + qt_multi_sum_without_gby """explain shape plan select sum(distinct b), sum(distinct a) from test_distinct_multi""" + qt_sum_count_without_gby """explain shape plan select sum(distinct b), count(distinct a) from test_distinct_multi""" + qt_multi_count_mulitcols_without_gby """explain shape plan select count(distinct b,c), count(distinct a,b) from test_distinct_multi""" + qt_multi_count_mulitcols_with_gby """explain shape plan select count(distinct b,c), count(distinct a,b) from test_distinct_multi group by d""" + qt_three_count_mulitcols_without_gby """explain shape plan select count(distinct b,c), count(distinct a,b), count(distinct a,b,c) from test_distinct_multi""" + qt_four_count_mulitcols_with_gby """explain shape plan select count(distinct b,c), count(distinct a,b),count(distinct b,c,d), count(distinct a,b,c) from test_distinct_multi group by d""" + + // should not rewrite + qt_multi_count_with_gby """explain shape plan select count(distinct b), count(distinct a) from test_distinct_multi group by c""" + qt_multi_sum_with_gby """explain shape plan select sum(distinct b), sum(distinct a) from test_distinct_multi group by c""" + qt_sum_count_with_gby """explain shape plan select sum(distinct b), count(distinct a) from test_distinct_multi group by a""" +} \ No newline at end of file diff --git a/regression-test/suites/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.groovy b/regression-test/suites/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.groovy index 356b96267a88f5..f5545bc41b22df 100644 --- a/regression-test/suites/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.groovy +++ b/regression-test/suites/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.groovy @@ -360,7 +360,7 @@ suite("aggregate_without_roll_up") { "from orders " + "where O_ORDERDATE < '2023-12-30' and O_ORDERDATE > '2023-12-01'" order_qt_query3_0_before "${query3_0}" - async_mv_rewrite_success(db, mv3_0, query3_0, "mv3_0") + async_mv_rewrite_fail(db, mv3_0, query3_0, "mv3_0") order_qt_query3_0_after "${query3_0}" sql """ DROP MATERIALIZED VIEW IF EXISTS mv3_0""" @@ -883,7 +883,7 @@ suite("aggregate_without_roll_up") { "on lineitem.L_ORDERKEY = orders.O_ORDERKEY " + "where orders.O_ORDERDATE < '2023-12-30' and orders.O_ORDERDATE > '2023-12-01' " order_qt_query20_0_before "${query20_0}" - async_mv_rewrite_success(db, mv20_0, query20_0, "mv20_0") + async_mv_rewrite_fail(db, mv20_0, query20_0, "mv20_0") order_qt_query20_0_after "${query20_0}" sql """ DROP MATERIALIZED VIEW IF EXISTS mv20_0""" diff --git a/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_1.groovy b/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_1.groovy index 3aed3b0f9e24df..59cff69ee895b9 100644 --- a/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_1.groovy +++ b/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_1.groovy @@ -410,7 +410,7 @@ suite("partition_mv_rewrite_dimension_1") { count(*) from orders_1 """ - mv_rewrite_success(agg_sql_1, agg_mv_name_1) + mv_rewrite_fail(agg_sql_1, agg_mv_name_1) compare_res(agg_sql_1 + " order by 1,2,3,4,5,6") sql """DROP MATERIALIZED VIEW IF EXISTS ${agg_mv_name_1};""" diff --git a/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_2_3.groovy b/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_2_3.groovy index c7ee359cdef2e4..a50d77bf3cc9f8 100644 --- a/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_2_3.groovy +++ b/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_2_3.groovy @@ -145,7 +145,7 @@ suite("partition_mv_rewrite_dimension_2_3") { count(*) from orders_2_3 left join lineitem_2_3 on lineitem_2_3.l_orderkey = orders_2_3.o_orderkey""" - mv_rewrite_success(sql_stmt_1, mv_name_1) + mv_rewrite_fail(sql_stmt_1, mv_name_1) compare_res(sql_stmt_1 + " order by 1,2,3,4,5,6") sql """DROP MATERIALIZED VIEW IF EXISTS ${mv_name_1};""" diff --git a/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_2_4.groovy b/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_2_4.groovy index e59b2771dd4e57..05c57974389ac3 100644 --- a/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_2_4.groovy +++ b/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_2_4.groovy @@ -577,7 +577,7 @@ suite("partition_mv_rewrite_dimension_2_4") { count(distinct case when O_SHIPPRIORITY > 2 and o_orderkey IN (2) then o_custkey else null end) as cnt_2 from orders_2_4 where o_orderkey > (-3) + 5 """ - mv_rewrite_success(sql_stmt_13, mv_name_13) + mv_rewrite_fail(sql_stmt_13, mv_name_13) compare_res(sql_stmt_13 + " order by 1") sql """DROP MATERIALIZED VIEW IF EXISTS ${mv_name_13};""" diff --git a/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy b/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy index 1b546db0ff8eae..aeb39fb275af3d 100644 --- a/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy +++ b/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy @@ -149,12 +149,6 @@ suite("aggregate_strategies") { from $tableName )a group by c""" - - - test { - sql "select count(distinct id, name), count(distinct id) from $tableName" - exception "The query contains multi count distinct or sum distinct, each can't have multi columns" - } } test_aggregate_strategies('test_bucket1_table', 1) From a5e68e3b3514474ee238525b673c225779cb2bc4 Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Wed, 18 Dec 2024 21:14:18 +0800 Subject: [PATCH 2/5] split other agg function into seperate aggregate --- .../nereids/rules/rewrite/DistinctSplit.java | 41 ++++++++++++++----- .../distinct_split/disitinct_split.out | 29 +++++++++++-- .../distinct_split/disitinct_split.groovy | 3 +- 3 files changed, 58 insertions(+), 15 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java index cf557245d2ae30..b273c5b6ba71f7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java @@ -152,14 +152,6 @@ public Plan visitLogicalAggregate(LogicalAggregate agg, Distinct Alias alias = new Alias(newDistinctAggFunc); outputExpressions.add(alias); if (i == 0) { - List otherAggFuncAliases = otherAggFuncs.stream() - .map(e -> ExpressionUtils.replace(e, producerToConsumerSlotMap)).collect(Collectors.toList()); - for (Expression otherAggFuncAlias : otherAggFuncAliases) { - // otherAggFunc is instance of Alias - Alias outputOtherFunc = new Alias(otherAggFuncAlias.child(0)); - outputExpressions.add(outputOtherFunc); - newToOriginDistinctFuncAlias.put(outputOtherFunc, (Alias) otherAggFuncAlias); - } // save replacedGroupBy outputJoinGroupBys.addAll(replacedGroupBy); } @@ -167,10 +159,39 @@ public Plan visitLogicalAggregate(LogicalAggregate agg, Distinct newAggs.add(newAgg); newToOriginDistinctFuncAlias.put(alias, distinctFuncWithAlias.get(i)); } + buildOtherAggFuncAggregate(otherAggFuncs, producer, ctx, cloneAgg, newToOriginDistinctFuncAlias, newAggs); List groupBy = agg.getGroupByExpressions(); LogicalJoin join = constructJoin(newAggs, groupBy); - return constructProject(groupBy, newToOriginDistinctFuncAlias, - outputJoinGroupBys, join); + return constructProject(groupBy, newToOriginDistinctFuncAlias, outputJoinGroupBys, join); + } + + private static void buildOtherAggFuncAggregate(List otherAggFuncs, LogicalCTEProducer producer, + DistinctSplitContext ctx, LogicalAggregate cloneAgg, Map newToOriginDistinctFuncAlias, + List> newAggs) { + if (otherAggFuncs.isEmpty()) { + return; + } + LogicalCTEConsumer consumer = new LogicalCTEConsumer(ctx.statementContext.getNextRelationId(), + producer.getCteId(), "", producer); + ctx.cascadesContext.putCTEIdToConsumer(consumer); + Map producerToConsumerSlotMap = new HashMap<>(); + for (Map.Entry entry : consumer.getConsumerToProducerOutputMap().entrySet()) { + producerToConsumerSlotMap.put(entry.getValue(), entry.getKey()); + } + List replacedGroupBy = ExpressionUtils.replace(cloneAgg.getGroupByExpressions(), + producerToConsumerSlotMap); + List outputExpressions = replacedGroupBy.stream() + .map(Slot.class::cast).collect(Collectors.toList()); + List otherAggFuncAliases = otherAggFuncs.stream() + .map(e -> ExpressionUtils.replace(e, producerToConsumerSlotMap)).collect(Collectors.toList()); + for (Expression otherAggFuncAlias : otherAggFuncAliases) { + // otherAggFunc is instance of Alias + Alias outputOtherFunc = new Alias(otherAggFuncAlias.child(0)); + outputExpressions.add(outputOtherFunc); + newToOriginDistinctFuncAlias.put(outputOtherFunc, (Alias) otherAggFuncAlias); + } + LogicalAggregate newAgg = new LogicalAggregate<>(replacedGroupBy, outputExpressions, consumer); + newAggs.add(newAgg); } private static boolean isDistinctMultiColumns(AggregateFunction func) { diff --git a/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out b/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out index cd10c541cc24c7..a1e600a15d7cea 100644 --- a/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out +++ b/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out @@ -411,10 +411,6 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalResultSink ----hashJoin[INNER_JOIN] hashCondition=((.d = .d)) otherCondition=() ------hashJoin[INNER_JOIN] hashCondition=((.d = .d)) otherCondition=() ---------hashAgg[DISTINCT_LOCAL] -----------hashAgg[GLOBAL] -------------hashAgg[LOCAL] ---------------PhysicalCteConsumer ( cteId=CTEId#0 ) --------hashJoin[INNER_JOIN] hashCondition=((.d = .d)) otherCondition=() ----------hashAgg[DISTINCT_LOCAL] ------------hashAgg[GLOBAL] @@ -424,11 +420,36 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------hashAgg[GLOBAL] --------------hashAgg[LOCAL] ----------------PhysicalCteConsumer ( cteId=CTEId#0 ) +--------hashAgg[DISTINCT_LOCAL] +----------hashAgg[GLOBAL] +------------hashAgg[LOCAL] +--------------PhysicalCteConsumer ( cteId=CTEId#0 ) ------hashAgg[DISTINCT_LOCAL] --------hashAgg[GLOBAL] ----------hashAgg[LOCAL] ------------PhysicalCteConsumer ( cteId=CTEId#0 ) +-- !has_other_func -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalOlapScan[test_distinct_multi] +--PhysicalResultSink +----NestedLoopJoin[CROSS_JOIN] +------NestedLoopJoin[CROSS_JOIN] +--------hashAgg[DISTINCT_GLOBAL] +----------hashAgg[DISTINCT_LOCAL] +------------hashAgg[GLOBAL] +--------------hashAgg[LOCAL] +----------------PhysicalCteConsumer ( cteId=CTEId#0 ) +--------hashAgg[DISTINCT_GLOBAL] +----------hashAgg[DISTINCT_LOCAL] +------------hashAgg[GLOBAL] +--------------hashAgg[LOCAL] +----------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------hashAgg[GLOBAL] +--------hashAgg[LOCAL] +----------PhysicalCteConsumer ( cteId=CTEId#0 ) + -- !multi_count_with_gby -- PhysicalResultSink --hashAgg[GLOBAL] diff --git a/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy b/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy index 2876e18449181e..6c85ecf45aee29 100644 --- a/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy +++ b/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy @@ -114,7 +114,7 @@ suite("distinct_split") { qt_count_multi_sum_avg_no_gby "select sum(distinct b), count(distinct a,d), avg(distinct c) from test_distinct_multi" qt_count_sum_avg_with_gby "select sum(distinct b), count(distinct a), avg(distinct c) from test_distinct_multi group by b,a order by 1,2,3" qt_count_multi_sum_avg_with_gby "select sum(distinct b), count(distinct a,d), avg(distinct c) from test_distinct_multi group by a,b order by 1,2,3" -//这里以下还需要验证结果 + // There is a reference query in the upper layer qt_multi_sum_has_upper """select c1+ c2 from (select sum(distinct b) c1, sum(distinct a) c2 from test_distinct_multi) t""" qt_000_count_has_upper """select abs(c1) from (select count(distinct a) c1 from test_distinct_multi) t""" @@ -189,6 +189,7 @@ suite("distinct_split") { qt_multi_count_mulitcols_with_gby """explain shape plan select count(distinct b,c), count(distinct a,b) from test_distinct_multi group by d""" qt_three_count_mulitcols_without_gby """explain shape plan select count(distinct b,c), count(distinct a,b), count(distinct a,b,c) from test_distinct_multi""" qt_four_count_mulitcols_with_gby """explain shape plan select count(distinct b,c), count(distinct a,b),count(distinct b,c,d), count(distinct a,b,c) from test_distinct_multi group by d""" + qt_has_other_func "explain shape plan select count(distinct b), count(distinct a), max(b),sum(c),min(a) from test_distinct_multi" // should not rewrite qt_multi_count_with_gby """explain shape plan select count(distinct b), count(distinct a) from test_distinct_multi group by c""" From 51e7f520fae911352f9c32c997ec91ac7608531f Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Thu, 19 Dec 2024 12:18:00 +0800 Subject: [PATCH 3/5] fix regression --- .../nereids_rules_p0/distinct_split/disitinct_split.groovy | 1 + 1 file changed, 1 insertion(+) diff --git a/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy b/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy index 6c85ecf45aee29..60392f67e37710 100644 --- a/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy +++ b/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy @@ -17,6 +17,7 @@ suite("distinct_split") { sql "set runtime_filter_mode = OFF" + sql "set disable_join_reorder=true" sql "drop table if exists test_distinct_multi" sql "create table test_distinct_multi(a int, b int, c int, d varchar(10), e date) distributed by hash(a) properties('replication_num'='1');" sql "insert into test_distinct_multi values(1,2,3,'abc','2024-01-02'),(1,2,4,'abc','2024-01-03'),(2,2,4,'abcd','2024-01-02'),(1,2,3,'abcd','2024-01-04'),(1,2,4,'eee','2024-02-02'),(2,2,4,'abc','2024-01-02');" From 609ee046cb3babb881636a8f52a50a4a107b7c4b Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Tue, 24 Dec 2024 22:21:53 +0800 Subject: [PATCH 4/5] fix comment --- .../doris/nereids/jobs/executor/Rewriter.java | 7 +- .../apache/doris/nereids/rules/RuleType.java | 2 +- .../implementation/AggregateStrategies.java | 11 +- .../rules/rewrite/CheckMultiDistinct.java | 31 +++ ...inctSplit.java => SplitMultiDistinct.java} | 55 ++--- .../functions/agg/SupportMultiDistinct.java | 4 +- .../rules/rewrite/DistinctSplitTest.java | 97 --------- .../rules/rewrite/SplitMultiDistinctTest.java | 191 ++++++++++++++++++ 8 files changed, 264 insertions(+), 134 deletions(-) rename fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/{DistinctSplit.java => SplitMultiDistinct.java} (86%) delete mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DistinctSplitTest.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinctTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 38aebc44154996..d4ed9c21776224 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -56,7 +56,7 @@ import org.apache.doris.nereids.rules.rewrite.CountLiteralRewrite; import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow; import org.apache.doris.nereids.rules.rewrite.DeferMaterializeTopNResult; -import org.apache.doris.nereids.rules.rewrite.DistinctSplit; +import org.apache.doris.nereids.rules.rewrite.SplitMultiDistinct; import org.apache.doris.nereids.rules.rewrite.EliminateAggCaseWhen; import org.apache.doris.nereids.rules.rewrite.EliminateAggregate; import org.apache.doris.nereids.rules.rewrite.EliminateAssertNumRows; @@ -445,7 +445,6 @@ public class Rewriter extends AbstractBatchJobExecutor { new CollectCteConsumerOutput() ) ), - // topic("distinct split", topDown(new DistinctSplit())), topic("Collect used column", custom(RuleType.COLLECT_COLUMNS, QueryColumnCollector::new) ) ) @@ -552,8 +551,8 @@ private static List getWholeTreeRewriteJobs( rewriteJobs.addAll(jobs(topic("or expansion", custom(RuleType.OR_EXPANSION, () -> OrExpansion.INSTANCE)))); } - rewriteJobs.addAll(jobs(topic("distinct split", - custom(RuleType.DISTINCT_SPLIT, () -> DistinctSplit.INSTANCE)))); + rewriteJobs.addAll(jobs(topic("split multi distinct", + custom(RuleType.SPLIT_MULTI_DISTINCT, () -> SplitMultiDistinct.INSTANCE)))); if (needSubPathPushDown) { rewriteJobs.addAll(jobs( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index a4b9b410358d0d..86b1d114d0b68d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -322,7 +322,7 @@ public enum RuleType { MERGE_TOP_N(RuleTypeClass.REWRITE), BUILD_AGG_FOR_UNION(RuleTypeClass.REWRITE), COUNT_DISTINCT_REWRITE(RuleTypeClass.REWRITE), - DISTINCT_SPLIT(RuleTypeClass.REWRITE), + SPLIT_MULTI_DISTINCT(RuleTypeClass.REWRITE), INNER_TO_CROSS_JOIN(RuleTypeClass.REWRITE), CROSS_TO_INNER_JOIN(RuleTypeClass.REWRITE), PRUNE_EMPTY_PARTITION(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index e98a9c6767daea..0a1b1c4e9b2d30 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -54,6 +54,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Min; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0; +import org.apache.doris.nereids.trees.expressions.functions.agg.SupportMultiDistinct; import org.apache.doris.nereids.trees.expressions.functions.scalar.If; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; @@ -1808,14 +1809,8 @@ private List getHashAggregatePartitionExpressions( } private AggregateFunction tryConvertToMultiDistinct(AggregateFunction function) { - if (function instanceof Count && function.isDistinct()) { - return ((Count) function).convertToMultiDistinct(); - } else if (function instanceof Sum && function.isDistinct()) { - return ((Sum) function).convertToMultiDistinct(); - } else if (function instanceof Sum0 && function.isDistinct()) { - return ((Sum0) function).convertToMultiDistinct(); - } else if (function instanceof GroupConcat && function.isDistinct()) { - return ((GroupConcat) function).convertToMultiDistinct(); + if (function instanceof SupportMultiDistinct && function.isDistinct()) { + return ((SupportMultiDistinct) function).convertToMultiDistinct(); } return function; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java index 4488a94b8d14c0..bc6fd8437239af 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java @@ -20,6 +20,7 @@ import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.OrderExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; @@ -57,6 +58,36 @@ private LogicalAggregate checkDistinct(LogicalAggregate aggregat } } } + + boolean distinctMultiColumns = false; + for (AggregateFunction func : aggregate.getAggregateFunctions()) { + if (!func.isDistinct()) { + continue; + } + if (func.arity() <= 1) { + continue; + } + for (int i = 1; i < func.arity(); i++) { + if (!func.child(i).getInputSlots().isEmpty() && !(func.child(i) instanceof OrderExpression)) { + // think about group_concat(distinct col_1, ',') + distinctMultiColumns = true; + break; + } + } + if (distinctMultiColumns) { + break; + } + } + + long distinctFunctionNum = 0; + for (AggregateFunction aggregateFunction : aggregate.getAggregateFunctions()) { + distinctFunctionNum += aggregateFunction.isDistinct() ? 1 : 0; + } + + if (distinctMultiColumns && distinctFunctionNum > 1) { + // throw new AnalysisException( + // "The query contains multi count distinct or sum distinct, each can't have multi columns"); + } return aggregate; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinct.java similarity index 86% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinct.java index b273c5b6ba71f7..56df1485a83829 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinct.java @@ -20,7 +20,7 @@ import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.jobs.JobContext; -import org.apache.doris.nereids.rules.rewrite.DistinctSplit.DistinctSplitContext; +import org.apache.doris.nereids.rules.rewrite.SplitMultiDistinct.DistinctSplitContext; import org.apache.doris.nereids.trees.copier.DeepCopierContext; import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier; import org.apache.doris.nereids.trees.expressions.Alias; @@ -67,8 +67,8 @@ * +--LogicalAggregate(output:count(distinct b)) * +--LogicalCTEConsumer * */ -public class DistinctSplit extends DefaultPlanRewriter implements CustomRewriter { - public static DistinctSplit INSTANCE = new DistinctSplit(); +public class SplitMultiDistinct extends DefaultPlanRewriter implements CustomRewriter { + public static SplitMultiDistinct INSTANCE = new SplitMultiDistinct(); /**DistinctSplitContext*/ public static class DistinctSplitContext { @@ -111,6 +111,8 @@ public Plan visitLogicalCTEAnchor( @Override public Plan visitLogicalAggregate(LogicalAggregate agg, DistinctSplitContext ctx) { + Plan newChild = agg.child().accept(this, ctx); + agg = agg.withChildren(ImmutableList.of(newChild)); List distinctFuncWithAlias = new ArrayList<>(); List otherAggFuncs = new ArrayList<>(); if (!needTransform((LogicalAggregate) agg, distinctFuncWithAlias, otherAggFuncs)) { @@ -137,18 +139,12 @@ public Plan visitLogicalAggregate(LogicalAggregate agg, Distinct List outputJoinGroupBys = new ArrayList<>(); for (int i = 0; i < distinctFuncWithAlias.size(); ++i) { Expression distinctAggFunc = distinctFuncWithAlias.get(i).child(0); - LogicalCTEConsumer consumer = new LogicalCTEConsumer(ctx.statementContext.getNextRelationId(), - producer.getCteId(), "", producer); - ctx.cascadesContext.putCTEIdToConsumer(consumer); Map producerToConsumerSlotMap = new HashMap<>(); - for (Map.Entry entry : consumer.getConsumerToProducerOutputMap().entrySet()) { - producerToConsumerSlotMap.put(entry.getValue(), entry.getKey()); - } - List replacedGroupBy = ExpressionUtils.replace(cloneAgg.getGroupByExpressions(), - producerToConsumerSlotMap); + List outputExpressions = new ArrayList<>(); + List replacedGroupBy = new ArrayList<>(); + LogicalCTEConsumer consumer = constructConsumerAndReplaceGroupBy(ctx, producer, cloneAgg, outputExpressions, + producerToConsumerSlotMap, replacedGroupBy); Expression newDistinctAggFunc = ExpressionUtils.replace(distinctAggFunc, producerToConsumerSlotMap); - List outputExpressions = replacedGroupBy.stream() - .map(Slot.class::cast).collect(Collectors.toList()); Alias alias = new Alias(newDistinctAggFunc); outputExpressions.add(alias); if (i == 0) { @@ -171,17 +167,11 @@ private static void buildOtherAggFuncAggregate(List otherAggFuncs, Logica if (otherAggFuncs.isEmpty()) { return; } - LogicalCTEConsumer consumer = new LogicalCTEConsumer(ctx.statementContext.getNextRelationId(), - producer.getCteId(), "", producer); - ctx.cascadesContext.putCTEIdToConsumer(consumer); Map producerToConsumerSlotMap = new HashMap<>(); - for (Map.Entry entry : consumer.getConsumerToProducerOutputMap().entrySet()) { - producerToConsumerSlotMap.put(entry.getValue(), entry.getKey()); - } - List replacedGroupBy = ExpressionUtils.replace(cloneAgg.getGroupByExpressions(), - producerToConsumerSlotMap); - List outputExpressions = replacedGroupBy.stream() - .map(Slot.class::cast).collect(Collectors.toList()); + List outputExpressions = new ArrayList<>(); + List replacedGroupBy = new ArrayList<>(); + LogicalCTEConsumer consumer = constructConsumerAndReplaceGroupBy(ctx, producer, cloneAgg, outputExpressions, + producerToConsumerSlotMap, replacedGroupBy); List otherAggFuncAliases = otherAggFuncs.stream() .map(e -> ExpressionUtils.replace(e, producerToConsumerSlotMap)).collect(Collectors.toList()); for (Expression otherAggFuncAlias : otherAggFuncAliases) { @@ -194,6 +184,20 @@ private static void buildOtherAggFuncAggregate(List otherAggFuncs, Logica newAggs.add(newAgg); } + private static LogicalCTEConsumer constructConsumerAndReplaceGroupBy(DistinctSplitContext ctx, + LogicalCTEProducer producer, LogicalAggregate cloneAgg, List outputExpressions, + Map producerToConsumerSlotMap, List replacedGroupBy) { + LogicalCTEConsumer consumer = new LogicalCTEConsumer(ctx.statementContext.getNextRelationId(), + producer.getCteId(), "", producer); + ctx.cascadesContext.putCTEIdToConsumer(consumer); + for (Map.Entry entry : consumer.getConsumerToProducerOutputMap().entrySet()) { + producerToConsumerSlotMap.put(entry.getValue(), entry.getKey()); + } + replacedGroupBy.addAll(ExpressionUtils.replace(cloneAgg.getGroupByExpressions(), producerToConsumerSlotMap)); + outputExpressions.addAll(replacedGroupBy.stream().map(Slot.class::cast).collect(Collectors.toList())); + return consumer; + } + private static boolean isDistinctMultiColumns(AggregateFunction func) { if (func.arity() <= 1) { return false; @@ -230,6 +234,11 @@ private static boolean needTransform(LogicalAggregate agg, List ali if (distinctFunc.size() <= 1) { return false; } + // when this aggregate is not distinctMultiColumns, and group by expressions is not empty + // e.g. sql1: select count(distinct a), count(distinct b) from t1 group by c; + // sql2: select count(distinct a) from t1 group by c; + // the physical plan of sql1 and sql2 is similar, both are 2-phase aggregate, + // so there is no need to do this rewrite if (!distinctMultiColumns && !agg.getGroupByExpressions().isEmpty()) { return false; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/SupportMultiDistinct.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/SupportMultiDistinct.java index 848c529e5c32b2..9feaf2025c4f63 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/SupportMultiDistinct.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/SupportMultiDistinct.java @@ -17,7 +17,9 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; -/** MultiDistinctTrait*/ +/** aggregate functions which have corresponding MultiDistinctXXX class, + * e.g. SUM,SUM0,COUNT,GROUP_CONCAT + * */ public interface SupportMultiDistinct { AggregateFunction convertToMultiDistinct(); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DistinctSplitTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DistinctSplitTest.java deleted file mode 100644 index 41801abd352568..00000000000000 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DistinctSplitTest.java +++ /dev/null @@ -1,97 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite; - -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.util.MatchingUtils; -import org.apache.doris.nereids.util.MemoPatternMatchSupported; -import org.apache.doris.nereids.util.PlanChecker; -import org.apache.doris.utframe.TestWithFeService; - -import org.junit.jupiter.api.Test; - -public class DistinctSplitTest extends TestWithFeService implements MemoPatternMatchSupported { - @Override - protected void runBeforeAll() throws Exception { - createDatabase("test"); - createTable("create table test.test_distinct_multi(a int, b int, c int, d varchar(10), e date)" - + "distributed by hash(a) properties('replication_num'='1');"); - connectContext.setDatabase("test"); - connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); - } - - @Test - void multiCountWithoutGby() { - String sql = "select count(distinct b), count(distinct a) from test_distinct_multi"; - PlanChecker.from(connectContext).checkExplain(sql, planner -> { - Plan plan = planner.getOptimizedPlan(); - MatchingUtils.assertMatches(plan, physicalCTEAnchor(physicalCTEProducer(any()), physicalResultSink(physicalProject(physicalNestedLoopJoin( - physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))))), - physicalDistribute(physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any()))))))) - ))))); - }); - } - - @Test - void multiSumWithoutGby() { - String sql = "select sum(distinct b), sum(distinct a) from test_distinct_multi"; - PlanChecker.from(connectContext).checkExplain(sql, planner -> { - Plan plan = planner.getOptimizedPlan(); - MatchingUtils.assertMatches(plan, physicalCTEAnchor(physicalCTEProducer(any()), physicalResultSink(physicalProject(physicalNestedLoopJoin( - physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))))), - physicalDistribute(physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any()))))))) - ))))); - }); - } - - @Test - void sumCountWithoutGby() { - String sql = "select sum(distinct b), count(distinct a) from test_distinct_multi"; - PlanChecker.from(connectContext).checkExplain(sql, planner -> { - Plan plan = planner.getOptimizedPlan(); - MatchingUtils.assertMatches(plan, physicalCTEAnchor(physicalCTEProducer(any()), physicalResultSink(physicalProject(physicalNestedLoopJoin( - physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))))), - physicalDistribute(physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any()))))))) - ))))); - }); - } - - @Test - void countMultiColumnsWithoutGby() { - String sql = "select count(distinct b,c), count(distinct a,b) from test_distinct_multi"; - PlanChecker.from(connectContext).checkExplain(sql, planner -> { - Plan plan = planner.getOptimizedPlan(); - MatchingUtils.assertMatches(plan, physicalCTEAnchor(physicalCTEProducer(any()), physicalResultSink(physicalProject(physicalNestedLoopJoin( - physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))), - physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any()))))) - ))))); - }); - } - - @Test - void countMultiColumnsWithGby() { - String sql = "select count(distinct b,c), count(distinct a,b) from test_distinct_multi group by d"; - PlanChecker.from(connectContext).checkExplain(sql, planner -> { - Plan plan = planner.getOptimizedPlan(); - MatchingUtils.assertMatches(plan, physicalCTEAnchor(physicalCTEProducer(any()), physicalResultSink(physicalDistribute(physicalProject(physicalHashJoin( - physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))), - physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))) - )))))); - }); - } -} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinctTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinctTest.java new file mode 100644 index 00000000000000..074135695a1a31 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinctTest.java @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.util.MatchingUtils; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.utframe.TestWithFeService; + +import org.junit.jupiter.api.Test; + +public class SplitMultiDistinctTest extends TestWithFeService implements MemoPatternMatchSupported { + @Override + protected void runBeforeAll() throws Exception { + createDatabase("test"); + createTable("create table test.test_distinct_multi(a int, b int, c int, d varchar(10), e date)" + + "distributed by hash(a) properties('replication_num'='1');"); + connectContext.setDatabase("test"); + connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); + } + + @Test + void multiCountWithoutGby() { + String sql = "select count(distinct b), count(distinct a) from test_distinct_multi"; + PlanChecker.from(connectContext).checkExplain(sql, planner -> { + Plan plan = planner.getOptimizedPlan(); + MatchingUtils.assertMatches(plan, + physicalCTEAnchor( + physicalCTEProducer(any()), + physicalResultSink( + physicalProject( + physicalNestedLoopJoin( + physicalHashAggregate( + physicalDistribute( + physicalHashAggregate( + physicalHashAggregate( + physicalDistribute( + physicalHashAggregate(any())))))), + physicalDistribute( + physicalHashAggregate( + physicalDistribute( + physicalHashAggregate( + physicalHashAggregate( + physicalDistribute( + physicalHashAggregate(any()))))))) + ) + ) + ) + ) + ); + }); + } + + @Test + void multiSumWithoutGby() { + String sql = "select sum(distinct b), sum(distinct a) from test_distinct_multi"; + PlanChecker.from(connectContext).checkExplain(sql, planner -> { + Plan plan = planner.getOptimizedPlan(); + MatchingUtils.assertMatches(plan, + physicalCTEAnchor( + physicalCTEProducer(any()), + physicalResultSink( + physicalProject( + physicalNestedLoopJoin( + physicalHashAggregate( + physicalDistribute( + physicalHashAggregate( + physicalHashAggregate( + physicalDistribute( + physicalHashAggregate(any())))))), + physicalDistribute( + physicalHashAggregate( + physicalDistribute( + physicalHashAggregate( + physicalHashAggregate( + physicalDistribute( + physicalHashAggregate(any()))))))) + ) + ) + ) + ) + ); + }); + } + + @Test + void sumCountWithoutGby() { + String sql = "select sum(distinct b), count(distinct a) from test_distinct_multi"; + PlanChecker.from(connectContext).checkExplain(sql, planner -> { + Plan plan = planner.getOptimizedPlan(); + MatchingUtils.assertMatches(plan, + physicalCTEAnchor( + physicalCTEProducer(any()), + physicalResultSink( + physicalProject( + physicalNestedLoopJoin( + physicalHashAggregate( + physicalDistribute( + physicalHashAggregate( + physicalHashAggregate( + physicalDistribute( + physicalHashAggregate(any())))))), + physicalDistribute( + physicalHashAggregate( + physicalDistribute( + physicalHashAggregate( + physicalHashAggregate( + physicalDistribute( + physicalHashAggregate(any()))))))) + ) + ) + ) + ) + ); + }); + } + + @Test + void countMultiColumnsWithoutGby() { + String sql = "select count(distinct b,c), count(distinct a,b) from test_distinct_multi"; + PlanChecker.from(connectContext).checkExplain(sql, planner -> { + Plan plan = planner.getOptimizedPlan(); + MatchingUtils.assertMatches(plan, + physicalCTEAnchor( + physicalCTEProducer(any()), + physicalResultSink( + physicalProject( + physicalNestedLoopJoin( + physicalHashAggregate( + physicalHashAggregate( + physicalDistribute( + physicalHashAggregate(any())))), + physicalDistribute( + physicalHashAggregate( + physicalHashAggregate( + physicalDistribute( + physicalHashAggregate(any()))))) + ) + ) + ) + ) + ); + }); + } + + @Test + void countMultiColumnsWithGby() { + String sql = "select count(distinct b,c), count(distinct a,b) from test_distinct_multi group by d"; + PlanChecker.from(connectContext).checkExplain(sql, planner -> { + Plan plan = planner.getOptimizedPlan(); + MatchingUtils.assertMatches(plan, + physicalCTEAnchor( + physicalCTEProducer( + any()), + physicalResultSink( + physicalDistribute( + physicalProject( + physicalHashJoin( + physicalHashAggregate( + physicalHashAggregate( + physicalDistribute( + physicalHashAggregate(any())))), + physicalHashAggregate( + physicalHashAggregate( + physicalDistribute( + physicalHashAggregate(any())))) + ) + ) + ) + ) + ) + ); + }); + } +} From a6d2c1cdd8bdd574c14008adc0f44f3a1ac4ec60 Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Wed, 25 Dec 2024 10:53:17 +0800 Subject: [PATCH 5/5] add regression --- .../doris/nereids/jobs/executor/Rewriter.java | 2 +- .../rules/rewrite/CheckMultiDistinct.java | 4 +-- .../distinct_split/disitinct_split.out | 30 +++++++++++++++++++ .../distinct_split/disitinct_split.groovy | 11 +++++++ 4 files changed, 44 insertions(+), 3 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index d4ed9c21776224..8653b93a88ea29 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -56,7 +56,6 @@ import org.apache.doris.nereids.rules.rewrite.CountLiteralRewrite; import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow; import org.apache.doris.nereids.rules.rewrite.DeferMaterializeTopNResult; -import org.apache.doris.nereids.rules.rewrite.SplitMultiDistinct; import org.apache.doris.nereids.rules.rewrite.EliminateAggCaseWhen; import org.apache.doris.nereids.rules.rewrite.EliminateAggregate; import org.apache.doris.nereids.rules.rewrite.EliminateAssertNumRows; @@ -136,6 +135,7 @@ import org.apache.doris.nereids.rules.rewrite.SimplifyEncodeDecode; import org.apache.doris.nereids.rules.rewrite.SimplifyWindowExpression; import org.apache.doris.nereids.rules.rewrite.SplitLimit; +import org.apache.doris.nereids.rules.rewrite.SplitMultiDistinct; import org.apache.doris.nereids.rules.rewrite.SumLiteralRewrite; import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinAgg; import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinAggProject; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java index bc6fd8437239af..dd76457c41181f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java @@ -85,8 +85,8 @@ private LogicalAggregate checkDistinct(LogicalAggregate aggregat } if (distinctMultiColumns && distinctFunctionNum > 1) { - // throw new AnalysisException( - // "The query contains multi count distinct or sum distinct, each can't have multi columns"); + throw new AnalysisException( + "The query contains multi count distinct or sum distinct, each can't have multi columns"); } return aggregate; } diff --git a/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out b/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out index a1e600a15d7cea..2a1dd6fd9d6705 100644 --- a/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out +++ b/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out @@ -303,6 +303,12 @@ 2 3 3 2 3 1 3 2 3 2 2 3 3 2 3 1 3 2 3 2 +-- !2_agg_count_distinct -- +2 2 + +-- !3_agg_count_distinct -- +1 1 + -- !multi_count_without_gby -- PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalCteProducer ( cteId=CTEId#0 ) @@ -450,6 +456,23 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --------hashAgg[LOCAL] ----------PhysicalCteConsumer ( cteId=CTEId#0 ) +-- !2_agg -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalOlapScan[test_distinct_multi] +--PhysicalResultSink +----hashAgg[GLOBAL] +------hashAgg[LOCAL] +--------hashJoin[INNER_JOIN] hashCondition=((.c = .c)) otherCondition=() +----------hashAgg[DISTINCT_LOCAL] +------------hashAgg[GLOBAL] +--------------hashAgg[LOCAL] +----------------PhysicalCteConsumer ( cteId=CTEId#0 ) +----------hashAgg[DISTINCT_LOCAL] +------------hashAgg[GLOBAL] +--------------hashAgg[LOCAL] +----------------PhysicalCteConsumer ( cteId=CTEId#0 ) + -- !multi_count_with_gby -- PhysicalResultSink --hashAgg[GLOBAL] @@ -468,3 +491,10 @@ PhysicalResultSink ----hashAgg[LOCAL] ------PhysicalOlapScan[test_distinct_multi] +-- !has_grouping -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalRepeat +--------PhysicalOlapScan[test_distinct_multi] + diff --git a/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy b/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy index 60392f67e37710..02812b269a33eb 100644 --- a/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy +++ b/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy @@ -181,6 +181,10 @@ suite("distinct_split") { select * from tmp, (select sum(distinct a), count(distinct b,c) from test_distinct_multi) t, (select sum(distinct a), count(distinct b,c) from test_distinct_multi group by d) tt order by 1,2,3,4,5,6,7,8,9,10 """ + // multi aggregate + qt_2_agg_count_distinct """select count(distinct c1) c3, count(distinct c2) c4 from (select count(distinct a,b) c1, count(distinct a,c) c2 from test_distinct_multi group by c) t""" + qt_3_agg_count_distinct """select count(distinct c3), count(distinct c4) from (select count(distinct c1) c3, count(distinct c2) c4 from (select count(distinct a,b) c1, count(distinct a,c) c2 from test_distinct_multi group by c) t) tt""" + // shape sql "SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'" qt_multi_count_without_gby """explain shape plan select count(distinct b), count(distinct a) from test_distinct_multi""" @@ -191,9 +195,16 @@ suite("distinct_split") { qt_three_count_mulitcols_without_gby """explain shape plan select count(distinct b,c), count(distinct a,b), count(distinct a,b,c) from test_distinct_multi""" qt_four_count_mulitcols_with_gby """explain shape plan select count(distinct b,c), count(distinct a,b),count(distinct b,c,d), count(distinct a,b,c) from test_distinct_multi group by d""" qt_has_other_func "explain shape plan select count(distinct b), count(distinct a), max(b),sum(c),min(a) from test_distinct_multi" + qt_2_agg """explain shape plan select max(c1), min(c2) from (select count(distinct a,b) c1, count(distinct a,c) c2 from test_distinct_multi group by c) t""" // should not rewrite qt_multi_count_with_gby """explain shape plan select count(distinct b), count(distinct a) from test_distinct_multi group by c""" qt_multi_sum_with_gby """explain shape plan select sum(distinct b), sum(distinct a) from test_distinct_multi group by c""" qt_sum_count_with_gby """explain shape plan select sum(distinct b), count(distinct a) from test_distinct_multi group by a""" + qt_has_grouping """explain shape plan select count(distinct b), count(distinct a) from test_distinct_multi group by grouping sets((a,b),(c));""" + test { + sql """select count(distinct a,b), count(distinct a) from test_distinct_multi + group by grouping sets((a,b),(c));""" + exception "The query contains multi count distinct or sum distinct, each can't have multi columns" + } } \ No newline at end of file