Skip to content

Commit

Permalink
[fix](Nereids) set correct sort key for aggregate #45369 branch-3.0 (#…
Browse files Browse the repository at this point in the history
…45706)

### What problem does this PR solve?

pick #45369
  • Loading branch information
englefly authored Jan 5, 2025
1 parent 71a3cc7 commit e6678e6
Show file tree
Hide file tree
Showing 6 changed files with 307 additions and 317 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,157 +21,84 @@
package org.apache.doris.nereids.processor.post;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.properties.DistributionSpecGather;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.AggMode;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate.TopnPushInfo;
import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
import org.apache.doris.qe.ConnectContext;

import org.apache.hadoop.util.Lists;

import java.util.List;
import java.util.stream.Collectors;

/**
* Add SortInfo to Agg. This SortInfo is used as boundary, not used to sort elements.
* Add TopNInfo to Agg. This TopNInfo is used as boundary, not used to sort elements.
* example
* sql: select count(*) from orders group by o_clerk order by o_clerk limit 1;
* plan: topn(1) -> aggGlobal -> shuffle -> aggLocal -> scan
* optimization: aggLocal and aggGlobal only need to generate the smallest row with respect to o_clerk.
*
* TODO: the following case is not covered:
* sql: select sum(o_shippriority) from orders group by o_clerk limit 1;
* plan: limit -> aggGlobal -> shuffle -> aggLocal -> scan
* aggGlobal may receive partial aggregate results, and hence is not supported now
* instance1: input (key=2, v=1) => localAgg => (2, 1) => aggGlobal inst1 => (2, 1)
* instance2: input (key=1, v=1), (key=2, v=2) => localAgg inst2 => (1, 1)
* (2,1),(1,1) => limit => may output (2, 1), which is not complete, missing (2, 2) in instance2
*
*TOPN:
* Precondition: topn orderkeys are the prefix of group keys
* TODO: topnKeys could be subset of groupKeys. This will be implemented in future
* Pattern 2-phase agg:
* topn -> aggGlobal -> distribute -> aggLocal
* =>
* topn(n) -> aggGlobal(topn=n) -> distribute -> aggLocal(topn=n)
* Pattern 1-phase agg:
* topn->agg->Any(not agg) -> topn -> agg(topn=n) -> any
*
* LIMIT:
* Pattern 1: limit->agg(1phase)->any
* Pattern 2: limit->agg(global)->gather->agg(local)
* This rule only applies to the patterns
* 1. topn->project->agg, or
* 2. topn->agg
* that
* 1. orderKeys and groupkeys are one-one mapping
* 2. aggregate is not scalar agg
* Refer to LimitAggToTopNAgg rule.
*/
public class PushTopnToAgg extends PlanPostProcessor {
@Override
public Plan visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, CascadesContext ctx) {
topN.child().accept(this, ctx);
if (ConnectContext.get().getSessionVariable().topnOptLimitThreshold <= topN.getLimit() + topN.getOffset()) {
if (ConnectContext.get().getSessionVariable().topnOptLimitThreshold <= topN.getLimit() + topN.getOffset()
&& !ConnectContext.get().getSessionVariable().pushTopnToAgg) {
return topN;
}
Plan topnChild = topN.child();
if (topnChild instanceof PhysicalProject) {
topnChild = topnChild.child(0);
Plan topNChild = topN.child();
if (topNChild instanceof PhysicalProject) {
topNChild = topNChild.child(0);
}
if (topnChild instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) topnChild;
List<OrderKey> orderKeys = tryGenerateOrderKeyByGroupKeyAndTopnKey(topN, upperAgg);
if (!orderKeys.isEmpty()) {

if (upperAgg.getAggPhase().isGlobal() && upperAgg.getAggMode() == AggMode.BUFFER_TO_RESULT) {
upperAgg.setTopnPushInfo(new TopnPushInfo(
orderKeys,
topN.getLimit() + topN.getOffset()));
if (upperAgg.child() instanceof PhysicalDistribute
&& upperAgg.child().child(0) instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> bottomAgg =
(PhysicalHashAggregate<? extends Plan>) upperAgg.child().child(0);
if (topNChild instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) topNChild;
if (isGroupKeyIdenticalToOrderKey(topN, upperAgg)) {
upperAgg.setTopnPushInfo(new TopnPushInfo(
topN.getOrderKeys(),
topN.getLimit() + topN.getOffset()));
if (upperAgg.child() instanceof PhysicalDistribute
&& upperAgg.child().child(0) instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> bottomAgg =
(PhysicalHashAggregate<? extends Plan>) upperAgg.child().child(0);
if (isGroupKeyIdenticalToOrderKey(topN, bottomAgg)) {
bottomAgg.setTopnPushInfo(new TopnPushInfo(
topN.getOrderKeys(),
topN.getLimit() + topN.getOffset()));
}
} else if (upperAgg.child() instanceof PhysicalHashAggregate) {
// multi-distinct plan
PhysicalHashAggregate<? extends Plan> bottomAgg =
(PhysicalHashAggregate<? extends Plan>) upperAgg.child();
if (isGroupKeyIdenticalToOrderKey(topN, bottomAgg)) {
bottomAgg.setTopnPushInfo(new TopnPushInfo(
orderKeys,
topN.getOrderKeys(),
topN.getLimit() + topN.getOffset()));
}
} else if (upperAgg.getAggPhase().isLocal() && upperAgg.getAggMode() == AggMode.INPUT_TO_RESULT) {
// one phase agg
upperAgg.setTopnPushInfo(new TopnPushInfo(
orderKeys,
topN.getLimit() + topN.getOffset()));
}
}
}
return topN;
}

/**
return true, if topn order-key is prefix of agg group-key, ignore asc/desc and null_first
TODO order-key can be subset of group-key. BE does not support now.
*/
private List<OrderKey> tryGenerateOrderKeyByGroupKeyAndTopnKey(PhysicalTopN<? extends Plan> topN,
private boolean isGroupKeyIdenticalToOrderKey(PhysicalTopN<? extends Plan> topN,
PhysicalHashAggregate<? extends Plan> agg) {
List<OrderKey> orderKeys = Lists.newArrayListWithCapacity(agg.getGroupByExpressions().size());
if (topN.getOrderKeys().size() > agg.getGroupByExpressions().size()) {
return orderKeys;
}
List<Expression> topnKeys = topN.getOrderKeys().stream()
.map(OrderKey::getExpr).collect(Collectors.toList());
for (int i = 0; i < topN.getOrderKeys().size(); i++) {
// prefix check
if (!topnKeys.get(i).equals(agg.getGroupByExpressions().get(i))) {
return Lists.newArrayList();
}
orderKeys.add(topN.getOrderKeys().get(i));
}
for (int i = topN.getOrderKeys().size(); i < agg.getGroupByExpressions().size(); i++) {
orderKeys.add(new OrderKey(agg.getGroupByExpressions().get(i), true, false));
}
return orderKeys;
}

@Override
public Plan visitPhysicalLimit(PhysicalLimit<? extends Plan> limit, CascadesContext ctx) {
limit.child().accept(this, ctx);
if (ConnectContext.get().getSessionVariable().topnOptLimitThreshold <= limit.getLimit() + limit.getOffset()) {
return limit;
if (topN.getOrderKeys().size() != agg.getGroupByExpressions().size()) {
return false;
}
Plan limitChild = limit.child();
if (limitChild instanceof PhysicalProject) {
limitChild = limitChild.child(0);
}
if (limitChild instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) limitChild;
if (upperAgg.getAggPhase().isGlobal() && upperAgg.getAggMode() == AggMode.BUFFER_TO_RESULT) {
Plan child = upperAgg.child();
Plan grandChild = child.child(0);
if (child instanceof PhysicalDistribute
&& ((PhysicalDistribute<?>) child).getDistributionSpec() instanceof DistributionSpecGather
&& grandChild instanceof PhysicalHashAggregate) {
upperAgg.setTopnPushInfo(new TopnPushInfo(
generateOrderKeyByGroupKey(upperAgg),
limit.getLimit() + limit.getOffset()));
PhysicalHashAggregate<? extends Plan> bottomAgg =
(PhysicalHashAggregate<? extends Plan>) grandChild;
bottomAgg.setTopnPushInfo(new TopnPushInfo(
generateOrderKeyByGroupKey(bottomAgg),
limit.getLimit() + limit.getOffset()));
}
} else if (upperAgg.getAggMode() == AggMode.INPUT_TO_RESULT) {
// 1-phase agg
upperAgg.setTopnPushInfo(new TopnPushInfo(
generateOrderKeyByGroupKey(upperAgg),
limit.getLimit() + limit.getOffset()));
for (int i = 0; i < agg.getGroupByExpressions().size(); i++) {
Expression groupByKey = agg.getGroupByExpressions().get(i);
Expression orderKey = topN.getOrderKeys().get(i).getExpr();
if (!groupByKey.equals(orderKey)) {
return false;
}
}
return limit;
}

private List<OrderKey> generateOrderKeyByGroupKey(PhysicalHashAggregate<? extends Plan> agg) {
return agg.getGroupByExpressions().stream()
.map(key -> new OrderKey(key, true, false))
.collect(Collectors.toList());
return true;
}
}
Loading

0 comments on commit e6678e6

Please sign in to comment.