Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Dec 12, 2024
1 parent 09619fc commit 3d47101
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.doris.nereids.properties.DistributionSpecGather;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.AggMode;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
Expand All @@ -46,7 +47,7 @@
* plan: topn(1) -> aggGlobal -> shuffle -> aggLocal -> scan
* optimization: aggLocal and aggGlobal only need to generate the smallest row with respect to o_clerk.
*
* TODO: the following case is not covered:
* Attention: the following case is error-prone
* sql: select sum(o_shippriority) from orders group by o_clerk limit 1;
* plan: limit -> aggGlobal -> shuffle -> aggLocal -> scan
* aggGlobal may receive partial aggregate results, and hence is not supported now
Expand All @@ -55,18 +56,13 @@
* (2,1),(1,1) => limit => may output (2, 1), which is not complete, missing (2, 2) in instance2
*
*TOPN:
* Precondition: topn orderkeys are the prefix of group keys
* TODO: topnKeys could be subset of groupKeys. This will be implemented in future
* Pattern 2-phase agg:
* topn -> aggGlobal -> distribute -> aggLocal
* =>
* topn(n) -> aggGlobal(topn=n) -> distribute -> aggLocal(topn=n)
* topn(n) -> aggGlobal(topNInfo) -> distribute -> aggLocal(topNInfo)
* Pattern 1-phase agg:
* topn->agg->Any(not agg) -> topn -> agg(topn=n) -> any
*
* LIMIT:
* Pattern 1: limit->agg(1phase)->any
* Pattern 2: limit->agg(global)->gather->agg(local)
* topn->agg->Any(not agg) -> topn -> agg(topNInfo) -> any
*/
public class PushTopnToAgg extends PlanPostProcessor {
@Override
Expand All @@ -81,9 +77,8 @@ public Plan visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, CascadesContext
}
if (topnChild instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) topnChild;
List<OrderKey> orderKeys = tryGenerateOrderKeyByGroupKeyAndTopnKey(topN, upperAgg);
List<OrderKey> orderKeys = generateOrderKeyByGroupKeyAndTopnKey(topN, upperAgg);
if (!orderKeys.isEmpty()) {

if (upperAgg.getAggPhase().isGlobal() && upperAgg.getAggMode() == AggMode.BUFFER_TO_RESULT) {
upperAgg.setTopnPushInfo(new TopnPushInfo(
orderKeys,
Expand All @@ -107,6 +102,24 @@ public Plan visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, CascadesContext
return topN;
}

private List<OrderKey> generateOrderKeyByGroupKeyAndTopnKey(PhysicalTopN<? extends Plan> topN,
PhysicalHashAggregate<? extends Plan> agg) {
List<OrderKey> orderKeys = Lists.newArrayListWithCapacity(agg.getGroupByExpressions().size());
if (topN.getOrderKeys().size() < agg.getGroupByExpressions().size()) {
return Lists.newArrayList();
}
for (int i = 0; i < agg.getGroupByExpressions().size(); i++) {
Expression groupByKey = agg.getGroupByExpressions().get(i);
Expression orderKey = topN.getOrderKeys().get(i).getExpr();
if (groupByKey.equals(orderKey)) {
orderKeys.add(topN.getOrderKeys().get(i));
} else {
orderKeys.clear();
break;
}
}
return orderKeys;
}
/**
return true, if topn order-key is prefix of agg group-key, ignore asc/desc and null_first
TODO order-key can be subset of group-key. BE does not support now.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
Expand All @@ -33,8 +34,9 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
Expand All @@ -56,107 +58,88 @@ public List<Rule> buildRules() {
>= limit.getLimit() + limit.getOffset())
.then(limit -> {
LogicalAggregate<? extends Plan> agg = limit.child();
Optional<OrderKey> orderKeysOpt = tryGenerateOrderKeyByTheFirstGroupKey(agg);
if (!orderKeysOpt.isPresent()) {
return null;
}
List<OrderKey> orderKeys = Lists.newArrayList(orderKeysOpt.get());
List<OrderKey> orderKeys = generateOrderKeyByGroupKey(agg);
return new LogicalTopN<>(orderKeys, limit.getLimit(), limit.getOffset(), agg);
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG),
//limit->project->agg to topn->project->agg
//limit->project->agg to project->topn->agg
logicalLimit(logicalProject(logicalAggregate()))
.when(limit -> ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable().pushTopnToAgg
&& ConnectContext.get().getSessionVariable().topnOptLimitThreshold
>= limit.getLimit() + limit.getOffset())
.when(limit -> limit.child().isAllSlots())
.then(limit -> {
LogicalProject<? extends Plan> project = limit.child();
LogicalAggregate<? extends Plan> agg
= (LogicalAggregate<? extends Plan>) project.child();
Optional<OrderKey> orderKeysOpt = tryGenerateOrderKeyByTheFirstGroupKey(agg);
if (!orderKeysOpt.isPresent()) {
return null;
}
List<OrderKey> orderKeys = Lists.newArrayList(orderKeysOpt.get());
Plan result;

if (outputAllGroupKeys(limit, agg)) {
result = new LogicalTopN<>(orderKeys, limit.getLimit(),
limit.getOffset(), project);
} else {
// add the first group by key to topn, and prune this key by upper project
// topn order keys are prefix of group by keys
// refer to PushTopnToAgg.tryGenerateOrderKeyByGroupKeyAndTopnKey()
Expression firstGroupByKey = agg.getGroupByExpressions().get(0);
if (!(firstGroupByKey instanceof SlotReference)) {
return null;
}
boolean shouldPruneFirstGroupByKey = true;
if (project.getOutputs().contains(firstGroupByKey)) {
shouldPruneFirstGroupByKey = false;
} else {
List<NamedExpression> bottomProjections = Lists.newArrayList(project.getProjects());
bottomProjections.add((SlotReference) firstGroupByKey);
project = project.withProjects(bottomProjections);
}
LogicalTopN topn = new LogicalTopN<>(orderKeys, limit.getLimit(),
limit.getOffset(), project);
if (shouldPruneFirstGroupByKey) {
List<NamedExpression> limitOutput = limit.getOutput().stream()
.map(e -> (NamedExpression) e).collect(Collectors.toList());
result = new LogicalProject<>(limitOutput, topn);
} else {
result = topn;
}
}
return result;
List<OrderKey> orderKeys = generateOrderKeyByGroupKey(agg);
LogicalTopN topn = new LogicalTopN<>(orderKeys, limit.getLimit(),
limit.getOffset(), agg);
project = (LogicalProject<? extends Plan>) project.withChildren(topn);
return project;
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG),
// topn -> agg: add all group key to sort key, if sort key is prefix of group key
// topn -> agg: add all group key to sort key
logicalTopN(logicalAggregate())
.when(topn -> ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable().pushTopnToAgg
&& ConnectContext.get().getSessionVariable().topnOptLimitThreshold
>= topn.getLimit() + topn.getOffset())
.then(topn -> {
LogicalAggregate<? extends Plan> agg = (LogicalAggregate<? extends Plan>) topn.child();
List<OrderKey> newOrders = tryGenerateOrderKeyByGroupKeyAndTopnKey(topn, agg);
if (newOrders.isEmpty()) {
return topn;
List<OrderKey> newOrders = Lists.newArrayList(topn.getOrderKeys());
Set<Expression> orderExprs = topn.getOrderKeys().stream()
.map(orderKey -> orderKey.getExpr()).collect(Collectors.toSet());
boolean orderKeyChanged = false;
for (Expression expr : agg.getGroupByExpressions()) {
if (!orderExprs.contains(expr)) {
// after NormalizeAggregate, expr should be SlotReference
if (expr instanceof SlotReference) {
orderKeyChanged = true;
newOrders.add(new OrderKey(expr, true, true));
}
}
}
return orderKeyChanged ? topn.withOrderKeys(newOrders) : topn;
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG),
//topn -> project ->agg: add all group key to sort key, and prune column
logicalTopN(logicalProject(logicalAggregate()))
.when(topn -> ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable().pushTopnToAgg
&& ConnectContext.get().getSessionVariable().topnOptLimitThreshold
>= topn.getLimit() + topn.getOffset())
.when(topn -> topn.child().isAllSlots())
.then(topn -> {
LogicalProject project = topn.child();
LogicalAggregate<? extends Plan> agg = (LogicalAggregate) project.child();
List<OrderKey> newOrders = Lists.newArrayList(topn.getOrderKeys());
Set<Expression> orderExprs = topn.getOrderKeys().stream()
.map(orderKey -> orderKey.getExpr()).collect(Collectors.toSet());
boolean orderKeyChanged = false;
for (Expression expr : agg.getGroupByExpressions()) {
if (!orderExprs.contains(expr)) {
// after NormalizeAggregate, expr should be SlotReference
if (expr instanceof SlotReference) {
orderKeyChanged = true;
newOrders.add(new OrderKey(expr, true, true));
}
}
}
Plan result;
if (orderKeyChanged) {
topn = (LogicalTopN) topn.withChildren(agg);
topn.withOrderKeys(newOrders);
result = (Plan) project.withChildren(topn);
} else {
return topn.withOrderKeys(newOrders);
result = topn;
}
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG));
}

private List<OrderKey> tryGenerateOrderKeyByGroupKeyAndTopnKey(LogicalTopN<? extends Plan> topN,
LogicalAggregate<? extends Plan> agg) {
List<OrderKey> orderKeys = Lists.newArrayListWithCapacity(agg.getGroupByExpressions().size());
if (topN.getOrderKeys().size() > agg.getGroupByExpressions().size()) {
return orderKeys;
}
List<Expression> topnKeys = topN.getOrderKeys().stream()
.map(OrderKey::getExpr).collect(Collectors.toList());
for (int i = 0; i < topN.getOrderKeys().size(); i++) {
// prefix check
if (!topnKeys.get(i).equals(agg.getGroupByExpressions().get(i))) {
return Lists.newArrayList();
}
orderKeys.add(topN.getOrderKeys().get(i));
}
for (int i = topN.getOrderKeys().size(); i < agg.getGroupByExpressions().size(); i++) {
orderKeys.add(new OrderKey(agg.getGroupByExpressions().get(i), true, false));
}
return orderKeys;
}

private boolean outputAllGroupKeys(LogicalLimit limit, LogicalAggregate agg) {
return limit.getOutputSet().containsAll(agg.getGroupByExpressions());
return result;
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG)
);
}

private Optional<OrderKey> tryGenerateOrderKeyByTheFirstGroupKey(LogicalAggregate<? extends Plan> agg) {
if (agg.getGroupByExpressions().isEmpty()) {
return Optional.empty();
}
return Optional.of(new OrderKey(agg.getGroupByExpressions().get(0), true, false));
private List<OrderKey> generateOrderKeyByGroupKey(LogicalAggregate<? extends Plan> agg) {
return agg.getGroupByExpressions().stream()
.map(key -> new OrderKey(key, true, false))
.collect(Collectors.toList());
}
}

0 comments on commit 3d47101

Please sign in to comment.