Skip to content

Commit

Permalink
[Feat](nereids) support pull up predicate from set operator (#39450) (#…
Browse files Browse the repository at this point in the history
…44056)

cherry-pick #39450 to branch-2.1
  • Loading branch information
feiniaofeiafei authored Dec 3, 2024
1 parent d9ef316 commit 3e16922
Show file tree
Hide file tree
Showing 37 changed files with 3,817 additions and 1,878 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite;
import org.apache.doris.nereids.rules.expression.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.ExpressionNormalizationAndOptimization;
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
import org.apache.doris.nereids.rules.rewrite.AddDefaultLimit;
Expand Down Expand Up @@ -286,6 +285,21 @@ public class Rewriter extends AbstractBatchJobExecutor {
topDown(new ConvertInnerOrCrossJoin()),
topDown(new ProjectOtherJoinConditionForNestedLoopJoin())
),
topic("Set operation optimization",
// Do MergeSetOperation first because we hope to match pattern of Distinct SetOperator.
topDown(new PushProjectThroughUnion(), new MergeProjects()),
bottomUp(new MergeSetOperations(), new MergeSetOperationsExcept()),
bottomUp(new PushProjectIntoOneRowRelation()),
topDown(new MergeOneRowRelationIntoUnion()),
costBased(topDown(new InferSetOperatorDistinct())),
topDown(new BuildAggForUnion()),
bottomUp(new EliminateEmptyRelation()),
// when union has empty relation child and constantExprsList is not empty,
// after EliminateEmptyRelation, project can be pushed into union
topDown(new PushProjectIntoUnion())
),
// putting the "Column pruning and infer predicate" topic behind the "Set operation optimization"
// is because that pulling up predicates from union needs EliminateEmptyRelation in union child
topic("Column pruning and infer predicate",
custom(RuleType.COLUMN_PRUNING, ColumnPruning::new),
custom(RuleType.INFER_PREDICATES, InferPredicates::new),
Expand All @@ -299,24 +313,11 @@ public class Rewriter extends AbstractBatchJobExecutor {
// after eliminate outer join, we can move some filters to join.otherJoinConjuncts,
// this can help to translate plan to backend
topDown(new PushFilterInsideJoin()),
topDown(new FindHashConditionForJoin()),
topDown(new ExpressionNormalization())
topDown(new FindHashConditionForJoin())
),

// this rule should invoke after ColumnPruning
custom(RuleType.ELIMINATE_UNNECESSARY_PROJECT, EliminateUnnecessaryProject::new),

topic("Set operation optimization",
// Do MergeSetOperation first because we hope to match pattern of Distinct SetOperator.
topDown(new PushProjectThroughUnion(), new MergeProjects()),
bottomUp(new MergeSetOperations(), new MergeSetOperationsExcept()),
bottomUp(new PushProjectIntoOneRowRelation()),
topDown(new MergeOneRowRelationIntoUnion()),
topDown(new PushProjectIntoUnion()),
costBased(topDown(new InferSetOperatorDistinct())),
topDown(new BuildAggForUnion())
),

topic("Eliminate GroupBy",
topDown(new EliminateGroupBy(),
new MergeAggregate(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectFilterScanRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectJoinRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectScanRule;
import org.apache.doris.nereids.rules.expression.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.ExpressionOptimization;
import org.apache.doris.nereids.rules.implementation.AggregateStrategies;
import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows;
Expand Down Expand Up @@ -96,6 +97,7 @@
import org.apache.doris.nereids.rules.implementation.LogicalWindowToPhysicalWindow;
import org.apache.doris.nereids.rules.rewrite.ConvertOuterJoinToAntiJoin;
import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow;
import org.apache.doris.nereids.rules.rewrite.EliminateFilter;
import org.apache.doris.nereids.rules.rewrite.EliminateOuterJoin;
import org.apache.doris.nereids.rules.rewrite.MaxMinFilterPushDown;
import org.apache.doris.nereids.rules.rewrite.MergeFilters;
Expand Down Expand Up @@ -167,7 +169,12 @@ public class RuleSet {
new PushDownAliasThroughJoin(),
new PushDownFilterThroughWindow(),
new PushDownFilterThroughPartitionTopN(),
new ExpressionOptimization()
new ExpressionOptimization(),
// some useless predicates(e.g. 1=1) can be inferred by InferPredicates,
// the FoldConstantRule in ExpressionNormalization can fold 1=1 to true
// and EliminateFilter can eliminate the useless filter
new ExpressionNormalization(),
new EliminateFilter()
);

public static final List<Rule> IMPLEMENTATION_RULES = planRuleFactories()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.UnaryNode;
Expand Down Expand Up @@ -45,6 +46,9 @@
/**
* try to eliminate sub plan tree which contains EmptyRelation
*/
@DependsRules ({
BuildAggForUnion.class
})
public class EliminateEmptyRelation implements RewriteRuleFactory {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,23 @@

import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalExcept;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalIntersect;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -107,6 +113,45 @@ public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, JobContext
return filter;
}

@Override
public Plan visitLogicalExcept(LogicalExcept except, JobContext context) {
except = visitChildren(this, except, context);
Set<Expression> baseExpressions = pullUpPredicates(except);
if (baseExpressions.isEmpty()) {
return except;
}
ImmutableList.Builder<Plan> builder = ImmutableList.builder();
builder.add(except.child(0));
for (int i = 1; i < except.arity(); ++i) {
Map<Expression, Expression> replaceMap = new HashMap<>();
for (int j = 0; j < except.getOutput().size(); ++j) {
NamedExpression output = except.getOutput().get(j);
replaceMap.put(output, except.getRegularChildOutput(i).get(j));
}
builder.add(inferNewPredicate(except.child(i), ExpressionUtils.replace(baseExpressions, replaceMap)));
}
return except.withChildren(builder.build());
}

@Override
public Plan visitLogicalIntersect(LogicalIntersect intersect, JobContext context) {
intersect = visitChildren(this, intersect, context);
Set<Expression> baseExpressions = pullUpPredicates(intersect);
if (baseExpressions.isEmpty()) {
return intersect;
}
ImmutableList.Builder<Plan> builder = ImmutableList.builder();
for (int i = 0; i < intersect.arity(); ++i) {
Map<Expression, Expression> replaceMap = new HashMap<>();
for (int j = 0; j < intersect.getOutput().size(); ++j) {
NamedExpression output = intersect.getOutput().get(j);
replaceMap.put(output, intersect.getRegularChildOutput(i).get(j));
}
builder.add(inferNewPredicate(intersect.child(i), ExpressionUtils.replace(baseExpressions, replaceMap)));
}
return intersect.withChildren(builder.build());
}

private Set<Expression> getAllExpressions(Plan left, Plan right, Optional<Expression> condition) {
Set<Expression> baseExpressions = pullUpPredicates(left);
baseExpressions.addAll(pullUpPredicates(right));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,22 @@
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalExcept;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalIntersect;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;

Expand All @@ -37,7 +44,10 @@
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
Expand All @@ -58,6 +68,78 @@ public ImmutableSet<Expression> visit(Plan plan, Void context) {
return ImmutableSet.of();
}

@Override
public ImmutableSet<Expression> visitLogicalOneRowRelation(LogicalOneRowRelation r, Void context) {
ImmutableSet.Builder<Expression> predicates = ImmutableSet.builder();
for (NamedExpression expr : r.getProjects()) {
if (expr instanceof Alias && expr.child(0) instanceof Literal) {
predicates.add(new EqualTo(expr.toSlot(), expr.child(0)));
}
}
return predicates.build();
}

@Override
public ImmutableSet<Expression> visitLogicalIntersect(LogicalIntersect intersect, Void context) {
return cacheOrElse(intersect, () -> {
ImmutableSet.Builder<Expression> builder = ImmutableSet.builder();
for (int i = 0; i < intersect.children().size(); ++i) {
Plan child = intersect.child(i);
Set<Expression> childFilters = child.accept(this, context);
if (childFilters.isEmpty()) {
continue;
}
Map<Expression, Expression> replaceMap = new HashMap<>();
for (int j = 0; j < intersect.getOutput().size(); ++j) {
NamedExpression output = intersect.getOutput().get(j);
replaceMap.put(intersect.getRegularChildOutput(i).get(j), output);
}
builder.addAll(ExpressionUtils.replace(childFilters, replaceMap));
}
return getAvailableExpressions(builder.build(), intersect);
});
}

@Override
public ImmutableSet<Expression> visitLogicalExcept(LogicalExcept except, Void context) {
return cacheOrElse(except, () -> {
if (except.arity() < 1) {
return ImmutableSet.of();
}
Set<Expression> firstChildFilters = except.child(0).accept(this, context);
if (firstChildFilters.isEmpty()) {
return ImmutableSet.of();
}
Map<Expression, Expression> replaceMap = new HashMap<>();
for (int i = 0; i < except.getOutput().size(); ++i) {
NamedExpression output = except.getOutput().get(i);
replaceMap.put(except.getRegularChildOutput(0).get(i), output);
}
return ImmutableSet.copyOf(ExpressionUtils.replace(firstChildFilters, replaceMap));
});
}

@Override
public ImmutableSet<Expression> visitLogicalUnion(LogicalUnion union, Void context) {
return cacheOrElse(union, () -> {
if (!union.getConstantExprsList().isEmpty() && union.arity() == 0) {
return getFiltersFromUnionConstExprs(union);
} else if (union.getConstantExprsList().isEmpty() && union.arity() != 0) {
return getFiltersFromUnionChild(union, context);
} else if (!union.getConstantExprsList().isEmpty() && union.arity() != 0) {
HashSet<Expression> fromChildFilters = new HashSet<>(getFiltersFromUnionChild(union, context));
if (fromChildFilters.isEmpty()) {
return ImmutableSet.of();
}
if (!ExpressionUtils.unionConstExprsSatisfyConjuncts(union, fromChildFilters)) {
return ImmutableSet.of();
}
return ImmutableSet.copyOf(fromChildFilters);
}
return ImmutableSet.of();
});
}

@Override
public ImmutableSet<Expression> visitLogicalFilter(LogicalFilter<? extends Plan> filter, Void context) {
return cacheOrElse(filter, () -> {
Expand All @@ -75,6 +157,10 @@ public ImmutableSet<Expression> visitLogicalJoin(LogicalJoin<? extends Plan, ? e
ImmutableSet<Expression> rightPredicates = join.right().accept(this, context);
predicates.addAll(leftPredicates);
predicates.addAll(rightPredicates);
if (join.getJoinType() == JoinType.CROSS_JOIN || join.getJoinType() == JoinType.INNER_JOIN) {
predicates.addAll(join.getHashJoinConjuncts());
predicates.addAll(join.getOtherJoinConjuncts());
}
return getAvailableExpressions(predicates, join);
});
}
Expand Down Expand Up @@ -122,6 +208,9 @@ private ImmutableSet<Expression> cacheOrElse(Plan plan, Supplier<ImmutableSet<Ex
}

private ImmutableSet<Expression> getAvailableExpressions(Set<Expression> predicates, Plan plan) {
if (predicates.isEmpty()) {
return ImmutableSet.of();
}
Set<Expression> inferPredicates = PredicatePropagation.infer(predicates);
Builder<Expression> newPredicates = ImmutableSet.builderWithExpectedSize(predicates.size() + 10);
Set<Slot> outputSet = plan.getOutputSet();
Expand All @@ -143,4 +232,55 @@ private ImmutableSet<Expression> getAvailableExpressions(Set<Expression> predica
private boolean hasAgg(Expression expression) {
return expression.anyMatch(AggregateFunction.class::isInstance);
}

private ImmutableSet<Expression> getFiltersFromUnionChild(LogicalUnion union, Void context) {
Set<Expression> filters = new HashSet<>();
for (int i = 0; i < union.getArity(); ++i) {
Plan child = union.child(i);
Set<Expression> childFilters = child.accept(this, context);
if (childFilters.isEmpty()) {
return ImmutableSet.of();
}
Map<Expression, Expression> replaceMap = new HashMap<>();
for (int j = 0; j < union.getOutput().size(); ++j) {
NamedExpression output = union.getOutput().get(j);
replaceMap.put(union.getRegularChildOutput(i).get(j), output);
}
Set<Expression> unionFilters = ExpressionUtils.replace(childFilters, replaceMap);
if (0 == i) {
filters.addAll(unionFilters);
} else {
filters.retainAll(unionFilters);
}
if (filters.isEmpty()) {
return ImmutableSet.of();
}
}
return ImmutableSet.copyOf(filters);
}

private ImmutableSet<Expression> getFiltersFromUnionConstExprs(LogicalUnion union) {
List<List<NamedExpression>> constExprs = union.getConstantExprsList();
ImmutableSet.Builder<Expression> filtersFromConstExprs = ImmutableSet.builder();
for (int col = 0; col < union.getOutput().size(); ++col) {
Expression compareExpr = union.getOutput().get(col);
Set<Expression> options = new HashSet<>();
for (List<NamedExpression> constExpr : constExprs) {
if (constExpr.get(col) instanceof Alias
&& ((Alias) constExpr.get(col)).child() instanceof Literal) {
options.add(((Alias) constExpr.get(col)).child());
} else {
options.clear();
break;
}
}
options.removeIf(option -> option instanceof NullLiteral);
if (options.size() > 1) {
filtersFromConstExprs.add(new InPredicate(compareExpr, options));
} else if (options.size() == 1) {
filtersFromConstExprs.add(new EqualTo(compareExpr, options.iterator().next()));
}
}
return filtersFromConstExprs.build();
}
}
Loading

0 comments on commit 3e16922

Please sign in to comment.