Skip to content

Commit

Permalink
change to custom rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
feiniaofeiafei committed Dec 11, 2024
1 parent 444c440 commit 1fe5f2c
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
new CollectCteConsumerOutput()
)
),
// topic("distinct split", topDown(new DistinctSplit())),
topic("Collect used column", custom(RuleType.COLLECT_COLUMNS, QueryColumnCollector::new)
)
)
Expand Down Expand Up @@ -549,7 +550,8 @@ private static List<RewriteJob> getWholeTreeRewriteJobs(
rewriteJobs.addAll(jobs(topic("or expansion",
custom(RuleType.OR_EXPANSION, () -> OrExpansion.INSTANCE))));
}
rewriteJobs.addAll(jobs(topic("distinct split", topDown(new DistinctSplit()))));
rewriteJobs.addAll(jobs(topic("distinct split",
custom(RuleType.DISTINCT_SPLIT, () -> DistinctSplit.INSTANCE))));

if (needSubPathPushDown) {
rewriteJobs.addAll(jobs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
Expand Down Expand Up @@ -139,36 +138,6 @@ private void checkExpressionInputTypes(Plan plan) {
}

private void checkAggregate(LogicalAggregate<? extends Plan> aggregate) {
Set<AggregateFunction> aggregateFunctions = aggregate.getAggregateFunctions();
boolean distinctMultiColumns = false;
for (AggregateFunction func : aggregateFunctions) {
if (!func.isDistinct()) {
continue;
}
if (func.arity() <= 1) {
continue;
}
for (int i = 1; i < func.arity(); i++) {
if (!func.child(i).getInputSlots().isEmpty() && !(func.child(i) instanceof OrderExpression)) {
// think about group_concat(distinct col_1, ',')
distinctMultiColumns = true;
break;
}
}
if (distinctMultiColumns) {
break;
}
}

long distinctFunctionNum = 0;
for (AggregateFunction aggregateFunction : aggregateFunctions) {
distinctFunctionNum += aggregateFunction.isDistinct() ? 1 : 0;
}

// if (distinctMultiColumns && distinctFunctionNum > 1) {
// throw new AnalysisException(
// "The query contains multi count distinct or sum distinct, each can't have multi columns");
// }
for (Expression expr : aggregate.getGroupByExpressions()) {
if (expr.anyMatch(AggregateFunction.class::isInstance)) {
throw new AnalysisException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.rules.rewrite.DistinctSplit.DistinctSplitContext;
import org.apache.doris.nereids.trees.copier.DeepCopierContext;
import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
import org.apache.doris.nereids.trees.expressions.Alias;
Expand All @@ -38,8 +39,12 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
Expand All @@ -62,40 +67,61 @@
* +--LogicalAggregate(output:count(distinct b))
* +--LogicalCTEConsumer
* */
public class DistinctSplit extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalAggregate()
// TODO with source repeat aggregate need to be supported in future
.whenNot(agg -> agg.getSourceRepeat().isPresent())
.thenApply(ctx -> doSplit(ctx.root, ctx.cascadesContext))
.toRule(RuleType.DISTINCT_SPLIT);
public class DistinctSplit extends DefaultPlanRewriter<DistinctSplitContext> implements CustomRewriter {
public static DistinctSplit INSTANCE = new DistinctSplit();

/**DistinctSplitContext*/
public static class DistinctSplitContext {
List<LogicalCTEProducer<? extends Plan>> cteProducerList;
StatementContext statementContext;
CascadesContext cascadesContext;

public DistinctSplitContext(StatementContext statementContext, CascadesContext cascadesContext) {
this.statementContext = statementContext;
this.cteProducerList = new ArrayList<>();
this.cascadesContext = cascadesContext;
}
}

private static boolean isDistinctMultiColumns(AggregateFunction func) {
if (func.arity() <= 1) {
return false;
@Override
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
DistinctSplitContext ctx = new DistinctSplitContext(
jobContext.getCascadesContext().getStatementContext(), jobContext.getCascadesContext());
plan = plan.accept(this, ctx);
for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
LogicalCTEProducer<? extends Plan> producer = ctx.cteProducerList.get(i);
plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
}
for (int i = 1; i < func.arity(); ++i) {
// think about group_concat(distinct col_1, ',')
if (!(func.child(i) instanceof OrderExpression) && !func.child(i).getInputSlots().isEmpty()) {
return true;
}
return plan;
}

@Override
public Plan visitLogicalCTEAnchor(
LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, DistinctSplitContext ctx) {
Plan child1 = anchor.child(0).accept(this, ctx);
DistinctSplitContext consumerContext =
new DistinctSplitContext(ctx.statementContext, ctx.cascadesContext);
Plan child2 = anchor.child(1).accept(this, consumerContext);
for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) {
LogicalCTEProducer<? extends Plan> producer = consumerContext.cteProducerList.get(i);
child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, child2);
}
return false;
return anchor.withChildren(ImmutableList.of(child1, child2));
}

private static Plan doSplit(LogicalAggregate<Plan> agg, CascadesContext ctx) {
@Override
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, DistinctSplitContext ctx) {
List<Alias> distinctFuncWithAlias = new ArrayList<>();
List<Alias> otherAggFuncs = new ArrayList<>();
if (!needTransform(agg, distinctFuncWithAlias, otherAggFuncs)) {
return null;
if (!needTransform((LogicalAggregate<Plan>) agg, distinctFuncWithAlias, otherAggFuncs)) {
return agg;
}

LogicalAggregate<Plan> cloneAgg = (LogicalAggregate<Plan>) LogicalPlanDeepCopier.INSTANCE
.deepCopy(agg, new DeepCopierContext());
LogicalCTEProducer<Plan> producer = new LogicalCTEProducer<>(ctx.getStatementContext().getNextCTEId(),
LogicalCTEProducer<Plan> producer = new LogicalCTEProducer<>(ctx.statementContext.getNextCTEId(),
cloneAgg.child());
ctx.cteProducerList.add(producer);
Map<Slot, Slot> originToProducerSlot = new HashMap<>();
for (int i = 0; i < agg.child().getOutput().size(); ++i) {
Slot originSlot = agg.child().getOutput().get(i);
Expand All @@ -106,14 +132,14 @@ private static Plan doSplit(LogicalAggregate<Plan> agg, CascadesContext ctx) {
otherAggFuncs = ExpressionUtils.replace((List) otherAggFuncs, originToProducerSlot);
// construct cte consumer and aggregate
List<LogicalAggregate<Plan>> newAggs = new ArrayList<>();
// All aggFunc except count distinct are placed in the first one
// All otherAggFuncs are placed in the first one
Map<Alias, Alias> newToOriginDistinctFuncAlias = new HashMap<>();
List<Expression> outputJoinGroupBys = new ArrayList<>();
for (int i = 0; i < distinctFuncWithAlias.size(); ++i) {
Expression distinctAggFunc = distinctFuncWithAlias.get(i).child(0);
LogicalCTEConsumer consumer = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(),
LogicalCTEConsumer consumer = new LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
producer.getCteId(), "", producer);
ctx.putCTEIdToConsumer(consumer);
ctx.cascadesContext.putCTEIdToConsumer(consumer);
Map<Slot, Slot> producerToConsumerSlotMap = new HashMap<>();
for (Map.Entry<Slot, Slot> entry : consumer.getConsumerToProducerOutputMap().entrySet()) {
producerToConsumerSlotMap.put(entry.getValue(), entry.getKey());
Expand Down Expand Up @@ -143,12 +169,28 @@ private static Plan doSplit(LogicalAggregate<Plan> agg, CascadesContext ctx) {
}
List<Expression> groupBy = agg.getGroupByExpressions();
LogicalJoin<Plan, Plan> join = constructJoin(newAggs, groupBy);
LogicalProject<Plan> project = constructProject(groupBy, newToOriginDistinctFuncAlias,
return constructProject(groupBy, newToOriginDistinctFuncAlias,
outputJoinGroupBys, join);
return new LogicalCTEAnchor<Plan, Plan>(producer.getCteId(), producer, project);
}

private static boolean isDistinctMultiColumns(AggregateFunction func) {
if (func.arity() <= 1) {
return false;
}
for (int i = 1; i < func.arity(); ++i) {
// think about group_concat(distinct col_1, ',')
if (!(func.child(i) instanceof OrderExpression) && !func.child(i).getInputSlots().isEmpty()) {
return true;
}
}
return false;
}

private static boolean needTransform(LogicalAggregate<Plan> agg, List<Alias> aliases, List<Alias> otherAggFuncs) {
// TODO with source repeat aggregate need to be supported in future
if (agg.getSourceRepeat().isPresent()) {
return false;
}
Set<Expression> distinctFunc = new HashSet<>();
boolean distinctMultiColumns = false;
for (NamedExpression namedExpression : agg.getOutputExpressions()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public class OrExpansion extends DefaultPlanRewriter<OrExpandsionContext> implem

@Override
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
OrExpandsionContext ctx = new OrExpandsionContext(
OrExpandsionContext ctx = new OrExpandsionContext(
jobContext.getCascadesContext().getStatementContext(), jobContext.getCascadesContext());
plan = plan.accept(this, ctx);
for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void multiCountWithoutGby() {
String sql = "select count(distinct b), count(distinct a) from test_distinct_multi";
PlanChecker.from(connectContext).checkExplain(sql, planner -> {
Plan plan = planner.getOptimizedPlan();
MatchingUtils.assertMatches(plan, physicalResultSink(physicalCTEAnchor(physicalCTEProducer(any()), physicalProject(physicalNestedLoopJoin(
MatchingUtils.assertMatches(plan, physicalCTEAnchor(physicalCTEProducer(any()), physicalResultSink(physicalProject(physicalNestedLoopJoin(
physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))))),
physicalDistribute(physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))))))
)))));
Expand All @@ -52,7 +52,7 @@ void multiSumWithoutGby() {
String sql = "select sum(distinct b), sum(distinct a) from test_distinct_multi";
PlanChecker.from(connectContext).checkExplain(sql, planner -> {
Plan plan = planner.getOptimizedPlan();
MatchingUtils.assertMatches(plan, physicalResultSink(physicalCTEAnchor(physicalCTEProducer(any()), physicalProject(physicalNestedLoopJoin(
MatchingUtils.assertMatches(plan, physicalCTEAnchor(physicalCTEProducer(any()), physicalResultSink(physicalProject(physicalNestedLoopJoin(
physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))))),
physicalDistribute(physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))))))
)))));
Expand All @@ -64,7 +64,7 @@ void SumCountWithoutGby() {
String sql = "select sum(distinct b), count(distinct a) from test_distinct_multi";
PlanChecker.from(connectContext).checkExplain(sql, planner -> {
Plan plan = planner.getOptimizedPlan();
MatchingUtils.assertMatches(plan, physicalResultSink(physicalCTEAnchor(physicalCTEProducer(any()), physicalProject(physicalNestedLoopJoin(
MatchingUtils.assertMatches(plan, physicalCTEAnchor(physicalCTEProducer(any()), physicalResultSink(physicalProject(physicalNestedLoopJoin(
physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))))),
physicalDistribute(physicalHashAggregate(physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))))))
)))));
Expand All @@ -76,7 +76,7 @@ void CountMultiColumnsWithoutGby() {
String sql = "select count(distinct b,c), count(distinct a,b) from test_distinct_multi";
PlanChecker.from(connectContext).checkExplain(sql, planner -> {
Plan plan = planner.getOptimizedPlan();
MatchingUtils.assertMatches(plan, physicalResultSink(physicalCTEAnchor(physicalCTEProducer(any()), physicalProject(physicalNestedLoopJoin(
MatchingUtils.assertMatches(plan, physicalCTEAnchor(physicalCTEProducer(any()), physicalResultSink(physicalProject(physicalNestedLoopJoin(
physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))),
physicalDistribute(physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))))
)))));
Expand All @@ -88,7 +88,7 @@ void CountMultiColumnsWithGby() {
String sql = "select count(distinct b,c), count(distinct a,b) from test_distinct_multi group by d";
PlanChecker.from(connectContext).checkExplain(sql, planner -> {
Plan plan = planner.getOptimizedPlan();
MatchingUtils.assertMatches(plan, physicalResultSink(physicalCTEAnchor(physicalCTEProducer(any()), physicalDistribute(physicalProject(physicalHashJoin(
MatchingUtils.assertMatches(plan, physicalCTEAnchor(physicalCTEProducer(any()), physicalResultSink(physicalDistribute(physicalProject(physicalHashJoin(
physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any())))),
physicalHashAggregate(physicalHashAggregate(physicalDistribute(physicalHashAggregate(any()))))
))))));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,10 @@
1 4 25 2

-- !multi_count_without_gby --
PhysicalResultSink
--PhysicalCteAnchor ( cteId=CTEId#0 )
----PhysicalCteProducer ( cteId=CTEId#0 )
------PhysicalOlapScan[test_distinct_multi]
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----PhysicalOlapScan[test_distinct_multi]
--PhysicalResultSink
----NestedLoopJoin[CROSS_JOIN]
------hashAgg[DISTINCT_GLOBAL]
--------hashAgg[DISTINCT_LOCAL]
Expand All @@ -272,10 +272,10 @@ PhysicalResultSink
--------------PhysicalCteConsumer ( cteId=CTEId#0 )

-- !multi_sum_without_gby --
PhysicalResultSink
--PhysicalCteAnchor ( cteId=CTEId#0 )
----PhysicalCteProducer ( cteId=CTEId#0 )
------PhysicalOlapScan[test_distinct_multi]
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----PhysicalOlapScan[test_distinct_multi]
--PhysicalResultSink
----NestedLoopJoin[CROSS_JOIN]
------hashAgg[DISTINCT_GLOBAL]
--------hashAgg[DISTINCT_LOCAL]
Expand All @@ -289,10 +289,10 @@ PhysicalResultSink
--------------PhysicalCteConsumer ( cteId=CTEId#0 )

-- !sum_count_without_gby --
PhysicalResultSink
--PhysicalCteAnchor ( cteId=CTEId#0 )
----PhysicalCteProducer ( cteId=CTEId#0 )
------PhysicalOlapScan[test_distinct_multi]
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----PhysicalOlapScan[test_distinct_multi]
--PhysicalResultSink
----NestedLoopJoin[CROSS_JOIN]
------hashAgg[DISTINCT_GLOBAL]
--------hashAgg[DISTINCT_LOCAL]
Expand All @@ -306,10 +306,10 @@ PhysicalResultSink
--------------PhysicalCteConsumer ( cteId=CTEId#0 )

-- !multi_count_mulitcols_without_gby --
PhysicalResultSink
--PhysicalCteAnchor ( cteId=CTEId#0 )
----PhysicalCteProducer ( cteId=CTEId#0 )
------PhysicalOlapScan[test_distinct_multi]
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----PhysicalOlapScan[test_distinct_multi]
--PhysicalResultSink
----NestedLoopJoin[CROSS_JOIN]
------hashAgg[DISTINCT_LOCAL]
--------hashAgg[GLOBAL]
Expand All @@ -321,10 +321,10 @@ PhysicalResultSink
------------PhysicalCteConsumer ( cteId=CTEId#0 )

-- !multi_count_mulitcols_with_gby --
PhysicalResultSink
--PhysicalCteAnchor ( cteId=CTEId#0 )
----PhysicalCteProducer ( cteId=CTEId#0 )
------PhysicalOlapScan[test_distinct_multi]
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----PhysicalOlapScan[test_distinct_multi]
--PhysicalResultSink
----hashJoin[INNER_JOIN] hashCondition=((.d = .d)) otherCondition=()
------hashAgg[DISTINCT_LOCAL]
--------hashAgg[GLOBAL]
Expand All @@ -336,10 +336,10 @@ PhysicalResultSink
------------PhysicalCteConsumer ( cteId=CTEId#0 )

-- !three_count_mulitcols_without_gby --
PhysicalResultSink
--PhysicalCteAnchor ( cteId=CTEId#0 )
----PhysicalCteProducer ( cteId=CTEId#0 )
------PhysicalOlapScan[test_distinct_multi]
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----PhysicalOlapScan[test_distinct_multi]
--PhysicalResultSink
----NestedLoopJoin[CROSS_JOIN]
------NestedLoopJoin[CROSS_JOIN]
--------hashAgg[DISTINCT_LOCAL]
Expand All @@ -356,10 +356,10 @@ PhysicalResultSink
------------PhysicalCteConsumer ( cteId=CTEId#0 )

-- !four_count_mulitcols_with_gby --
PhysicalResultSink
--PhysicalCteAnchor ( cteId=CTEId#0 )
----PhysicalCteProducer ( cteId=CTEId#0 )
------PhysicalOlapScan[test_distinct_multi]
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----PhysicalOlapScan[test_distinct_multi]
--PhysicalResultSink
----hashJoin[INNER_JOIN] hashCondition=((.d = .d)) otherCondition=()
------hashJoin[INNER_JOIN] hashCondition=((.d = .d)) otherCondition=()
--------hashAgg[DISTINCT_LOCAL]
Expand Down
Loading

0 comments on commit 1fe5f2c

Please sign in to comment.