Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
feiniaofeiafei committed Dec 24, 2024
1 parent ab2e173 commit 791a242
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
import org.apache.doris.nereids.rules.rewrite.CountLiteralRewrite;
import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow;
import org.apache.doris.nereids.rules.rewrite.DeferMaterializeTopNResult;
import org.apache.doris.nereids.rules.rewrite.DistinctSplit;
import org.apache.doris.nereids.rules.rewrite.SplitMultiDistinct;
import org.apache.doris.nereids.rules.rewrite.EliminateAggCaseWhen;
import org.apache.doris.nereids.rules.rewrite.EliminateAggregate;
import org.apache.doris.nereids.rules.rewrite.EliminateAssertNumRows;
Expand Down Expand Up @@ -445,7 +445,6 @@ public class Rewriter extends AbstractBatchJobExecutor {
new CollectCteConsumerOutput()
)
),
// topic("distinct split", topDown(new DistinctSplit())),
topic("Collect used column", custom(RuleType.COLLECT_COLUMNS, QueryColumnCollector::new)
)
)
Expand Down Expand Up @@ -552,8 +551,8 @@ private static List<RewriteJob> getWholeTreeRewriteJobs(
rewriteJobs.addAll(jobs(topic("or expansion",
custom(RuleType.OR_EXPANSION, () -> OrExpansion.INSTANCE))));
}
rewriteJobs.addAll(jobs(topic("distinct split",
custom(RuleType.DISTINCT_SPLIT, () -> DistinctSplit.INSTANCE))));
rewriteJobs.addAll(jobs(topic("split multi distinct",
custom(RuleType.SPLIT_MULTI_DISTINCT, () -> SplitMultiDistinct.INSTANCE))));

