Skip to content

Commit

Permalink
[feature](Nereids) covert predicate to SARGABLE (apache#25180)
Browse files Browse the repository at this point in the history
covert predicate to SARGABLE 
1. support format like `1 - a`
2. support rearrange `year/month/week/day/minutes/seconds_sub/add` function
  • Loading branch information
keanji-x authored Oct 12, 2023
1 parent c63bf24 commit d6ff974
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,34 @@
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DaysAdd;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DaysSub;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursAdd;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursSub;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MinutesAdd;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MinutesSub;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MonthsAdd;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MonthsSub;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SecondsAdd;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SecondsSub;
import org.apache.doris.nereids.trees.expressions.functions.scalar.WeeksAdd;
import org.apache.doris.nereids.trees.expressions.functions.scalar.WeeksSub;
import org.apache.doris.nereids.trees.expressions.functions.scalar.YearsAdd;
import org.apache.doris.nereids.trees.expressions.functions.scalar.YearsSub;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.nereids.util.TypeUtils;

import com.google.common.collect.ImmutableMap;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;

/**
* Simplify arithmetic comparison rule.
Expand All @@ -40,68 +58,90 @@
public class SimplifyArithmeticComparisonRule extends AbstractExpressionRewriteRule {
public static final SimplifyArithmeticComparisonRule INSTANCE = new SimplifyArithmeticComparisonRule();

@Override
public Expression visit(Expression expr, ExpressionRewriteContext context) {
return expr;
}
// don't rearrange multiplication because divide may loss precision
final Map<Class<? extends Expression>, Class<? extends Expression>> rearrangementMap = ImmutableMap
.<Class<? extends Expression>, Class<? extends Expression>>builder()
.put(Add.class, Subtract.class)
.put(Subtract.class, Add.class)
.put(Divide.class, Multiply.class)
.put(YearsSub.class, YearsAdd.class)
.put(YearsAdd.class, YearsSub.class)
.put(MonthsSub.class, MonthsAdd.class)
.put(MonthsAdd.class, MonthsSub.class)
.put(WeeksSub.class, WeeksAdd.class)
.put(WeeksAdd.class, WeeksSub.class)
.put(DaysSub.class, DaysAdd.class)
.put(DaysAdd.class, DaysSub.class)
.put(HoursSub.class, HoursAdd.class)
.put(HoursAdd.class, HoursSub.class)
.put(MinutesSub.class, MinutesAdd.class)
.put(MinutesAdd.class, MinutesSub.class)
.put(SecondsSub.class, SecondsAdd.class)
.put(SecondsAdd.class, SecondsSub.class)
.build();

private Expression process(ComparisonPredicate predicate) {
Expression left = predicate.left();
Expression right = predicate.right();
if (TypeUtils.isAddOrSubtract(left)) {
Expression p = left.child(1);
if (p.isConstant()) {
if (TypeUtils.isAdd(left)) {
right = new Subtract(right, p);
}
if (TypeUtils.isSubtract(left)) {
right = new Add(right, p);
}
left = left.child(0);
@Override
public Expression visitComparisonPredicate(ComparisonPredicate comparison, ExpressionRewriteContext context) {
ComparisonPredicate newComparison = comparison;
if (couldRearrange(comparison)) {
newComparison = normalize(comparison);
if (newComparison == null) {
return comparison;
}
}
if (TypeUtils.isDivide(left)) {
Expression p = left.child(1);
if (p.isLiteral()) {
right = new Multiply(right, p);
left = left.child(0);
if (p.toString().startsWith("-")) {
Expression tmp = right;
right = left;
left = tmp;
}
try {
List<Expression> children = tryRearrangeChildren(newComparison.left(), newComparison.right());
newComparison = (ComparisonPredicate) newComparison.withChildren(children);
} catch (Exception e) {
return comparison;
}
}
if (left != predicate.left() || right != predicate.right()) {
predicate = (ComparisonPredicate) predicate.withChildren(left, right);
return TypeCoercionUtils.processComparisonPredicate(predicate);
} else {
return predicate;
}
return TypeCoercionUtils.processComparisonPredicate(newComparison);
}

@Override
public Expression visitGreaterThan(GreaterThan greaterThan, ExpressionRewriteContext context) {
return process(greaterThan);
private boolean couldRearrange(ComparisonPredicate cmp) {
return rearrangementMap.containsKey(cmp.left().getClass())
&& !cmp.left().isConstant()
&& cmp.left().children().stream().anyMatch(Expression::isConstant);
}

@Override
public Expression visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, ExpressionRewriteContext context) {
return process(greaterThanEqual);
}
private List<Expression> tryRearrangeChildren(Expression left, Expression right) throws Exception {
if (!left.child(1).isLiteral()) {
throw new RuntimeException(String.format("Expected literal when arranging children for Expr %s", left));
}
Literal leftLiteral = (Literal) left.child(1);
Expression leftExpr = left.child(0);

@Override
public Expression visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context) {
return process(equalTo);
}
Class<? extends Expression> oppositeOperator = rearrangementMap.get(left.getClass());
Expression newChild = oppositeOperator.getConstructor(Expression.class, Expression.class)
.newInstance(right, leftLiteral);

@Override
public Expression visitLessThan(LessThan lessThan, ExpressionRewriteContext context) {
return process(lessThan);
if (left instanceof Divide && leftLiteral.compareTo(new IntegerLiteral(0)) < 0) {
// Multiplying by a negative number will change the operator.
return Arrays.asList(newChild, leftExpr);
}
return Arrays.asList(leftExpr, newChild);
}

@Override
public Expression visitLessThanEqual(LessThanEqual lessThanEqual, ExpressionRewriteContext context) {
return process(lessThanEqual);
// Ensure that the second child must be Literal, such as
private @Nullable ComparisonPredicate normalize(ComparisonPredicate comparison) {
if (!(comparison.left().child(1) instanceof Literal)) {
Expression left = comparison.left();
if (comparison.left() instanceof Add) {
// 1 + a > 1 => a + 1 > 1
Expression newLeft = left.withChildren(left.child(1), left.child(0));
comparison = (ComparisonPredicate) comparison.withChildren(newLeft, comparison.right());
} else if (comparison.left() instanceof Subtract) {
// 1 - a > 1 => a + 1 < 1
Expression newLeft = left.child(0);
Expression newRight = new Add(left.child(1), comparison.right());
comparison = (ComparisonPredicate) comparison.withChildren(newLeft, newRight);
comparison = comparison.commute();
} else {
// Don't normalize division/multiplication because the slot sign is undecided.
return null;
}
}
return comparison;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;

public class SimplifyArithmeticRuleTest extends ExpressionRewriteTestHelper {
class SimplifyArithmeticRuleTest extends ExpressionRewriteTestHelper {
@Test
public void testSimplifyArithmetic() {
void testSimplifyArithmetic() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
SimplifyArithmeticRule.INSTANCE,
FunctionBinder.INSTANCE,
Expand All @@ -53,7 +53,7 @@ public void testSimplifyArithmetic() {
}

@Test
public void testSimplifyArithmeticComparison() {
void testSimplifyArithmeticComparison() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
SimplifyArithmeticRule.INSTANCE,
FoldConstantRule.INSTANCE,
Expand Down Expand Up @@ -88,7 +88,35 @@ public void testSimplifyArithmeticComparison() {
assertRewriteAfterTypeCoercion("IA * ID > IB * IC", "IA * ID > IB * IC");
assertRewriteAfterTypeCoercion("IA * ID / 2 > IB * IC", "cast((IA * ID) as DOUBLE) > cast((IB * IC) as DOUBLE) * 2");
assertRewriteAfterTypeCoercion("IA * ID / -2 > IB * IC", "cast((IB * IC) as DOUBLE) * -2 > cast((IA * ID) as DOUBLE)");
assertRewriteAfterTypeCoercion("1 - IA > 1", "(cast(IA as BIGINT) < 0)");
assertRewriteAfterTypeCoercion("1 - IA + 1 * 3 - 5 > 1", "(cast(IA as BIGINT) < -2)");
}

@Test
void testSimplifyDateTimeComparison() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
SimplifyArithmeticRule.INSTANCE,
FoldConstantRule.INSTANCE,
SimplifyArithmeticComparisonRule.INSTANCE,
SimplifyArithmeticRule.INSTANCE,
FunctionBinder.INSTANCE,
FoldConstantRule.INSTANCE
));
assertRewriteAfterTypeCoercion("years_add(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2020-01-01 00:00:00')");
assertRewriteAfterTypeCoercion("years_sub(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2022-01-01 00:00:00')");
assertRewriteAfterTypeCoercion("months_add(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2020-12-01 00:00:00')");
assertRewriteAfterTypeCoercion("months_sub(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2021-02-01 00:00:00')");
assertRewriteAfterTypeCoercion("weeks_add(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2020-12-25 00:00:00')");
assertRewriteAfterTypeCoercion("weeks_sub(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2021-01-08 00:00:00')");
assertRewriteAfterTypeCoercion("days_add(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2020-12-31 00:00:00')");
assertRewriteAfterTypeCoercion("days_sub(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2021-01-02 00:00:00')");
assertRewriteAfterTypeCoercion("hours_add(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2020-12-31 23:00:00')");
assertRewriteAfterTypeCoercion("hours_sub(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2021-01-01 01:00:00')");
assertRewriteAfterTypeCoercion("minutes_add(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2020-12-31 23:59:00')");
assertRewriteAfterTypeCoercion("minutes_sub(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2021-01-01 00:01:00')");
assertRewriteAfterTypeCoercion("seconds_add(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2020-12-31 23:59:59')");
assertRewriteAfterTypeCoercion("seconds_sub(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2021-01-01 00:00:01')");

}
}

0 comments on commit d6ff974

Please sign in to comment.