From bbf8a819de19d36a31aa057e6f10c60c2466eb58 Mon Sep 17 00:00:00 2001 From: yujun Date: Thu, 9 Jan 2025 12:06:45 +0800 Subject: [PATCH] [opt](nereids) opt range inference for or expression when out of order (#46303) ### What problem does this PR solve? Problem Summary: For range inference, it will merge multiple value desc whose reference are the same. It will merge two value desc step by step. Diff merge order may get diff result. For range Inference: `x1 op x2 op x3 op x4` If op is `AND`, then the merge order doesn't matter. It will always get the same result. But if op is `OR`, then the merge order does matter. For example: `(a < 10) or ( a > 30) or (a >= 15 and a <= 35)`. When merge the first OP, it will get an UnknownValue: and its source is: `[ (-00, 10), (30, +00) ]`, latter will merge this UnknowValue with RangeValue `[15, 35]`. Since UnknowValue union another value desc will get a new UnknownValue, then then final result is UknownValue(UnknowValue(RangeValue(`a<10`) or RangeValue(`a>30`)) or RangeValue(`a>=15 and a <= 35`)). This is bad. It should merge the 1st and 3rd value desc firstly, latter merge the 2nd value desc, Then finally the merge result is 'TRUE'. In order to achieve this, use a RangeSet to record all the ranges, then RangeSet will auto merge the results. What's more, this pr also: 1. opt 'a > 20 or a = 20' to 'a >= 20'; 2. for the discrete value's options, if an option is in one range, then the option will eliminate. for example: `a <= 10 or a in [1, 2, 3, 11, 12, 13]` will opt to `a <= 10 or a in [11, 12, 13]`; 3. delete toExpr in RangeInference; --- .../expression/rules/RangeInference.java | 197 ++++++++++-------- .../rules/expression/rules/SimplifyRange.java | 23 +- .../rules/expression/SimplifyRangeTest.java | 63 ++++-- .../doris/nereids/sqltest/InferTest.java | 4 +- 4 files changed, 171 insertions(+), 116 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java index 247856578c29ff..c78ec7a75fbad1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java @@ -34,15 +34,17 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.util.ExpressionUtils; +import com.google.common.collect.BoundType; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.collect.Multimap; import com.google.common.collect.Multimaps; import com.google.common.collect.Range; +import com.google.common.collect.RangeSet; import com.google.common.collect.Sets; +import com.google.common.collect.TreeRangeSet; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; @@ -118,18 +120,17 @@ public ValueDesc visitInPredicate(InPredicate inPredicate, ExpressionRewriteCont @Override public ValueDesc visitAnd(And and, ExpressionRewriteContext context) { - return simplify(context, and, ExpressionUtils.extractConjunction(and), + return simplify(context, ExpressionUtils.extractConjunction(and), ValueDesc::intersect, true); } @Override public ValueDesc visitOr(Or or, ExpressionRewriteContext context) { - return simplify(context, or, ExpressionUtils.extractDisjunction(or), + return simplify(context, ExpressionUtils.extractDisjunction(or), ValueDesc::union, false); } - private ValueDesc simplify(ExpressionRewriteContext context, - Expression originExpr, List predicates, + private ValueDesc simplify(ExpressionRewriteContext context, List predicates, BinaryOperator op, boolean isAnd) { boolean convertIsNullToEmptyValue = isAnd && predicates.stream().anyMatch(expr -> expr instanceof NullLiteral); @@ -144,7 +145,7 @@ private ValueDesc simplify(ExpressionRewriteContext context, // but we don't consider this case here, we should fold IsNull(a) to FALSE using other rule. ValueDesc valueDesc = null; if (convertIsNullToEmptyValue && predicate instanceof IsNull) { - valueDesc = new EmptyValue(context, ((IsNull) predicate).child(), predicate); + valueDesc = new EmptyValue(context, ((IsNull) predicate).child()); } else { valueDesc = predicate.accept(this, context); } @@ -154,7 +155,11 @@ private ValueDesc simplify(ExpressionRewriteContext context, List valuePerRefs = Lists.newArrayList(); for (Entry> referenceValues : groupByReference.asMap().entrySet()) { + Expression reference = referenceValues.getKey(); List valuePerReference = (List) referenceValues.getValue(); + if (!isAnd) { + valuePerReference = ValueDesc.unionDiscreteAndRange(context, reference, valuePerReference); + } // merge per reference ValueDesc simplifiedValue = valuePerReference.get(0); @@ -170,7 +175,7 @@ private ValueDesc simplify(ExpressionRewriteContext context, } // use UnknownValue to wrap different references - return new UnknownValue(context, originExpr, valuePerRefs, isAnd); + return new UnknownValue(context, valuePerRefs, isAnd); } /** @@ -178,12 +183,10 @@ private ValueDesc simplify(ExpressionRewriteContext context, */ public abstract static class ValueDesc { ExpressionRewriteContext context; - Expression toExpr; Expression reference; - public ValueDesc(ExpressionRewriteContext context, Expression reference, Expression toExpr) { + public ValueDesc(ExpressionRewriteContext context, Expression reference) { this.context = context; - this.toExpr = toExpr; this.reference = reference; } @@ -191,10 +194,6 @@ public Expression getReference() { return reference; } - public Expression getOriginExpr() { - return toExpr; - } - public ExpressionRewriteContext getExpressionRewriteContext() { return context; } @@ -204,16 +203,62 @@ public ExpressionRewriteContext getExpressionRewriteContext() { /** or */ public static ValueDesc union(ExpressionRewriteContext context, RangeValue range, DiscreteValue discrete, boolean reverseOrder) { - long count = discrete.values.stream().filter(x -> range.range.test(x)).count(); - if (count == discrete.values.size()) { + if (discrete.values.stream().allMatch(x -> range.range.test(x))) { return range; } - Expression toExpr = FoldConstantRuleOnFE.evaluate( - new Or(range.toExpr, discrete.toExpr), context); List sourceValues = reverseOrder ? ImmutableList.of(discrete, range) : ImmutableList.of(range, discrete); - return new UnknownValue(context, toExpr, sourceValues, false); + return new UnknownValue(context, sourceValues, false); + } + + /** merge discrete and ranges only, no merge other value desc */ + public static List unionDiscreteAndRange(ExpressionRewriteContext context, + Expression reference, List valueDescs) { + Set discreteValues = Sets.newHashSet(); + for (ValueDesc valueDesc : valueDescs) { + if (valueDesc instanceof DiscreteValue) { + discreteValues.addAll(((DiscreteValue) valueDesc).getValues()); + } + } + + // for 'a > 8 or a = 8', then range (8, +00) can convert to [8, +00) + RangeSet rangeSet = TreeRangeSet.create(); + for (ValueDesc valueDesc : valueDescs) { + if (valueDesc instanceof RangeValue) { + Range range = ((RangeValue) valueDesc).range; + rangeSet.add(range); + if (range.hasLowerBound() + && range.lowerBoundType() == BoundType.OPEN + && discreteValues.contains(range.lowerEndpoint())) { + rangeSet.add(Range.singleton(range.lowerEndpoint())); + } + if (range.hasUpperBound() + && range.upperBoundType() == BoundType.OPEN + && discreteValues.contains(range.upperEndpoint())) { + rangeSet.add(Range.singleton(range.upperEndpoint())); + } + } + } + + if (!rangeSet.isEmpty()) { + discreteValues.removeIf(x -> rangeSet.contains(x)); + } + + List result = Lists.newArrayListWithExpectedSize(valueDescs.size()); + if (!discreteValues.isEmpty()) { + result.add(new DiscreteValue(context, reference, discreteValues)); + } + for (Range range : rangeSet.asRanges()) { + result.add(new RangeValue(context, reference, range)); + } + for (ValueDesc valueDesc : valueDescs) { + if (!(valueDesc instanceof DiscreteValue) && !(valueDesc instanceof RangeValue)) { + result.add(valueDesc); + } + } + + return result; } /** intersect */ @@ -221,19 +266,19 @@ public static ValueDesc union(ExpressionRewriteContext context, /** intersect */ public static ValueDesc intersect(ExpressionRewriteContext context, RangeValue range, DiscreteValue discrete) { - DiscreteValue result = new DiscreteValue(context, discrete.reference, discrete.toExpr); - discrete.values.stream().filter(x -> range.range.contains(x)).forEach(result.values::add); - if (!result.values.isEmpty()) { - return result; + Set newValues = discrete.values.stream().filter(x -> range.range.contains(x)) + .collect(Collectors.toSet()); + if (newValues.isEmpty()) { + return new EmptyValue(context, range.reference); + } else { + return new DiscreteValue(context, range.reference, newValues); } - Expression originExpr = FoldConstantRuleOnFE.evaluate(new And(range.toExpr, discrete.toExpr), context); - return new EmptyValue(context, range.reference, originExpr); } private static ValueDesc range(ExpressionRewriteContext context, ComparisonPredicate predicate) { Literal value = (Literal) predicate.right(); if (predicate instanceof EqualTo) { - return new DiscreteValue(context, predicate.left(), predicate, value); + return new DiscreteValue(context, predicate.left(), Sets.newHashSet(value)); } Range range = null; if (predicate instanceof GreaterThanEqual) { @@ -246,13 +291,13 @@ private static ValueDesc range(ExpressionRewriteContext context, ComparisonPredi range = Range.lessThan(value); } - return new RangeValue(context, predicate.left(), predicate, range); + return new RangeValue(context, predicate.left(), range); } public static ValueDesc discrete(ExpressionRewriteContext context, InPredicate in) { // Set literals = (Set) Utils.fastToImmutableSet(in.getOptions()); Set literals = in.getOptions().stream().map(Literal.class::cast).collect(Collectors.toSet()); - return new DiscreteValue(context, in.getCompareExpr(), in, literals); + return new DiscreteValue(context, in.getCompareExpr(), literals); } } @@ -261,8 +306,8 @@ public static ValueDesc discrete(ExpressionRewriteContext context, InPredicate i */ public static class EmptyValue extends ValueDesc { - public EmptyValue(ExpressionRewriteContext context, Expression reference, Expression toExpr) { - super(context, reference, toExpr); + public EmptyValue(ExpressionRewriteContext context, Expression reference) { + super(context, reference); } @Override @@ -284,9 +329,8 @@ public ValueDesc intersect(ValueDesc other) { public static class RangeValue extends ValueDesc { Range range; - public RangeValue(ExpressionRewriteContext context, Expression reference, - Expression toExpr, Range range) { - super(context, reference, toExpr); + public RangeValue(ExpressionRewriteContext context, Expression reference, Range range) { + super(context, reference); this.range = range; } @@ -300,20 +344,16 @@ public ValueDesc union(ValueDesc other) { return other.union(this); } if (other instanceof RangeValue) { - Expression originExpr = FoldConstantRuleOnFE.evaluate(new Or(toExpr, other.toExpr), context); RangeValue o = (RangeValue) other; if (range.isConnected(o.range)) { - return new RangeValue(context, reference, originExpr, range.span(o.range)); + return new RangeValue(context, reference, range.span(o.range)); } - return new UnknownValue(context, originExpr, - ImmutableList.of(this, other), false); + return new UnknownValue(context, ImmutableList.of(this, other), false); } if (other instanceof DiscreteValue) { return union(context, this, (DiscreteValue) other, false); } - Expression originExpr = FoldConstantRuleOnFE.evaluate(new Or(toExpr, other.toExpr), context); - return new UnknownValue(context, originExpr, - ImmutableList.of(this, other), false); + return new UnknownValue(context, ImmutableList.of(this, other), false); } @Override @@ -322,19 +362,16 @@ public ValueDesc intersect(ValueDesc other) { return other.intersect(this); } if (other instanceof RangeValue) { - Expression originExpr = FoldConstantRuleOnFE.evaluate(new And(toExpr, other.toExpr), context); RangeValue o = (RangeValue) other; if (range.isConnected(o.range)) { - return new RangeValue(context, reference, originExpr, range.intersection(o.range)); + return new RangeValue(context, reference, range.intersection(o.range)); } - return new EmptyValue(context, reference, originExpr); + return new EmptyValue(context, reference); } if (other instanceof DiscreteValue) { return intersect(context, this, (DiscreteValue) other); } - Expression originExpr = FoldConstantRuleOnFE.evaluate(new And(toExpr, other.toExpr), context); - return new UnknownValue(context, originExpr, - ImmutableList.of(this, other), true); + return new UnknownValue(context, ImmutableList.of(this, other), true); } @Override @@ -349,17 +386,12 @@ public String toString() { * a in (1,2,3) => [1,2,3] */ public static class DiscreteValue extends ValueDesc { - Set values; + final Set values; public DiscreteValue(ExpressionRewriteContext context, - Expression reference, Expression toExpr, Literal... values) { - this(context, reference, toExpr, Arrays.asList(values)); - } - - public DiscreteValue(ExpressionRewriteContext context, - Expression reference, Expression toExpr, Collection values) { - super(context, reference, toExpr); - this.values = Sets.newHashSet(values); + Expression reference, Set values) { + super(context, reference); + this.values = values; } public Set getValues() { @@ -372,20 +404,15 @@ public ValueDesc union(ValueDesc other) { return other.union(this); } if (other instanceof DiscreteValue) { - Expression originExpr = FoldConstantRuleOnFE.evaluate( - ExpressionUtils.or(toExpr, other.toExpr), context); - DiscreteValue discreteValue = new DiscreteValue(context, reference, originExpr); - discreteValue.values.addAll(((DiscreteValue) other).values); - discreteValue.values.addAll(this.values); - return discreteValue; + Set newValues = Sets.newHashSet(); + newValues.addAll(((DiscreteValue) other).values); + newValues.addAll(this.values); + return new DiscreteValue(context, reference, newValues); } if (other instanceof RangeValue) { return union(context, (RangeValue) other, this, true); } - Expression originExpr = FoldConstantRuleOnFE.evaluate( - ExpressionUtils.or(toExpr, other.toExpr), context); - return new UnknownValue(context, originExpr, - ImmutableList.of(this, other), false); + return new UnknownValue(context, ImmutableList.of(this, other), false); } @Override @@ -394,24 +421,19 @@ public ValueDesc intersect(ValueDesc other) { return other.intersect(this); } if (other instanceof DiscreteValue) { - Expression originExpr = FoldConstantRuleOnFE.evaluate( - ExpressionUtils.and(toExpr, other.toExpr), context); - DiscreteValue discreteValue = new DiscreteValue(context, reference, originExpr); - discreteValue.values.addAll(((DiscreteValue) other).values); - discreteValue.values.retainAll(this.values); - if (discreteValue.values.isEmpty()) { - return new EmptyValue(context, reference, originExpr); + Set newValues = Sets.newHashSet(); + newValues.addAll(((DiscreteValue) other).values); + newValues.retainAll(this.values); + if (newValues.isEmpty()) { + return new EmptyValue(context, reference); } else { - return discreteValue; + return new DiscreteValue(context, reference, newValues); } } if (other instanceof RangeValue) { return intersect(context, (RangeValue) other, this); } - Expression originExpr = FoldConstantRuleOnFE.evaluate( - ExpressionUtils.and(toExpr, other.toExpr), context); - return new UnknownValue(context, originExpr, - ImmutableList.of(this, other), true); + return new UnknownValue(context, ImmutableList.of(this, other), true); } @Override @@ -428,14 +450,14 @@ public static class UnknownValue extends ValueDesc { private final boolean isAnd; private UnknownValue(ExpressionRewriteContext context, Expression expr) { - super(context, expr, expr); + super(context, expr); sourceValues = ImmutableList.of(); isAnd = false; } - public UnknownValue(ExpressionRewriteContext context, Expression toExpr, + private UnknownValue(ExpressionRewriteContext context, List sourceValues, boolean isAnd) { - super(context, getReference(sourceValues, toExpr), toExpr); + super(context, getReference(context, sourceValues, isAnd)); this.sourceValues = ImmutableList.copyOf(sourceValues); this.isAnd = isAnd; } @@ -455,11 +477,12 @@ public UnknownValue(ExpressionRewriteContext context, Expression toExpr, // E union UnknownValue1 = E.union(UnknownValue1) = UnknownValue1, // 2. since E and UnknownValue2's reference not equals, then // E union UnknownValue2 = UnknownValue3(E union UnknownValue2, reference=E union UnknownValue2) - private static Expression getReference(List sourceValues, Expression toExpr) { + private static Expression getReference(ExpressionRewriteContext context, + List sourceValues, boolean isAnd) { Expression reference = sourceValues.get(0).reference; for (int i = 1; i < sourceValues.size(); i++) { if (!reference.equals(sourceValues.get(i).reference)) { - return toExpr; + return SimplifyRange.INSTANCE.getExpression(context, sourceValues, isAnd); } } return reference; @@ -480,10 +503,7 @@ public ValueDesc union(ValueDesc other) { if (other instanceof EmptyValue) { return other.union(this); } - Expression originExpr = FoldConstantRuleOnFE.evaluate( - ExpressionUtils.or(toExpr, other.toExpr), context); - return new UnknownValue(context, originExpr, - ImmutableList.of(this, other), false); + return new UnknownValue(context, ImmutableList.of(this, other), false); } @Override @@ -493,10 +513,7 @@ public ValueDesc intersect(ValueDesc other) { if (other instanceof EmptyValue) { return other.intersect(this); } - Expression originExpr = FoldConstantRuleOnFE.evaluate( - ExpressionUtils.and(toExpr, other.toExpr), context); - return new UnknownValue(context, originExpr, - ImmutableList.of(this, other), true); + return new UnknownValue(context, ImmutableList.of(this, other), true); } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java index 576ef6bbf4d5df..64891882f7d661 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java @@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.util.ExpressionUtils; +import com.google.common.base.Preconditions; import com.google.common.collect.BoundType; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; @@ -150,26 +151,28 @@ private Expression getExpression(DiscreteValue value) { private Expression getExpression(UnknownValue value) { List sourceValues = value.getSourceValues(); - Expression originExpr = value.getOriginExpr(); if (sourceValues.isEmpty()) { - return originExpr; + return value.getReference(); + } else { + return getExpression(value.getExpressionRewriteContext(), sourceValues, value.isAnd()); } + } + + /** getExpression */ + public Expression getExpression(ExpressionRewriteContext context, + List sourceValues, boolean isAnd) { + Preconditions.checkArgument(!sourceValues.isEmpty()); List sourceExprs = Lists.newArrayListWithExpectedSize(sourceValues.size()); for (ValueDesc sourceValue : sourceValues) { Expression expr = getExpression(sourceValue); - if (value.isAnd()) { + if (isAnd) { sourceExprs.addAll(ExpressionUtils.extractConjunction(expr)); } else { sourceExprs.addAll(ExpressionUtils.extractDisjunction(expr)); } } - Expression result = value.isAnd() ? ExpressionUtils.and(sourceExprs) : ExpressionUtils.or(sourceExprs); - result = FoldConstantRuleOnFE.evaluate(result, value.getExpressionRewriteContext()); - // ATTN: we must return original expr, because OrToIn is implemented with MutableState, - // newExpr will lose these states leading to dead loop by OrToIn -> SimplifyRange -> FoldConstantByFE - if (result.equals(originExpr)) { - return originExpr; - } + Expression result = isAnd ? ExpressionUtils.and(sourceExprs) : ExpressionUtils.or(sourceExprs); + result = FoldConstantRuleOnFE.evaluate(result, context); return result; } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java index 784600577c37c6..7393439c5e617f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java @@ -22,6 +22,10 @@ import org.apache.doris.nereids.analyzer.UnboundSlot; import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer; +import org.apache.doris.nereids.rules.expression.rules.RangeInference; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.EmptyValue; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.UnknownValue; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDesc; import org.apache.doris.nereids.rules.expression.rules.SimplifyRange; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; @@ -60,6 +64,24 @@ public SimplifyRangeTest() { context = new ExpressionRewriteContext(cascadesContext); } + @Test + public void testRangeInference() { + ValueDesc valueDesc = getValueDesc("TA IS NULL"); + Assertions.assertInstanceOf(UnknownValue.class, valueDesc); + List sourceValues = ((UnknownValue) valueDesc).getSourceValues(); + Assertions.assertEquals(0, sourceValues.size()); + Assertions.assertEquals("TA IS NULL", valueDesc.getReference().toSql()); + + valueDesc = getValueDesc("TA IS NULL AND TB IS NULL AND NULL"); + Assertions.assertInstanceOf(UnknownValue.class, valueDesc); + sourceValues = ((UnknownValue) valueDesc).getSourceValues(); + Assertions.assertEquals(3, sourceValues.size()); + Assertions.assertInstanceOf(EmptyValue.class, sourceValues.get(0)); + Assertions.assertInstanceOf(EmptyValue.class, sourceValues.get(1)); + Assertions.assertEquals("TA", sourceValues.get(0).getReference().toSql()); + Assertions.assertEquals("TB", sourceValues.get(1).getReference().toSql()); + } + @Test public void testSimplify() { executor = new ExpressionRuleExecutor(ImmutableList.of( @@ -69,8 +91,15 @@ public void testSimplify() { assertRewrite("TA > 3 or TA > null", "TA > 3 OR NULL"); assertRewrite("TA > 3 or TA < null", "TA > 3 OR NULL"); assertRewrite("TA > 3 or TA = null", "TA > 3 OR NULL"); + assertRewrite("TA > 3 or TA = 3 or TA < null", "TA >= 3 OR NULL"); + assertRewrite("TA < 10 or TA in (1, 2, 3, 11, 12, 13)", "TA in (11, 12, 13) OR TA < 10"); + assertRewrite("TA < 10 or TA in (1, 2, 3, 10, 11, 12, 13) or TA > 13 or TA < 10 or TA in (1, 2, 3, 10, 11, 12, 13) or TA > 13", + "TA in (11, 12) OR TA <= 10 OR TA >= 13"); assertRewrite("TA > 3 or TA <> null", "TA > 3 or null"); assertRewrite("TA > 3 or TA <=> null", "TA > 3 or TA <=> null"); + assertRewrite("(TA < 1 or TA > 2) or (TA >= 0 and TA <= 3)", "TA IS NOT NULL OR NULL"); + assertRewrite("TA between 10 and 20 or TA between 100 and 120 or TA between 15 and 25 or TA between 115 and 125", + "TA between 10 and 25 or TA between 100 and 125"); assertRewriteNotNull("TA > 3 and TA > null", "TA > 3 and NULL"); assertRewriteNotNull("TA > 3 and TA < null", "TA > 3 and NULL"); assertRewriteNotNull("TA > 3 and TA = null", "TA > 3 and NULL"); @@ -88,13 +117,13 @@ public void testSimplify() { assertRewrite("TA >= 3 and TA < 3", "TA >= 3 and TA < 3"); assertRewriteNotNull("TA = 1 and TA > 10", "FALSE"); assertRewrite("TA = 1 and TA > 10", "TA is null and null"); - assertRewrite("TA > 5 or TA < 1", "TA > 5 or TA < 1"); + assertRewrite("TA > 5 or TA < 1", "TA < 1 or TA > 5"); assertRewrite("TA > 5 or TA > 1 or TA > 10", "TA > 1"); assertRewrite("TA > 5 or TA > 1 or TA < 10", "TA is not null or null"); assertRewriteNotNull("TA > 5 or TA > 1 or TA < 10", "TRUE"); assertRewrite("TA > 5 and TA > 1 and TA > 10", "TA > 10"); assertRewrite("TA > 5 and TA > 1 and TA < 10", "TA > 5 and TA < 10"); - assertRewrite("TA > 1 or TA < 1", "TA > 1 or TA < 1"); + assertRewrite("TA > 1 or TA < 1", "TA < 1 or TA > 1"); assertRewrite("TA > 1 or TA < 10", "TA is not null or null"); assertRewriteNotNull("TA > 1 or TA < 10", "TRUE"); assertRewrite("TA > 5 and TA < 10", "TA > 5 and TA < 10"); @@ -109,7 +138,7 @@ public void testSimplify() { assertRewrite("(TA > 10 or TA > 20) and (TB > 10 and TB > 20)", "TA > 10 and TB > 20"); assertRewrite("((TB > 30 and TA > 40) and TA > 20) and (TB > 10 and TB > 20)", "TB > 30 and TA > 40"); assertRewrite("(TA > 10 and TB > 10) or (TB > 10 and TB > 20)", "TA > 10 and TB > 10 or TB > 20"); - assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))"); + assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB < 10 or TB > 20))"); assertRewriteNotNull("TA in (1,2,3) and TA > 10", "FALSE"); assertRewrite("TA in (1,2,3) and TA > 10", "TA is null and null"); assertRewrite("TA in (1,2,3) and TA >= 1", "TA in (1,2,3)"); @@ -119,7 +148,7 @@ public void testSimplify() { assertRewrite("TA in (1,2,3) and TA < 10", "TA in (1,2,3)"); assertRewriteNotNull("TA in (1,2,3) and TA < 1", "FALSE"); assertRewrite("TA in (1,2,3) and TA < 1", "TA is null and null"); - assertRewrite("TA in (1,2,3) or TA < 1", "TA in (1,2,3) or TA < 1"); + assertRewrite("TA in (1,2,3) or TA < 1", "TA in (2,3) or TA <= 1"); assertRewrite("TA in (1,2,3) or TA in (2,3,4)", "TA in (1,2,3,4)"); assertRewrite("TA in (1,2,3) or TA in (4,5,6)", "TA in (1,2,3,4,5,6)"); assertRewrite("TA in (1,2,3) and TA in (4,5,6)", "TA is null and null"); @@ -150,12 +179,12 @@ public void testSimplify() { assertRewrite("TA + TC >= 3 and TA + TC < 3", "TA + TC >= 3 and TA + TC < 3"); assertRewriteNotNull("TA + TC = 1 and TA + TC > 10", "FALSE"); assertRewrite("TA + TC = 1 and TA + TC > 10", "(TA + TC) is null and null"); - assertRewrite("TA + TC > 5 or TA + TC < 1", "TA + TC > 5 or TA + TC < 1"); + assertRewrite("TA + TC > 5 or TA + TC < 1", "TA + TC < 1 or TA + TC > 5"); assertRewrite("TA + TC > 5 or TA + TC > 1 or TA + TC > 10", "TA + TC > 1"); assertRewrite("TA + TC > 5 or TA + TC > 1 or TA + TC < 10", "(TA + TC) is not null or null"); assertRewrite("TA + TC > 5 and TA + TC > 1 and TA + TC > 10", "TA + TC > 10"); assertRewrite("TA + TC > 5 and TA + TC > 1 and TA + TC < 10", "TA + TC > 5 and TA + TC < 10"); - assertRewrite("TA + TC > 1 or TA + TC < 1", "TA + TC > 1 or TA + TC < 1"); + assertRewrite("TA + TC > 1 or TA + TC < 1", "TA + TC < 1 or TA + TC > 1"); assertRewrite("TA + TC > 1 or TA + TC < 10", "(TA + TC) is not null or null"); assertRewrite("TA + TC > 5 and TA + TC < 10", "TA + TC > 5 and TA + TC < 10"); assertRewrite("TA + TC > 5 and TA + TC > 10", "TA + TC > 10"); @@ -168,7 +197,7 @@ public void testSimplify() { assertRewrite("(TA + TC > 10 or TA + TC > 20) and (TB > 10 and TB > 20)", "TA + TC > 10 and TB > 20"); assertRewrite("((TB > 30 and TA + TC > 40) and TA + TC > 20) and (TB > 10 and TB > 20)", "TB > 30 and TA + TC > 40"); assertRewrite("(TA + TC > 10 and TB > 10) or (TB > 10 and TB > 20)", "TA + TC > 10 and TB > 10 or TB > 20"); - assertRewrite("((TA + TC > 10 or TA + TC > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA + TC > 5 and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))"); + assertRewrite("((TA + TC > 10 or TA + TC > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA + TC > 5 and TB > 10) or (TB > 10 and (TB < 10 or TB > 20))"); assertRewriteNotNull("TA + TC in (1,2,3) and TA + TC > 10", "FALSE"); assertRewrite("TA + TC in (1,2,3) and TA + TC > 10", "(TA + TC) is null and null"); assertRewrite("TA + TC in (1,2,3) and TA + TC >= 1", "TA + TC in (1,2,3)"); @@ -178,7 +207,7 @@ public void testSimplify() { assertRewrite("TA + TC in (1,2,3) and TA + TC < 10", "TA + TC in (1,2,3)"); assertRewriteNotNull("TA + TC in (1,2,3) and TA + TC < 1", "FALSE"); assertRewrite("TA + TC in (1,2,3) and TA + TC < 1", "(TA + TC) is null and null"); - assertRewrite("TA + TC in (1,2,3) or TA + TC < 1", "TA + TC in (1,2,3) or TA + TC < 1"); + assertRewrite("TA + TC in (1,2,3) or TA + TC < 1", "TA + TC in (2,3) or TA + TC <= 1"); assertRewrite("TA + TC in (1,2,3) or TA + TC in (2,3,4)", "TA + TC in (1,2,3,4)"); assertRewrite("TA + TC in (1,2,3) or TA + TC in (4,5,6)", "TA + TC in (1,2,3,4,5,6)"); assertRewriteNotNull("TA + TC in (1,2,3) and TA + TC in (4,5,6)", "FALSE"); @@ -221,7 +250,7 @@ public void testSimplifyDate() { assertRewriteNotNull("AA = date '2024-01-01' and AA > date '2024-01-10'", "FALSE"); assertRewrite("AA = date '2024-01-01' and AA > date '2024-01-10'", "AA is null and null"); assertRewrite("AA > date '2024-01-05' or AA < date '2024-01-01'", - "AA > date '2024-01-05' or AA < date '2024-01-01'"); + "AA < date '2024-01-01' or AA > date '2024-01-05'"); assertRewrite("AA > date '2024-01-05' or AA > date '2024-01-01' or AA > date '2024-01-10'", "AA > date '2024-01-01'"); assertRewrite("AA > date '2024-01-05' or AA > date '2024-01-01' or AA < date '2024-01-10'", "AA is not null or null"); @@ -231,7 +260,7 @@ public void testSimplifyDate() { assertRewrite("AA > date '2024-01-05' and AA > date '2024-01-01' and AA < date '2024-01-10'", "AA > date '2024-01-05' and AA < date '2024-01-10'"); assertRewrite("AA > date '2024-01-05' or AA < date '2024-01-05'", - "AA > date '2024-01-05' or AA < date '2024-01-05'"); + "AA < date '2024-01-05' or AA > date '2024-01-05'"); assertRewrite("AA > date '2024-01-01' or AA < date '2024-01-10'", "AA is not null or null"); assertRewriteNotNull("AA > date '2024-01-01' or AA < date '2024-01-10'", "TRUE"); assertRewrite("AA > date '2024-01-05' and AA < date '2024-01-10'", @@ -261,7 +290,7 @@ public void testSimplifyDate() { assertRewrite("AA in (date '2024-01-01',date '2024-01-02',date '2024-01-03') and AA < date '2024-01-01'", "AA is null and null"); assertRewrite("AA in (date '2024-01-01',date '2024-01-02',date '2024-01-03') or AA < date '2024-01-01'", - "AA in (date '2024-01-01',date '2024-01-02',date '2024-01-03') or AA < date '2024-01-01'"); + "AA in (date '2024-01-02',date '2024-01-03') or AA <= date '2024-01-01'"); assertRewrite("AA in (date '2024-01-01',date '2024-01-02') or AA in (date '2024-01-02', date '2024-01-03')", "AA in (date '2024-01-01',date '2024-01-02',date '2024-01-03')"); assertRewriteNotNull("AA in (date '2024-01-01',date '2024-01-02') and AA in (date '2024-01-03', date '2024-01-04')", @@ -301,7 +330,7 @@ public void testSimplifyDateTime() { assertRewriteNotNull("CA = timestamp '2024-01-01 10:00:10' and CA > timestamp '2024-01-10 00:00:10'", "FALSE"); assertRewrite("CA = timestamp '2024-01-01 10:00:10' and CA > timestamp '2024-01-10 00:00:10'", "CA is null and null"); assertRewrite("CA > timestamp '2024-01-05 00:00:10' or CA < timestamp '2024-01-01 00:00:10'", - "CA > timestamp '2024-01-05 00:00:10' or CA < timestamp '2024-01-01 00:00:10'"); + "CA < timestamp '2024-01-01 00:00:10' or CA > timestamp '2024-01-05 00:00:10'"); assertRewrite("CA > timestamp '2024-01-05 00:00:10' or CA > timestamp '2024-01-01 00:00:10' or CA > timestamp '2024-01-10 00:00:10'", "CA > timestamp '2024-01-01 00:00:10'"); assertRewrite("CA > timestamp '2024-01-05 00:00:10' or CA > timestamp '2024-01-01 00:00:10' or CA < timestamp '2024-01-10 00:00:10'", "CA is not null or null"); @@ -311,7 +340,7 @@ public void testSimplifyDateTime() { assertRewrite("CA > timestamp '2024-01-05 00:00:10' and CA > timestamp '2024-01-01 00:00:10' and CA < timestamp '2024-01-10 00:00:10'", "CA > timestamp '2024-01-05 00:00:10' and CA < timestamp '2024-01-10 00:00:10'"); assertRewrite("CA > timestamp '2024-01-05 00:00:10' or CA < timestamp '2024-01-05 00:00:10'", - "CA > timestamp '2024-01-05 00:00:10' or CA < timestamp '2024-01-05 00:00:10'"); + "CA < timestamp '2024-01-05 00:00:10' or CA > timestamp '2024-01-05 00:00:10'"); assertRewrite("CA > timestamp '2024-01-01 00:02:10' or CA < timestamp '2024-01-10 00:02:10'", "CA is not null or null"); assertRewriteNotNull("CA > timestamp '2024-01-01 00:00:00' or CA < timestamp '2024-01-10 00:00:00'", "TRUE"); assertRewrite("CA > timestamp '2024-01-05 01:00:00' and CA < timestamp '2024-01-10 01:00:00'", @@ -364,7 +393,13 @@ public void testSimplifyDateTime() { "(CA is null and null) OR CB < timestamp '2024-01-05 00:50:00'"); } - @Test + private ValueDesc getValueDesc(String expression) { + Map mem = Maps.newHashMap(); + Expression parseExpression = replaceUnboundSlot(PARSER.parseExpression(expression), mem); + parseExpression = typeCoercion(parseExpression); + return (new RangeInference()).getValue(parseExpression, context); + } + private void assertRewrite(String expression, String expected) { Map mem = Maps.newHashMap(); Expression needRewriteExpression = replaceUnboundSlot(PARSER.parseExpression(expression), mem); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java index 3d88c131c97f15..cdc36164ae9f95 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java @@ -58,7 +58,7 @@ void testInferNotNullFromFilterAndEliminateOuter2() { f -> ExpressionUtils.and(f.getConjuncts().stream() .sorted((a, b) -> a.toString().compareTo(b.toString())) .collect(Collectors.toList())) - .toString().equals("AND[(id#0 >= 4),OR[(id#0 = 4),(id#0 > 4)]]")) + .toString().equals("(id#0 >= 4)")) ) ); @@ -76,7 +76,7 @@ void testInferNotNullFromFilterAndEliminateOuter3() { logicalFilter( leftOuterLogicalJoin( logicalFilter().when( - f -> f.getPredicate().toString().equals("AND[(id#0 >= 4),OR[(id#0 = 4),(id#0 > 4)]]")), + f -> f.getPredicate().toString().equals("(id#0 >= 4)")), logicalFilter().when( f -> f.getPredicate().toString().equals("(id#2 >= 4)") )