if (needSubPathPushDown) {
rewriteJobs.addAll(jobs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ public enum RuleType {
MERGE_TOP_N(RuleTypeClass.REWRITE),
BUILD_AGG_FOR_UNION(RuleTypeClass.REWRITE),
COUNT_DISTINCT_REWRITE(RuleTypeClass.REWRITE),
DISTINCT_SPLIT(RuleTypeClass.REWRITE),
SPLIT_MULTI_DISTINCT(RuleTypeClass.REWRITE),
INNER_TO_CROSS_JOIN(RuleTypeClass.REWRITE),
CROSS_TO_INNER_JOIN(RuleTypeClass.REWRITE),
PRUNE_EMPTY_PARTITION(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
import org.apache.doris.nereids.trees.expressions.functions.agg.SupportMultiDistinct;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
Expand Down Expand Up @@ -1808,14 +1809,8 @@ private List<Expression> getHashAggregatePartitionExpressions(
}

private AggregateFunction tryConvertToMultiDistinct(AggregateFunction function) {
if (function instanceof Count && function.isDistinct()) {
return ((Count) function).convertToMultiDistinct();
} else if (function instanceof Sum && function.isDistinct()) {
return ((Sum) function).convertToMultiDistinct();
} else if (function instanceof Sum0 && function.isDistinct()) {
return ((Sum0) function).convertToMultiDistinct();
} else if (function instanceof GroupConcat && function.isDistinct()) {
return ((GroupConcat) function).convertToMultiDistinct();
if (function instanceof SupportMultiDistinct && function.isDistinct()) {
return ((SupportMultiDistinct) function).convertToMultiDistinct();
}
return function;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
Expand Down Expand Up @@ -57,6 +58,36 @@ private LogicalAggregate checkDistinct(LogicalAggregate<? extends Plan> aggregat
}
}
}

boolean distinctMultiColumns = false;
for (AggregateFunction func : aggregate.getAggregateFunctions()) {
if (!func.isDistinct()) {
continue;
}
if (func.arity() <= 1) {
continue;
}
for (int i = 1; i < func.arity(); i++) {
if (!func.child(i).getInputSlots().isEmpty() && !(func.child(i) instanceof OrderExpression)) {
// think about group_concat(distinct col_1, ',')
distinctMultiColumns = true;
break;
}
}
if (distinctMultiColumns) {
break;
}
}

long distinctFunctionNum = 0;
for (AggregateFunction aggregateFunction : aggregate.getAggregateFunctions()) {
distinctFunctionNum += aggregateFunction.isDistinct() ? 1 : 0;
}

if (distinctMultiColumns && distinctFunctionNum > 1) {
// throw new AnalysisException(
// "The query contains multi count distinct or sum distinct, each can't have multi columns");
}
return aggregate;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.rules.rewrite.DistinctSplit.DistinctSplitContext;
import org.apache.doris.nereids.rules.rewrite.SplitMultiDistinct.DistinctSplitContext;
import org.apache.doris.nereids.trees.copier.DeepCopierContext;
import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
import org.apache.doris.nereids.trees.expressions.Alias;
Expand Down Expand Up @@ -67,8 +67,8 @@
* +--LogicalAggregate(output:count(distinct b))
* +--LogicalCTEConsumer
* */
public class DistinctSplit extends DefaultPlanRewriter<DistinctSplitContext> implements CustomRewriter {
public static DistinctSplit INSTANCE = new DistinctSplit();
public class SplitMultiDistinct extends DefaultPlanRewriter<DistinctSplitContext> implements CustomRewriter {
public static SplitMultiDistinct INSTANCE = new SplitMultiDistinct();

/**DistinctSplitContext*/
public static class DistinctSplitContext {
Expand Down Expand Up @@ -111,6 +111,8 @@ public Plan visitLogicalCTEAnchor(

@Override
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, DistinctSplitContext ctx) {
Plan newChild = agg.child().accept(this, ctx);
agg = agg.withChildren(ImmutableList.of(newChild));
List<Alias> distinctFuncWithAlias = new ArrayList<>();
List<Alias> otherAggFuncs = new ArrayList<>();
if (!needTransform((LogicalAggregate<Plan>) agg, distinctFuncWithAlias, otherAggFuncs)) {
Expand All @@ -137,18 +139,12 @@ public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, Distinct
List<Expression> outputJoinGroupBys = new ArrayList<>();
for (int i = 0; i < distinctFuncWithAlias.size(); ++i) {
Expression distinctAggFunc = distinctFuncWithAlias.get(i).child(0);
LogicalCTEConsumer consumer = new LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
producer.getCteId(), "", producer);
ctx.cascadesContext.putCTEIdToConsumer(consumer);
Map<Slot, Slot> producerToConsumerSlotMap = new HashMap<>();
for (Map.Entry<Slot, Slot> entry : consumer.getConsumerToProducerOutputMap().entrySet()) {
producerToConsumerSlotMap.put(entry.getValue(), entry.getKey());
}
List<Expression> replacedGroupBy = ExpressionUtils.replace(cloneAgg.getGroupByExpressions(),
producerToConsumerSlotMap);
List<NamedExpression> outputExpressions = new ArrayList<>();
List<Expression> replacedGroupBy = new ArrayList<>();
LogicalCTEConsumer consumer = constructConsumerAndReplaceGroupBy(ctx, producer, cloneAgg, outputExpressions,
producerToConsumerSlotMap, replacedGroupBy);
Expression newDistinctAggFunc = ExpressionUtils.replace(distinctAggFunc, producerToConsumerSlotMap);
List<NamedExpression> outputExpressions = replacedGroupBy.stream()
.map(Slot.class::cast).collect(Collectors.toList());
Alias alias = new Alias(newDistinctAggFunc);
outputExpressions.add(alias);
if (i == 0) {
Expand All @@ -171,17 +167,11 @@ private static void buildOtherAggFuncAggregate(List<Alias> otherAggFuncs, Logica
if (otherAggFuncs.isEmpty()) {
return;
}
LogicalCTEConsumer consumer = new LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
producer.getCteId(), "", producer);
ctx.cascadesContext.putCTEIdToConsumer(consumer);
Map<Slot, Slot> producerToConsumerSlotMap = new HashMap<>();
for (Map.Entry<Slot, Slot> entry : consumer.getConsumerToProducerOutputMap().entrySet()) {
producerToConsumerSlotMap.put(entry.getValue(), entry.getKey());
}
List<Expression> replacedGroupBy = ExpressionUtils.replace(cloneAgg.getGroupByExpressions(),
producerToConsumerSlotMap);
List<NamedExpression> outputExpressions = replacedGroupBy.stream()
.map(Slot.class::cast).collect(Collectors.toList());
List<NamedExpression> outputExpressions = new ArrayList<>();
List<Expression> replacedGroupBy = new ArrayList<>();
LogicalCTEConsumer consumer = constructConsumerAndReplaceGroupBy(ctx, producer, cloneAgg, outputExpressions,
producerToConsumerSlotMap, replacedGroupBy);
List<Expression> otherAggFuncAliases = otherAggFuncs.stream()
.map(e -> ExpressionUtils.replace(e, producerToConsumerSlotMap)).collect(Collectors.toList());
for (Expression otherAggFuncAlias : otherAggFuncAliases) {
Expand All @@ -194,6 +184,20 @@ private static void buildOtherAggFuncAggregate(List<Alias> otherAggFuncs, Logica
newAggs.add(newAgg);
}

private static LogicalCTEConsumer constructConsumerAndReplaceGroupBy(DistinctSplitContext ctx,
LogicalCTEProducer<Plan> producer, LogicalAggregate<Plan> cloneAgg, List<NamedExpression> outputExpressions,
Map<Slot, Slot> producerToConsumerSlotMap, List<Expression> replacedGroupBy) {
LogicalCTEConsumer consumer = new LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
producer.getCteId(), "", producer);
ctx.cascadesContext.putCTEIdToConsumer(consumer);
for (Map.Entry<Slot, Slot> entry : consumer.getConsumerToProducerOutputMap().entrySet()) {
producerToConsumerSlotMap.put(entry.getValue(), entry.getKey());
}
replacedGroupBy.addAll(ExpressionUtils.replace(cloneAgg.getGroupByExpressions(), producerToConsumerSlotMap));
outputExpressions.addAll(replacedGroupBy.stream().map(Slot.class::cast).collect(Collectors.toList()));
return consumer;
}

private static boolean isDistinctMultiColumns(AggregateFunction func) {
if (func.arity() <= 1) {
return false;
Expand Down Expand Up @@ -230,6 +234,11 @@ private static boolean needTransform(LogicalAggregate<Plan> agg, List<Alias> ali
if (distinctFunc.size() <= 1) {
return false;
}
// when this aggregate is not distinctMultiColumns, and group by expressions is not empty
// e.g. sql1: select count(distinct a), count(distinct b) from t1 group by c;
// sql2: select count(distinct a) from t1 group by c;
// the physical plan of sql1 and sql2 is similar, both are 2-phase aggregate,
// so there is no need to do this rewrite
if (!distinctMultiColumns && !agg.getGroupByExpressions().isEmpty()) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.apache.doris.nereids.trees.expressions.functions.agg;

/** MultiDistinctTrait*/
/** aggregate functions which have corresponding MultiDistinctXXX class,
* e.g. SUM,SUM0,COUNT,GROUP_CONCAT
* */
public interface SupportMultiDistinct {
AggregateFunction convertToMultiDistinct();
}

This file was deleted.

Loading

0 comments on commit 791a242

Please sign in to comment.