Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(opt)(nereids) optimize and/or expression #44503

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.doris.nereids.cost;

import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
Expand Down Expand Up @@ -70,4 +72,14 @@ public Double visitLiteral(Literal literal, Void context) {
return 0.0;
}

@Override
public Double visitAnd(And and, Void context) {
return 0.0;
}

@Override
public Double visitOr(Or or, Void context) {
return 0.0;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.doris.nereids.cost;

import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
Expand All @@ -33,6 +35,7 @@

import com.google.common.collect.Maps;

import java.util.List;
import java.util.Map;

/**
Expand Down Expand Up @@ -83,4 +86,26 @@ public Double visitAlias(Alias alias, Void context) {
}
return alias.child().accept(this, context);
}

@Override
public Double visitAnd(And and, Void context) {
List<Expression> children = and.extract();
double sum = 0.0;
for (Expression child : children) {
sum += child.accept(this, context);
sum += dataTypeCost.getOrDefault(child.getDataType().getClass(), 0.1);
}
return sum;
}

@Override
public Double visitOr(Or or, Void context) {
List<Expression> children = or.extract();
double sum = 0.0;
for (Expression child : children) {
sum += child.accept(this, context);
sum += dataTypeCost.getOrDefault(child.getDataType().getClass(), 0.1);
}
return sum;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -322,22 +324,85 @@ public Expr visitNullLiteral(NullLiteral nullLiteral, PlanTranslatorContext cont
return nullLit;
}

private static class Frame {
int low;
int high;
CompoundPredicate.Operator op;
boolean processed;

Frame(int low, int high, CompoundPredicate.Operator op) {
this.low = low;
this.high = high;
this.op = op;
this.processed = false;
}
}

private Expr toBalancedTree(int low, int high, List<Expr> children,
CompoundPredicate.Operator op) {
Deque<Frame> stack = new ArrayDeque<>();
Deque<Expr> results = new ArrayDeque<>();

stack.push(new Frame(low, high, op));

while (!stack.isEmpty()) {
Frame currentFrame = stack.peek();

if (!currentFrame.processed) {
int l = currentFrame.low;
int h = currentFrame.high;
int diff = h - l;

if (diff == 0) {
results.push(children.get(l));
stack.pop();
} else if (diff == 1) {
Expr left = children.get(l);
Expr right = children.get(h);
CompoundPredicate cp = new CompoundPredicate(op, left, right);
results.push(cp);
stack.pop();
} else {
int mid = l + (h - l) / 2;

currentFrame.processed = true;

stack.push(new Frame(mid + 1, h, op));
stack.push(new Frame(l, mid, op));
}
} else {
stack.pop();
if (results.size() >= 2) {
Expr right = results.pop();
Expr left = results.pop();
CompoundPredicate cp = new CompoundPredicate(currentFrame.op, left, right);
results.push(cp);
}
}
}
return results.pop();
}

@Override
public Expr visitAnd(And and, PlanTranslatorContext context) {
org.apache.doris.analysis.CompoundPredicate cp = new org.apache.doris.analysis.CompoundPredicate(
org.apache.doris.analysis.CompoundPredicate.Operator.AND,
and.child(0).accept(this, context),
and.child(1).accept(this, context));
List<Expr> children = and.extract().stream().map(
e -> e.accept(this, context)
).collect(Collectors.toList());
CompoundPredicate cp = (CompoundPredicate) toBalancedTree(0, children.size() - 1,
children, CompoundPredicate.Operator.AND);

cp.setNullableFromNereids(and.nullable());
return cp;
}

@Override
public Expr visitOr(Or or, PlanTranslatorContext context) {
org.apache.doris.analysis.CompoundPredicate cp = new org.apache.doris.analysis.CompoundPredicate(
org.apache.doris.analysis.CompoundPredicate.Operator.OR,
or.child(0).accept(this, context),
or.child(1).accept(this, context));
List<Expr> children = or.extract().stream().map(
e -> e.accept(this, context)
).collect(Collectors.toList());
CompoundPredicate cp = (CompoundPredicate) toBalancedTree(0, children.size() - 1,
children, CompoundPredicate.Operator.OR);

cp.setNullableFromNereids(or.nullable());
return cp;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.BitNot;
import org.apache.doris.nereids.trees.expressions.BoundStar;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
Expand All @@ -57,6 +57,7 @@
import org.apache.doris.nereids.trees.expressions.Match;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Placeholder;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
Expand All @@ -73,6 +74,7 @@
import org.apache.doris.nereids.trees.expressions.functions.udf.UdfBuilder;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.trees.plans.Plan;
Expand All @@ -81,6 +83,7 @@
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
Expand All @@ -95,6 +98,7 @@
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Lists;
import org.apache.commons.lang3.StringUtils;

import java.util.ArrayList;
Expand Down Expand Up @@ -498,11 +502,59 @@ public Expression visitBinaryArithmetic(BinaryArithmetic binaryArithmetic, Expre
}

@Override
public Expression visitCompoundPredicate(CompoundPredicate compoundPredicate, ExpressionRewriteContext context) {
Expression left = compoundPredicate.left().accept(this, context);
Expression right = compoundPredicate.right().accept(this, context);
CompoundPredicate ret = (CompoundPredicate) compoundPredicate.withChildren(left, right);
return TypeCoercionUtils.processCompoundPredicate(ret);
public Expression visitOr(Or or, ExpressionRewriteContext context) {
List<Expression> children = ExpressionUtils.extractDisjunction(or);
List<Expression> newChildren = Lists.newArrayListWithCapacity(children.size());
boolean hasNewChild = false;
for (Expression child : children) {
Expression newChild = child.accept(this, context);
if (newChild == null) {
newChild = child;
}
if (newChild.getDataType().isNullType()) {
newChild = new NullLiteral(BooleanType.INSTANCE);
} else {
newChild = TypeCoercionUtils.castIfNotSameType(newChild, BooleanType.INSTANCE);
}

if (! child.equals(newChild)) {
hasNewChild = true;
}
newChildren.add(newChild);
}
if (hasNewChild) {
return ExpressionUtils.or(newChildren);
} else {
return or;
}
}

@Override
public Expression visitAnd(And and, ExpressionRewriteContext context) {
List<Expression> children = ExpressionUtils.extractConjunction(and);
List<Expression> newChildren = Lists.newArrayListWithCapacity(children.size());
boolean hasNewChild = false;
for (Expression child : children) {
Expression newChild = child.accept(this, context);
if (newChild == null) {
newChild = child;
}
if (newChild.getDataType().isNullType()) {
newChild = new NullLiteral(BooleanType.INSTANCE);
} else {
newChild = TypeCoercionUtils.castIfNotSameType(newChild, BooleanType.INSTANCE);
}

if (! child.equals(newChild)) {
hasNewChild = true;
}
newChildren.add(newChild);
}
if (hasNewChild) {
return ExpressionUtils.and(newChildren);
} else {
return and;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,10 +595,8 @@ public Expression visitBinaryOperator(BinaryOperator binaryOperator, SubqueryCon
isMarkJoin || ((binaryOperator.left().anyMatch(SubqueryExpr.class::isInstance)
|| binaryOperator.right().anyMatch(SubqueryExpr.class::isInstance))
&& (binaryOperator instanceof Or));

Expression left = replace(binaryOperator.left(), context);
Expression right = replace(binaryOperator.right(), context);

return binaryOperator.withChildren(left, right);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public abstract class AbstractExpressionRewriteRule extends DefaultExpressionRew

@Override
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
return expr.accept(this, ctx);
Expression result = expr.accept(this, ctx);
return result == null ? expr : result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
import org.apache.doris.nereids.pattern.ExpressionPatternRules;
import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners;
import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners.CombinedListener;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Or;

import com.google.common.collect.ImmutableList;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -104,9 +107,19 @@ private static Expression rewriteBottomUp(

private static Expression rewriteChildren(Expression parent, ExpressionRewriteContext context, int currentBatch,
ExpressionPatternRules rules, ExpressionPatternTraverseListeners listeners) {
List<Expression> children;
if (parent instanceof And) {
children = ((And) parent).extract();
} else if (parent instanceof Or) {
children = ((Or) parent).extract();
} else {
children = parent.children();
}

boolean changed = false;
ImmutableList.Builder<Expression> newChildren = ImmutableList.builderWithExpectedSize(parent.arity());
for (Expression child : parent.children()) {
ImmutableList.Builder<Expression> newChildren = ImmutableList.builderWithExpectedSize(
Math.max(parent.arity(), children.size()));
for (Expression child : children) {
Expression newChild = rewriteBottomUp(child, context, currentBatch, parent, rules, listeners);
changed |= !child.equals(newChild);
newChildren.add(newChild);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule;
import org.apache.doris.nereids.rules.expression.rules.LikeToEqualRewrite;
import org.apache.doris.nereids.rules.expression.rules.NullSafeEqualToEqual;
import org.apache.doris.nereids.rules.expression.rules.OrToIn;
import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
import org.apache.doris.nereids.rules.expression.rules.SimplifyDecimalV3Comparison;
import org.apache.doris.nereids.rules.expression.rules.SimplifyInPredicate;
Expand All @@ -48,7 +47,7 @@ public class ExpressionOptimization extends ExpressionRewrite {
SimplifyInPredicate.INSTANCE,
SimplifyDecimalV3Comparison.INSTANCE,
SimplifyRange.INSTANCE,
OrToIn.INSTANCE,
// OrToIn.INSTANCE,
DateFunctionRewrite.INSTANCE,
ArrayContainToArrayOverlap.INSTANCE,
CaseWhenToIf.INSTANCE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ public Expression visitConnectionId(ConnectionId connectionId, ExpressionRewrite
public Expression visitAnd(And and, ExpressionRewriteContext context) {
List<Expression> nonTrueLiteral = Lists.newArrayList();
int nullCount = 0;
for (Expression e : and.children()) {
for (Expression e : and.extract()) {
e = deepRewrite ? e.accept(this, context) : e;
if (BooleanLiteral.FALSE.equals(e)) {
return BooleanLiteral.FALSE;
Expand Down Expand Up @@ -418,7 +418,7 @@ public Expression visitAnd(And and, ExpressionRewriteContext context) {
public Expression visitOr(Or or, ExpressionRewriteContext context) {
List<Expression> nonFalseLiteral = Lists.newArrayList();
int nullCount = 0;
for (Expression e : or.children()) {
for (Expression e : or.extract()) {
e = deepRewrite ? e.accept(this, context) : e;
if (BooleanLiteral.TRUE.equals(e)) {
return BooleanLiteral.TRUE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,14 @@ public EvaluateRangeResult visitNot(Not not, EvaluateRangeInput context) {

private EvaluateRangeResult evaluateChildrenThenThis(Expression expr, EvaluateRangeInput context) {
// evaluate children
List<Expression> children = expr.children();
List<Expression> children;
if (expr instanceof And) {
children = ((And) expr).extract();
} else if (expr instanceof Or) {
children = ((Or) expr).extract();
} else {
children = expr.children();
}
ImmutableList.Builder<Expression> newChildren = ImmutableList.builderWithExpectedSize(children.size());
List<EvaluateRangeResult> childrenResults = new ArrayList<>(children.size());
boolean hasNewChildren = false;
Expand Down
Loading
Loading