Skip to content

Commit

Permalink
[enhance](nereids) add rule MultiDistinctSplit (#45209)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Problem Summary:

This pr add a rewrite rule, which can do this 2 type of rewrite:
1. This rewrite can greatly improve the execution speed of multiple
count(distinct) operations. When 3be, ndv=10000000, the performance can
be improved by three to four times.

select count(distinct a),count(distinct b),count(distinct c) from t;
->
with tmp as (select * from t) 
select * from (select count(distinct a) from tmp) t1 cross join  (select count(distinct b) from tmp) t2 cross join  (select count(distinct c) from tmp) t3


2.Before this PR, the following SQL statement would fail to execute due
to an error: "The query contains multi count distinct or sum distinct,
each can't have multi columns". This PR rewrites this type of SQL
statement as follows, making it executable without an error.

select count(distinct a,d),count(distinct b,c),count(distinct c) from t;
->
with tmp as (select * from t) 
select * from (select count(distinct a,d) from tmp) t1 cross join  (select count(distinct b,c) from tmp) t2 cross join  (select count(distinct c) from tmp) t3

### Release note

Support multi count distinct with different parameters
  • Loading branch information
feiniaofeiafei authored Jan 3, 2025
1 parent 2254956 commit c28c00a
Show file tree
Hide file tree
Showing 19 changed files with 1,274 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
import org.apache.doris.nereids.rules.rewrite.SimplifyEncodeDecode;
import org.apache.doris.nereids.rules.rewrite.SimplifyWindowExpression;
import org.apache.doris.nereids.rules.rewrite.SplitLimit;
import org.apache.doris.nereids.rules.rewrite.SplitMultiDistinct;
import org.apache.doris.nereids.rules.rewrite.SumLiteralRewrite;
import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinAgg;
import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinAggProject;
Expand Down Expand Up @@ -565,6 +566,9 @@ private static List<RewriteJob> getWholeTreeRewriteJobs(
rewriteJobs.addAll(jobs(topic("or expansion",
custom(RuleType.OR_EXPANSION, () -> OrExpansion.INSTANCE))));
}
rewriteJobs.addAll(jobs(topic("split multi distinct",
custom(RuleType.SPLIT_MULTI_DISTINCT, () -> SplitMultiDistinct.INSTANCE))));

if (needSubPathPushDown) {
rewriteJobs.addAll(jobs(
topic("variant element_at push down",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ public enum RuleType {
MERGE_TOP_N(RuleTypeClass.REWRITE),
BUILD_AGG_FOR_UNION(RuleTypeClass.REWRITE),
COUNT_DISTINCT_REWRITE(RuleTypeClass.REWRITE),
SPLIT_MULTI_DISTINCT(RuleTypeClass.REWRITE),
INNER_TO_CROSS_JOIN(RuleTypeClass.REWRITE),
CROSS_TO_INNER_JOIN(RuleTypeClass.REWRITE),
PRUNE_EMPTY_PARTITION(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
Expand Down Expand Up @@ -139,36 +138,6 @@ private void checkExpressionInputTypes(Plan plan) {
}

private void checkAggregate(LogicalAggregate<? extends Plan> aggregate) {
Set<AggregateFunction> aggregateFunctions = aggregate.getAggregateFunctions();
boolean distinctMultiColumns = false;
for (AggregateFunction func : aggregateFunctions) {
if (!func.isDistinct()) {
continue;
}
if (func.arity() <= 1) {
continue;
}
for (int i = 1; i < func.arity(); i++) {
if (!func.child(i).getInputSlots().isEmpty() && !(func.child(i) instanceof OrderExpression)) {
// think about group_concat(distinct col_1, ',')
distinctMultiColumns = true;
break;
}
}
if (distinctMultiColumns) {
break;
}
}

long distinctFunctionNum = 0;
for (AggregateFunction aggregateFunction : aggregateFunctions) {
distinctFunctionNum += aggregateFunction.isDistinct() ? 1 : 0;
}

if (distinctMultiColumns && distinctFunctionNum > 1) {
throw new AnalysisException(
"The query contains multi count distinct or sum distinct, each can't have multi columns");
}
for (Expression expr : aggregate.getGroupByExpressions()) {
if (expr.anyMatch(AggregateFunction.class::isInstance)) {
throw new AnalysisException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
import org.apache.doris.nereids.trees.expressions.functions.agg.SupportMultiDistinct;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
Expand Down Expand Up @@ -1809,15 +1809,8 @@ private List<Expression> getHashAggregatePartitionExpressions(
}

private AggregateFunction tryConvertToMultiDistinct(AggregateFunction function) {
if (function instanceof Count && function.isDistinct()) {
return new MultiDistinctCount(function.getArgument(0),
function.getArguments().subList(1, function.arity()).toArray(new Expression[0]));
} else if (function instanceof Sum && function.isDistinct()) {
return ((Sum) function).convertToMultiDistinct();
} else if (function instanceof Sum0 && function.isDistinct()) {
return ((Sum0) function).convertToMultiDistinct();
} else if (function instanceof GroupConcat && function.isDistinct()) {
return ((GroupConcat) function).convertToMultiDistinct();
if (function instanceof SupportMultiDistinct && function.isDistinct()) {
return ((SupportMultiDistinct) function).convertToMultiDistinct();
}
return function;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
Expand Down Expand Up @@ -57,6 +58,36 @@ private LogicalAggregate checkDistinct(LogicalAggregate<? extends Plan> aggregat
}
}
}

boolean distinctMultiColumns = false;
for (AggregateFunction func : aggregate.getAggregateFunctions()) {
if (!func.isDistinct()) {
continue;
}
if (func.arity() <= 1) {
continue;
}
for (int i = 1; i < func.arity(); i++) {
if (!func.child(i).getInputSlots().isEmpty() && !(func.child(i) instanceof OrderExpression)) {
// think about group_concat(distinct col_1, ',')
distinctMultiColumns = true;
break;
}
}
if (distinctMultiColumns) {
break;
}
}

long distinctFunctionNum = 0;
for (AggregateFunction aggregateFunction : aggregate.getAggregateFunctions()) {
distinctFunctionNum += aggregateFunction.isDistinct() ? 1 : 0;
}

if (distinctMultiColumns && distinctFunctionNum > 1) {
throw new AnalysisException(
"The query contains multi count distinct or sum distinct, each can't have multi columns");
}
return aggregate;
}
}
Loading

0 comments on commit c28c00a

Please sign in to comment.