Skip to content

Commit

Permalink
[opt](nereids)flattern and/or (#44574)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?
And/or expressions are represented by binary trees. The depth of the
tree causes stack overflow in recursive program calls. To solve this
problem, this pr will flatten the binary tree when processing and/or,
reducing the number of recursions. At the same time, we also rewrite
some recursive programs into non-recursive forms to avoid stack
overflow.

Issue
  • Loading branch information
englefly authored Nov 30, 2024
1 parent cf40dba commit 42a7734
Show file tree
Hide file tree
Showing 380 changed files with 1,300 additions and 930 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import org.apache.doris.load.loadv2.JobState;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.Lists;
Expand Down Expand Up @@ -528,7 +528,8 @@ private static void addNeedCancelLoadJob(String label, String state,
/**
* used for nereids planner
*/
public void cancelLoadJob(String dbName, String label, String state, BinaryOperator operator)
public void cancelLoadJob(String dbName, String label, String state,
Expression operator)
throws JobException, AnalysisException, DdlException {
Database db = Env.getCurrentInternalCatalog().getDbOrDdlException(dbName);
// List of load jobs waiting to be cancelled
Expand Down Expand Up @@ -582,7 +583,8 @@ public void cancelLoadJob(String dbName, String label, String state, BinaryOpera
}

private static void addNeedCancelLoadJob(String label, String state,
BinaryOperator operator, List<InsertJob> loadJobs,
Expression operator,
List<InsertJob> loadJobs,
List<InsertJob> matchLoadJobs)
throws AnalysisException {
PatternMatcher matcher = PatternMatcherWrapper.createMysqlPattern(label,
Expand Down
15 changes: 11 additions & 4 deletions fe/fe-core/src/main/java/org/apache/doris/load/ExportMgr.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import org.apache.doris.common.util.TimeUtils;
import org.apache.doris.datasource.InternalCatalog;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.scheduler.exception.JobException;
Expand Down Expand Up @@ -162,7 +162,9 @@ public void cancelExportJob(CancelExportStmt stmt) throws DdlException, Analysis
}
}

private List<ExportJob> getWaitingCancelJobs(String label, String state, BinaryOperator operator)
private List<ExportJob> getWaitingCancelJobs(
String label, String state,
Expression operator)
throws AnalysisException {
Predicate<ExportJob> jobFilter = buildCancelJobFilter(label, state, operator);
readLock();
Expand All @@ -174,7 +176,9 @@ private List<ExportJob> getWaitingCancelJobs(String label, String state, BinaryO
}

@VisibleForTesting
public static Predicate<ExportJob> buildCancelJobFilter(String label, String state, BinaryOperator operator)
public static Predicate<ExportJob> buildCancelJobFilter(
String label, String state,
Expression operator)
throws AnalysisException {
PatternMatcher matcher = PatternMatcherWrapper.createMysqlPattern(label,
CaseSensibility.LABEL.getCaseSensibility());
Expand All @@ -201,7 +205,10 @@ public static Predicate<ExportJob> buildCancelJobFilter(String label, String sta
/**
* used for Nereids planner
*/
public void cancelExportJob(String label, String state, BinaryOperator operator, String dbName)
public void cancelExportJob(
String label,
String state,
Expression operator, String dbName)
throws DdlException, AnalysisException {
// List of export jobs waiting to be cancelled
List<ExportJob> matchExportJobs = getWaitingCancelJobs(label, state, operator);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import org.apache.doris.load.Load;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.persist.CleanLabelOperationLog;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.OriginStatement;
Expand Down Expand Up @@ -248,7 +248,7 @@ public void recordFinishedLoadJob(String label, long transactionId, String dbNam
* Match need cancel loadJob by stmt.
**/
@VisibleForTesting
public static void addNeedCancelLoadJob(String label, String state, BinaryOperator operator,
public static void addNeedCancelLoadJob(String label, String state, Expression operator,
List<LoadJob> loadJobs, List<LoadJob> matchLoadJobs)
throws AnalysisException {
PatternMatcher matcher = PatternMatcherWrapper.createMysqlPattern(label,
Expand Down Expand Up @@ -281,7 +281,7 @@ public static void addNeedCancelLoadJob(String label, String state, BinaryOperat
/**
* Cancel load job by stmt.
**/
public void cancelLoadJob(String dbName, String label, String state, BinaryOperator operator)
public void cancelLoadJob(String dbName, String label, String state, Expression operator)
throws DdlException, AnalysisException {
Database db = Env.getCurrentInternalCatalog().getDbOrDdlException(dbName);
// List of load jobs waiting to be cancelled
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.doris.nereids.cost;

import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
Expand Down Expand Up @@ -70,4 +72,14 @@ public Double visitLiteral(Literal literal, Void context) {
return 0.0;
}

@Override
public Double visitAnd(And and, Void context) {
return 0.0;
}

@Override
public Double visitOr(Or or, Void context) {
return 0.0;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -322,22 +324,85 @@ public Expr visitNullLiteral(NullLiteral nullLiteral, PlanTranslatorContext cont
return nullLit;
}

private static class Frame {
int low;
int high;
CompoundPredicate.Operator op;
boolean processed;

Frame(int low, int high, CompoundPredicate.Operator op) {
this.low = low;
this.high = high;
this.op = op;
this.processed = false;
}
}

private Expr toBalancedTree(int low, int high, List<Expr> children,
CompoundPredicate.Operator op) {
Deque<Frame> stack = new ArrayDeque<>();
Deque<Expr> results = new ArrayDeque<>();

stack.push(new Frame(low, high, op));

while (!stack.isEmpty()) {
Frame currentFrame = stack.peek();

if (!currentFrame.processed) {
int l = currentFrame.low;
int h = currentFrame.high;
int diff = h - l;

if (diff == 0) {
results.push(children.get(l));
stack.pop();
} else if (diff == 1) {
Expr left = children.get(l);
Expr right = children.get(h);
CompoundPredicate cp = new CompoundPredicate(op, left, right);
results.push(cp);
stack.pop();
} else {
int mid = l + (h - l) / 2;

currentFrame.processed = true;

stack.push(new Frame(mid + 1, h, op));
stack.push(new Frame(l, mid, op));
}
} else {
stack.pop();
if (results.size() >= 2) {
Expr right = results.pop();
Expr left = results.pop();
CompoundPredicate cp = new CompoundPredicate(currentFrame.op, left, right);
results.push(cp);
}
}
}
return results.pop();
}

@Override
public Expr visitAnd(And and, PlanTranslatorContext context) {
org.apache.doris.analysis.CompoundPredicate cp = new org.apache.doris.analysis.CompoundPredicate(
org.apache.doris.analysis.CompoundPredicate.Operator.AND,
and.child(0).accept(this, context),
and.child(1).accept(this, context));
List<Expr> children = and.children().stream().map(
e -> e.accept(this, context)
).collect(Collectors.toList());
CompoundPredicate cp = (CompoundPredicate) toBalancedTree(0, children.size() - 1,
children, CompoundPredicate.Operator.AND);

cp.setNullableFromNereids(and.nullable());
return cp;
}

@Override
public Expr visitOr(Or or, PlanTranslatorContext context) {
org.apache.doris.analysis.CompoundPredicate cp = new org.apache.doris.analysis.CompoundPredicate(
org.apache.doris.analysis.CompoundPredicate.Operator.OR,
or.child(0).accept(this, context),
or.child(1).accept(this, context));
List<Expr> children = or.children().stream().map(
e -> e.accept(this, context)
).collect(Collectors.toList());
CompoundPredicate cp = (CompoundPredicate) toBalancedTree(0, children.size() - 1,
children, CompoundPredicate.Operator.OR);

cp.setNullableFromNereids(or.nullable());
return cp;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2081,8 +2081,14 @@ public Expression visitLogicalBinary(LogicalBinaryContext ctx) {
// into expressions.
Collections.reverse(contexts);
List<Expression> expressions = contexts.stream().map(this::getExpression).collect(Collectors.toList());
// Create a balanced tree.
return reduceToExpressionTree(0, expressions.size() - 1, expressions, ctx);
if (ctx.operator.getType() == DorisParser.AND) {
return new And(expressions);
} else if (ctx.operator.getType() == DorisParser.OR) {
return new Or(expressions);
} else {
// Create a balanced tree.
return reduceToExpressionTree(0, expressions.size() - 1, expressions, ctx);
}
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.BitNot;
import org.apache.doris.nereids.trees.expressions.BoundStar;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
Expand All @@ -57,6 +57,7 @@
import org.apache.doris.nereids.trees.expressions.Match;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Placeholder;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
Expand All @@ -73,6 +74,7 @@
import org.apache.doris.nereids.trees.expressions.functions.udf.UdfBuilder;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.trees.plans.Plan;
Expand All @@ -81,6 +83,7 @@
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
Expand All @@ -95,6 +98,7 @@
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Lists;
import org.apache.commons.lang3.StringUtils;

import java.util.ArrayList;
Expand Down Expand Up @@ -498,11 +502,59 @@ public Expression visitBinaryArithmetic(BinaryArithmetic binaryArithmetic, Expre
}

@Override
public Expression visitCompoundPredicate(CompoundPredicate compoundPredicate, ExpressionRewriteContext context) {
Expression left = compoundPredicate.left().accept(this, context);
Expression right = compoundPredicate.right().accept(this, context);
CompoundPredicate ret = (CompoundPredicate) compoundPredicate.withChildren(left, right);
return TypeCoercionUtils.processCompoundPredicate(ret);
public Expression visitOr(Or or, ExpressionRewriteContext context) {
List<Expression> children = ExpressionUtils.extractDisjunction(or);
List<Expression> newChildren = Lists.newArrayListWithCapacity(children.size());
boolean hasNewChild = false;
for (Expression child : children) {
Expression newChild = child.accept(this, context);
if (newChild == null) {
newChild = child;
}
if (newChild.getDataType().isNullType()) {
newChild = new NullLiteral(BooleanType.INSTANCE);
} else {
newChild = TypeCoercionUtils.castIfNotSameType(newChild, BooleanType.INSTANCE);
}

if (! child.equals(newChild)) {
hasNewChild = true;
}
newChildren.add(newChild);
}
if (hasNewChild) {
return ExpressionUtils.or(newChildren);
} else {
return or;
}
}

@Override
public Expression visitAnd(And and, ExpressionRewriteContext context) {
List<Expression> children = ExpressionUtils.extractConjunction(and);
List<Expression> newChildren = Lists.newArrayListWithCapacity(children.size());
boolean hasNewChild = false;
for (Expression child : children) {
Expression newChild = child.accept(this, context);
if (newChild == null) {
newChild = child;
}
if (newChild.getDataType().isNullType()) {
newChild = new NullLiteral(BooleanType.INSTANCE);
} else {
newChild = TypeCoercionUtils.castIfNotSameType(newChild, BooleanType.INSTANCE);
}

if (! child.equals(newChild)) {
hasNewChild = true;
}
newChildren.add(newChild);
}
if (hasNewChild) {
return ExpressionUtils.and(newChildren);
} else {
return and;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Exists;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InSubquery;
Expand Down Expand Up @@ -589,17 +589,19 @@ public Expression visitScalarSubquery(ScalarSubquery scalar, SubqueryContext con
}

@Override
public Expression visitBinaryOperator(BinaryOperator binaryOperator, SubqueryContext context) {
public Expression visitCompoundPredicate(CompoundPredicate compound, SubqueryContext context) {
// update isMarkJoin flag
isMarkJoin =
isMarkJoin || ((binaryOperator.left().anyMatch(SubqueryExpr.class::isInstance)
|| binaryOperator.right().anyMatch(SubqueryExpr.class::isInstance))
&& (binaryOperator instanceof Or));

Expression left = replace(binaryOperator.left(), context);
Expression right = replace(binaryOperator.right(), context);

return binaryOperator.withChildren(left, right);
if (compound instanceof Or) {
for (Expression child : compound.children()) {
if (child.anyMatch(SubqueryExpr.class::isInstance)) {
isMarkJoin = true;
break;
}
}
}
return compound.withChildren(
compound.children().stream().map(c -> replace(c, context)).collect(Collectors.toList())
);
}
}

Expand Down
Loading

0 comments on commit 42a7734

Please sign in to comment.