From 820156138a2937ffba1e0e19f265b603c41e3a09 Mon Sep 17 00:00:00 2001 From: yujun Date: Fri, 29 Nov 2024 21:33:44 +0800 Subject: [PATCH] add session variable disable_nereids_expression_rules --- .../rules/SimplifyComparisonPredicate.java | 184 ++++++++---- .../doris/nereids/util/ExpressionUtils.java | 16 ++ .../doris/nereids/util/TypeCoercionUtils.java | 44 +++ .../SimplifyComparisonPredicateTest.java | 263 +++++++++++++++++- .../pull_up_predicate_literal.out | 110 +------- .../predicate_infer/infer_predicate.out | 10 +- .../predicate_infer/infer_predicate.groovy | 7 +- 7 files changed, 462 insertions(+), 172 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index cb61795865239b..5f14d9625e0849 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -17,18 +17,17 @@ package org.apache.doris.nereids.rules.expression.rules; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; -import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; -import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.LessThan; import org.apache.doris.nereids.trees.expressions.LessThanEqual; import org.apache.doris.nereids.trees.expressions.NullSafeEqual; @@ -44,16 +43,16 @@ import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; -import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.expressions.literal.NumericLiteral; import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; -import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.types.DateTimeType; import org.apache.doris.nereids.types.DateTimeV2Type; import org.apache.doris.nereids.types.DateType; import org.apache.doris.nereids.types.DateV2Type; import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.coercion.DateLikeType; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.TypeCoercionUtils; import com.google.common.base.Preconditions; @@ -62,9 +61,10 @@ import java.math.BigDecimal; import java.math.RoundingMode; import java.util.List; +import java.util.Optional; /** - * simplify comparison + * simplify comparison, not support large int. * such as: cast(c1 as DateV2) >= DateV2Literal --> c1 >= DateLiteral * cast(c1 AS double) > 2.0 --> c1 >= 2 (c1 is integer like type) */ @@ -98,22 +98,25 @@ public static Expression simplify(ComparisonPredicate cp) { Expression left = cp.left(); Expression right = cp.right(); - // float like type: float, double - if (left.getDataType().isFloatLikeType() && right.getDataType().isFloatLikeType()) { - return processFloatLikeTypeCoercion(cp, left, right); - } + Expression result; - // decimalv3 type - if (left.getDataType() instanceof DecimalV3Type && right.getDataType() instanceof DecimalV3Type) { - return processDecimalV3TypeCoercion(cp, left, right); + // process type coercion + if (left.getDataType().isFloatLikeType() && right.getDataType().isFloatLikeType()) { + result = processFloatLikeTypeCoercion(cp, left, right); + } else if (left.getDataType() instanceof DecimalV3Type && right.getDataType() instanceof DecimalV3Type) { + result = processDecimalV3TypeCoercion(cp, left, right); + } else if (left.getDataType() instanceof DateLikeType && right.getDataType() instanceof DateLikeType) { + result = processDateLikeTypeCoercion(cp, left, right); + } else { + result = cp; } - // date like type - if (left.getDataType() instanceof DateLikeType && right.getDataType() instanceof DateLikeType) { - return processDateLikeTypeCoercion(cp, left, right); + if (result instanceof ComparisonPredicate && ((ComparisonPredicate) result).right() instanceof NumericLiteral) { + ComparisonPredicate cmp = (ComparisonPredicate) result; + result = processTypeRangeLimitComparison(cmp, cmp.left(), (NumericLiteral) cmp.right()); } - return cp; + return result; } private static Expression processComparisonPredicateDateTimeV2Literal( @@ -128,17 +131,13 @@ private static Expression processComparisonPredicateDateTimeV2Literal( if (right.getMicroSecond() == originValue) { return comparisonPredicate.withChildren(left, right); } else { - if (left.nullable()) { - // TODO: the ideal way is to return an If expr like: - // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), - // BooleanLiteral.of(false)); - // but current fold constant rule can't handle such complex expr with null literal - // before supporting complex conjuncts with null literal folding rules, - // we use a trick way like this: - return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE)); - } else { - return BooleanLiteral.of(false); - } + // TODO: the ideal way is to return an If expr like: + // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), + // BooleanLiteral.of(false)); + // but current fold constant rule can't handle such complex expr with null literal + // before supporting complex conjuncts with null literal folding rules, + // we use a trick way like this: + return ExpressionUtils.falseOrNull(left); } } else if (comparisonPredicate instanceof NullSafeEqual) { long originValue = right.getMicroSecond(); @@ -239,18 +238,13 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa comparisonPredicate.withChildren(left, new DecimalV3Literal( literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY)))); } catch (ArithmeticException e) { - if (left.nullable()) { - // TODO: the ideal way is to return an If expr like: - // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), - // BooleanLiteral.of(false)); - // but current fold constant rule can't handle such complex expr with null literal - // before supporting complex conjuncts with null literal folding rules, - // we use a trick way like this: - return new And(new IsNull(left), - new NullLiteral(BooleanType.INSTANCE)); - } else { - return BooleanLiteral.of(false); - } + // TODO: the ideal way is to return an If expr like: + // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), + // BooleanLiteral.of(false)); + // but current fold constant rule can't handle such complex expr with null literal + // before supporting complex conjuncts with null literal folding rules, + // we use a trick way like this: + return ExpressionUtils.falseOrNull(left); } } else if (comparisonPredicate instanceof NullSafeEqual) { try { @@ -281,21 +275,18 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa private static Expression processIntegerDecimalLiteralComparison( ComparisonPredicate comparisonPredicate, Expression left, BigDecimal literal) { // we only process isIntegerLikeType, which are tinyint, smallint, int, bigint - if (literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) { + if (literal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0 + && literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) { literal = literal.stripTrailingZeros(); if (literal.scale() > 0) { if (comparisonPredicate instanceof EqualTo) { - if (left.nullable()) { - // TODO: the ideal way is to return an If expr like: - // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), - // BooleanLiteral.of(false)); - // but current fold constant rule can't handle such complex expr with null literal - // before supporting complex conjuncts with null literal folding rules, - // we use a trick way like this: - return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE)); - } else { - return BooleanLiteral.of(false); - } + // TODO: the ideal way is to return an If expr like: + // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), + // BooleanLiteral.of(false)); + // but current fold constant rule can't handle such complex expr with null literal + // before supporting complex conjuncts with null literal folding rules, + // we use a trick way like this: + return ExpressionUtils.falseOrNull(left); } else if (comparisonPredicate instanceof NullSafeEqual) { return BooleanLiteral.of(false); } else if (comparisonPredicate instanceof GreaterThan @@ -320,10 +311,95 @@ private static Expression processIntegerDecimalLiteralComparison( return comparisonPredicate; } + private static Expression processTypeRangeLimitComparison(ComparisonPredicate cp, Expression left, + NumericLiteral right) { + BigDecimal typeMinValue = null; + BigDecimal typeMaxValue = null; + // cmp float like have lost precision, for example float.max_value + 0.01 still eval to float.max_value + if (left.getDataType().isIntegerLikeType() || left.getDataType().isDecimalV3Type()) { + Optional> minMaxOpt = + TypeCoercionUtils.getDataTypeMinMaxValue(left.getDataType()); + if (minMaxOpt.isPresent()) { + typeMinValue = minMaxOpt.get().first; + typeMaxValue = minMaxOpt.get().second; + } + } + + // cast(child as dataType2) range should be: + // [ max(childDataType.min_value, dataType2.min_value), min(childDataType.max_value, dataType2.max_value)] + if (left instanceof Cast) { + left = ((Cast) left).child(); + if (left.getDataType().isIntegerLikeType() || left.getDataType().isDecimalV3Type()) { + Optional> minMaxOpt = + TypeCoercionUtils.getDataTypeMinMaxValue(left.getDataType()); + if (minMaxOpt.isPresent()) { + if (typeMinValue == null || typeMinValue.compareTo(minMaxOpt.get().first) < 0) { + typeMinValue = minMaxOpt.get().first; + } + if (typeMaxValue == null || typeMaxValue.compareTo(minMaxOpt.get().second) > 0) { + typeMaxValue = minMaxOpt.get().second; + } + } + } + } + + if (typeMinValue == null || typeMaxValue == null) { + return cp; + } + BigDecimal literal = new BigDecimal(right.getStringValue()); + int cmpMin = literal.compareTo(typeMinValue); + int cmpMax = literal.compareTo(typeMaxValue); + if (cp instanceof EqualTo) { + if (cmpMin < 0 || cmpMax > 0) { + return ExpressionUtils.falseOrNull(left); + } + } else if (cp instanceof NullSafeEqual) { + if (cmpMin < 0 || cmpMax > 0) { + return BooleanLiteral.of(false); + } + } else if (cp instanceof GreaterThan) { + if (cmpMin < 0) { + return ExpressionUtils.trueOrNull(left); + } + if (cmpMax >= 0) { + return ExpressionUtils.falseOrNull(left); + } + } else if (cp instanceof GreaterThanEqual) { + if (cmpMin <= 0) { + return ExpressionUtils.trueOrNull(left); + } + if (cmpMax == 0) { + return new EqualTo(cp.left(), cp.right()); + } + if (cmpMax > 0) { + return ExpressionUtils.falseOrNull(left); + } + } else if (cp instanceof LessThan) { + if (cmpMin <= 0) { + return ExpressionUtils.falseOrNull(left); + } + if (cmpMax > 0) { + return ExpressionUtils.trueOrNull(left); + } + } else if (cp instanceof LessThanEqual) { + if (cmpMin < 0) { + return ExpressionUtils.falseOrNull(left); + } + if (cmpMin == 0) { + return new EqualTo(cp.left(), cp.right()); + } + if (cmpMax >= 0) { + return ExpressionUtils.trueOrNull(left); + } + } + return cp; + } + private static IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal) { - Preconditions.checkArgument( - decimal.scale() <= 0 && decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0, - "decimal literal must have 0 scale and smaller than Long.MAX_VALUE"); + Preconditions.checkArgument(decimal.scale() <= 0 + && decimal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0 + && decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0, + "decimal literal must have 0 scale and in range [Long.MIN_VALUE, Long.MAX_VALUE]"); long val = decimal.longValue(); if (val >= Byte.MIN_VALUE && val <= Byte.MAX_VALUE) { return new TinyIntLiteral((byte) val); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 22b681a6246d92..25637d1b816656 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -260,6 +260,22 @@ public static Expression or(Collection expressions) { } } + public static Expression falseOrNull(Expression expression) { + if (expression.nullable()) { + return new And(new IsNull(expression), new NullLiteral(BooleanType.INSTANCE)); + } else { + return BooleanLiteral.FALSE; + } + } + + public static Expression trueOrNull(Expression expression) { + if (expression.nullable()) { + return new Or(new Not(new IsNull(expression)), new NullLiteral(BooleanType.INSTANCE)); + } else { + return BooleanLiteral.TRUE; + } + } + /** * Use AND/OR to combine expressions together. */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index 603a891d2d2a49..1da4353d20da33 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -21,6 +21,7 @@ import org.apache.doris.catalog.ScalarType; import org.apache.doris.catalog.Type; import org.apache.doris.common.Config; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Add; @@ -116,6 +117,7 @@ import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -1773,6 +1775,48 @@ private static Expression processDecimalV3BinaryArithmetic(BinaryArithmetic bina castIfNotSameType(right, dt2)); } + /** + * get min and max value of a data type + * + * @param dataType specific data type + * @return min and max values pair + */ + public static Optional> getDataTypeMinMaxValue(DataType dataType) { + if (dataType.isTinyIntType()) { + return Optional.of(Pair.of(new BigDecimal(Byte.MIN_VALUE), new BigDecimal(Byte.MAX_VALUE))); + } else if (dataType.isSmallIntType()) { + return Optional.of(Pair.of(new BigDecimal(Short.MIN_VALUE), new BigDecimal(Short.MAX_VALUE))); + } else if (dataType.isIntegerType()) { + return Optional.of(Pair.of(new BigDecimal(Integer.MIN_VALUE), new BigDecimal(Integer.MAX_VALUE))); + } else if (dataType.isBigIntType()) { + return Optional.of(Pair.of(new BigDecimal(Long.MIN_VALUE), new BigDecimal(Long.MAX_VALUE))); + } else if (dataType.isLargeIntType()) { + return Optional.of(Pair.of(new BigDecimal(LargeIntType.MIN_VALUE), new BigDecimal(LargeIntType.MAX_VALUE))); + } else if (dataType.isFloatType()) { + return Optional.of(Pair.of(BigDecimal.valueOf(-Float.MAX_VALUE), new BigDecimal(Float.MAX_VALUE))); + } else if (dataType.isDoubleType()) { + return Optional.of(Pair.of(BigDecimal.valueOf(-Double.MAX_VALUE), new BigDecimal(Double.MAX_VALUE))); + } else if (dataType.isDecimalV3Type()) { + DecimalV3Type type = (DecimalV3Type) dataType; + int precision = type.getPrecision(); + int scale = type.getScale(); + if (scale >= 0) { + StringBuilder sb = new StringBuilder(); + sb.append(StringUtils.repeat('9', precision - scale)); + if (sb.length() == 0) { + sb.append('0'); + } + if (scale > 0) { + sb.append('.'); + sb.append(StringUtils.repeat('9', scale)); + } + return Optional.of(Pair.of(new BigDecimal("-" + sb.toString()), new BigDecimal(sb.toString()))); + } + } + + return Optional.empty(); + } + private static boolean supportCompare(DataType dataType) { if (dataType.isArrayType()) { return true; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java index 84ebd7c7250198..9202c1e202469d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java @@ -17,10 +17,12 @@ package org.apache.doris.nereids.rules.expression.rules; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; @@ -38,17 +40,29 @@ import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal; import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal; import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.LargeIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; +import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.BooleanType; +import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DateTimeV2Type; import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; +import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.List; class SimplifyComparisonPredicateTest extends ExpressionRewriteTestHelper { @Test @@ -233,63 +247,292 @@ void testDecimalV3Literal() { Assertions.assertEquals(BooleanLiteral.FALSE, rewrittenExpression); // > right literal should round floor - leftChild = new DecimalV3Literal(new BigDecimal("1.24")); + leftChild = new DecimalV3Literal(new BigDecimal("10.24")); left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3)); right = new DecimalV3Literal(new BigDecimal("12.345")); expression = new GreaterThan(left, right); rewrittenExpression = executor.rewrite(expression, context); - Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0)); Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2), rewrittenExpression.child(0).getDataType()); Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); Assertions.assertEquals(new BigDecimal("12.34"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); // <= right literal should round floor - leftChild = new DecimalV3Literal(new BigDecimal("1.24")); + leftChild = new DecimalV3Literal(new BigDecimal("10.24")); left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3)); right = new DecimalV3Literal(new BigDecimal("12.345")); expression = new LessThanEqual(left, right); rewrittenExpression = executor.rewrite(expression, context); - Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0)); Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2), rewrittenExpression.child(0).getDataType()); Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); Assertions.assertEquals(new BigDecimal("12.34"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); // >= right literal should round ceiling - leftChild = new DecimalV3Literal(new BigDecimal("1.24")); + leftChild = new DecimalV3Literal(new BigDecimal("10.24")); left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3)); right = new DecimalV3Literal(new BigDecimal("12.345")); expression = new GreaterThanEqual(left, right); rewrittenExpression = executor.rewrite(expression, context); - Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0)); Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2), rewrittenExpression.child(0).getDataType()); Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); Assertions.assertEquals(new BigDecimal("12.35"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); // < right literal should round ceiling - leftChild = new DecimalV3Literal(new BigDecimal("1.24")); + leftChild = new DecimalV3Literal(new BigDecimal("10.24")); left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3)); right = new DecimalV3Literal(new BigDecimal("12.345")); expression = new LessThan(left, right); rewrittenExpression = executor.rewrite(expression, context); - Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0)); Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2), rewrittenExpression.child(0).getDataType()); Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); Assertions.assertEquals(new BigDecimal("12.35"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); // left's child range smaller than right literal - leftChild = new DecimalV3Literal(new BigDecimal("1234.12")); + leftChild = new DecimalV3Literal(new BigDecimal("12340.12")); left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(10, 5)); right = new DecimalV3Literal(new BigDecimal("12345.12000")); expression = new EqualTo(left, right); rewrittenExpression = executor.rewrite(expression, context); - Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0)); Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(7, 2), rewrittenExpression.child(0).getDataType()); Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); Assertions.assertEquals(new BigDecimal("12345.12"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); } + + private enum RangeLimitResult { + TRUE, // eval to true + FALSE, // eval to false + EQUALS, // eval to equals + NO_CHANGE_CP // no change cmp type + } + + @Test + void testTypeRangeLimit() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyComparisonPredicate.INSTANCE) + )); + + checkTypeRangeLimit(TinyIntType.INSTANCE, + ImmutableList.of( + Pair.of(new SmallIntLiteral((short) -129), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-129")), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-128.1")), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-1000.1")), null), + Pair.of(new DoubleLiteral(-129.0), new SmallIntLiteral((short) -129)), + Pair.of(new DoubleLiteral(-128.1), new DecimalV3Literal(new BigDecimal("-128.1")))), + ImmutableList.of( + Pair.of(new TinyIntLiteral((byte) -128), null), + Pair.of(new SmallIntLiteral((short) -128), null), + Pair.of(new IntegerLiteral(-128), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-128")), new TinyIntLiteral((byte) -128)), + Pair.of(new DoubleLiteral(-128.0), new TinyIntLiteral((byte) -128))), + ImmutableList.of( + Pair.of(new TinyIntLiteral((byte) -127), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-127")), new TinyIntLiteral((byte) -127)), + Pair.of(new DoubleLiteral(-127.0), new TinyIntLiteral((byte) -127)), + Pair.of(new TinyIntLiteral((byte) 126), null), + Pair.of(new DoubleLiteral(126.0), new TinyIntLiteral((byte) 126))), + ImmutableList.of( + Pair.of(new TinyIntLiteral((byte) 127), null), + Pair.of(new DecimalV3Literal(new BigDecimal("127")), new TinyIntLiteral((byte) 127)), + Pair.of(new DecimalV3Literal(new BigDecimal("127.00")), new TinyIntLiteral((byte) 127)), + Pair.of(new DoubleLiteral(127.0), new TinyIntLiteral((byte) 127))), + ImmutableList.of( + Pair.of(new SmallIntLiteral((short) 128), null), + Pair.of(new DecimalV3Literal(new BigDecimal("128.02")), null), + Pair.of(new DoubleLiteral(128.0), new SmallIntLiteral((short) 128)), + Pair.of(new DoubleLiteral(127.1), new DecimalV3Literal(new BigDecimal("127.1"))))); + + checkTypeRangeLimit(SmallIntType.INSTANCE, + ImmutableList.of( + Pair.of(new IntegerLiteral(-32769), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-32769")), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-32768.1")), null), + Pair.of(new DoubleLiteral(-32769.0), new IntegerLiteral(-32769)), + Pair.of(new DoubleLiteral(-32769.1), new DecimalV3Literal(new BigDecimal("-32769.1")))), + ImmutableList.of( + Pair.of(new SmallIntLiteral((short) -32768), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-32768")), new SmallIntLiteral((short) -32768)), + Pair.of(new DoubleLiteral(-32768.0), new SmallIntLiteral((short) -32768))), + ImmutableList.of( + Pair.of(new SmallIntLiteral((short) -32767), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-32767")), new SmallIntLiteral((short) -32767)), + Pair.of(new DoubleLiteral(-32767.0), new SmallIntLiteral((short) -32767)), + Pair.of(new SmallIntLiteral((short) 32766), null), + Pair.of(new DoubleLiteral(32766.0), new SmallIntLiteral((short) 32766))), + ImmutableList.of( + Pair.of(new SmallIntLiteral((short) 32767), null), + Pair.of(new DecimalV3Literal(new BigDecimal("32767")), new SmallIntLiteral((short) 32767)), + Pair.of(new DecimalV3Literal(new BigDecimal("32767.00")), new SmallIntLiteral((short) 32767)), + Pair.of(new DoubleLiteral(32767.0), new SmallIntLiteral((short) 32767))), + ImmutableList.of( + Pair.of(new IntegerLiteral(32768), null), + Pair.of(new DecimalV3Literal(new BigDecimal("32768.02")), null), + Pair.of(new DoubleLiteral(32768.0), new IntegerLiteral(32768)), + Pair.of(new DoubleLiteral(32768.1), new DecimalV3Literal(new BigDecimal("32768.1"))))); + + checkTypeRangeLimit(IntegerType.INSTANCE, + ImmutableList.of( + Pair.of(new BigIntLiteral(-2147483649L), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-2147483649")), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-2147483649.1")), null), + Pair.of(new DoubleLiteral(-2147483649.0), new BigIntLiteral(-2147483649L)), + Pair.of(new DoubleLiteral(-2147483649.1), new DecimalV3Literal(new BigDecimal("-2147483649.1")))), + ImmutableList.of( + Pair.of(new IntegerLiteral(-2147483648), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-2147483648.0")), new IntegerLiteral(-2147483648)), + Pair.of(new DoubleLiteral(-2147483648.0), new IntegerLiteral(-2147483648))), + ImmutableList.of( + Pair.of(new TinyIntLiteral((byte) 0), null), + Pair.of(new SmallIntLiteral((short) 0), null), + Pair.of(new IntegerLiteral(-2147483647), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-2147483647")), new IntegerLiteral(-2147483647)), + Pair.of(new DoubleLiteral(-2147483647.0), new IntegerLiteral(-2147483647)), + Pair.of(new IntegerLiteral(2147483646), null), + Pair.of(new DoubleLiteral(2147483646.0), new IntegerLiteral(2147483646))), + ImmutableList.of( + Pair.of(new IntegerLiteral(2147483647), null), + Pair.of(new DecimalV3Literal(new BigDecimal("2147483647")), new IntegerLiteral(2147483647)), + Pair.of(new DecimalV3Literal(new BigDecimal("2147483647.00")), new IntegerLiteral(2147483647)), + Pair.of(new DoubleLiteral(2147483647.0), new IntegerLiteral(2147483647))), + ImmutableList.of( + Pair.of(new BigIntLiteral(2147483648L), null), + Pair.of(new DecimalV3Literal(new BigDecimal("2147483648.02")), null), + Pair.of(new DoubleLiteral(2147483648.0), new BigIntLiteral(2147483648L)), + Pair.of(new DoubleLiteral(2147483647.1), new DecimalV3Literal(new BigDecimal("2147483647.1"))))); + + checkTypeRangeLimit(BigIntType.INSTANCE, + ImmutableList.of( + Pair.of(new LargeIntLiteral(new BigInteger("-9223372036854775809")), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-9223372036854775809")), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-9223372036854775808.1")), null), + Pair.of(new DoubleLiteral(-9223372036854775809.0), new LargeIntLiteral(new BigInteger("-9223372036854775809"))), + Pair.of(new DoubleLiteral(-9223372036854775808.1), new DecimalV3Literal(new BigDecimal("-9223372036854775808.1")))), + ImmutableList.of( + Pair.of(new BigIntLiteral(-9223372036854775808L), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-9223372036854775808")), new BigIntLiteral(-9223372036854775808L))), + ImmutableList.of( + Pair.of(new BigIntLiteral(-9223372036854775807L), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-9223372036854775807")), new BigIntLiteral(-9223372036854775807L)), + Pair.of(new DoubleLiteral(-9223372036854775000.0), new BigIntLiteral(-9223372036854775000L)), + Pair.of(new BigIntLiteral(9223372036854775806L), null), + Pair.of(new DoubleLiteral(9223372036854775000.0), new BigIntLiteral(9223372036854775000L))), + ImmutableList.of( + Pair.of(new BigIntLiteral(9223372036854775807L), null), + Pair.of(new DecimalV3Literal(new BigDecimal("9223372036854775807")), new BigIntLiteral(9223372036854775807L)), + Pair.of(new DecimalV3Literal(new BigDecimal("9223372036854775807.00")), new BigIntLiteral(9223372036854775807L))), + ImmutableList.of( + Pair.of(new LargeIntLiteral(new BigInteger("9223372036854775808")), null), + Pair.of(new DecimalV3Literal(new BigDecimal("9223372036854775807.02")), null), + Pair.of(new DoubleLiteral(9223372036854775808.0), new LargeIntLiteral(new BigInteger("9223372036854775808"))), + Pair.of(new DoubleLiteral(9223372036854775807.1), new DecimalV3Literal(new BigDecimal("9223372036854775807.1"))))); + + checkTypeRangeLimit(DecimalV3Type.createDecimalV3Type(5, 2), + ImmutableList.of( + Pair.of(new IntegerLiteral(-1000), null), + Pair.of(new DoubleLiteral(-1000.1), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-999.999")), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-1000.00")), null), + Pair.of(new DecimalV3Literal(new BigDecimal("-1000.0123")), null)), + ImmutableList.of( + Pair.of(new DecimalV3Literal(new BigDecimal("-999.99")), null)), + ImmutableList.of( + Pair.of(new DecimalV3Literal(new BigDecimal("100.4")), null), + Pair.of(new DecimalV3Literal(new BigDecimal("100")), null)), + ImmutableList.of( + Pair.of(new DecimalV3Literal(new BigDecimal("999.99")), null)), + ImmutableList.of( + Pair.of(new IntegerLiteral(1000), null), + Pair.of(new DoubleLiteral(1000.1), null), + Pair.of(new DecimalV3Literal(new BigDecimal("1000")), null), + Pair.of(new DecimalV3Literal(new BigDecimal("999.999")), null))); + } + + // each expr list item is: pair + // if rewritten right literal = null, then rewritten right literal = origin right literal + void checkTypeRangeLimit(DataType dataType, List> lessThanMinExpr, + List> minExpr, List> betweenMinMaxExpr, + List> maxExpr, List> greaterThanMaxExpr) { + // due to ComparisonPredicate constructor require not null left and right child, + // use a dummyExpr as ComparisonPredicate's child + Expression dummyExpr = new SmallIntLiteral((short) 100); + // cp -> list of cp with lessThanMinExpr, minExpr, betweenMinMaxExpr, maxExpr, greaterThanMaxExpr + List>> cmpResults = ImmutableList.of( + Pair.of(new EqualTo(dummyExpr, dummyExpr), ImmutableList.of( + RangeLimitResult.FALSE, RangeLimitResult.NO_CHANGE_CP, RangeLimitResult.NO_CHANGE_CP, + RangeLimitResult.NO_CHANGE_CP, RangeLimitResult.FALSE)), + Pair.of(new NullSafeEqual(dummyExpr, dummyExpr), ImmutableList.of( + RangeLimitResult.FALSE, RangeLimitResult.NO_CHANGE_CP, RangeLimitResult.NO_CHANGE_CP, + RangeLimitResult.NO_CHANGE_CP, RangeLimitResult.FALSE)), + Pair.of(new GreaterThan(dummyExpr, dummyExpr), ImmutableList.of( + RangeLimitResult.TRUE, RangeLimitResult.NO_CHANGE_CP, RangeLimitResult.NO_CHANGE_CP, + RangeLimitResult.FALSE, RangeLimitResult.FALSE)), + Pair.of(new GreaterThanEqual(dummyExpr, dummyExpr), ImmutableList.of( + RangeLimitResult.TRUE, RangeLimitResult.TRUE, RangeLimitResult.NO_CHANGE_CP, + RangeLimitResult.EQUALS, RangeLimitResult.FALSE)), + Pair.of(new LessThan(dummyExpr, dummyExpr), ImmutableList.of( + RangeLimitResult.FALSE, RangeLimitResult.FALSE, RangeLimitResult.NO_CHANGE_CP, + RangeLimitResult.NO_CHANGE_CP, RangeLimitResult.TRUE)), + Pair.of(new LessThanEqual(dummyExpr, dummyExpr), ImmutableList.of( + RangeLimitResult.FALSE, RangeLimitResult.EQUALS, RangeLimitResult.NO_CHANGE_CP, + RangeLimitResult.TRUE, RangeLimitResult.TRUE)) + ); + + for (Pair> cmpResult : cmpResults) { + ComparisonPredicate cp = cmpResult.first; + List result = cmpResult.second; + checkTypeRangeLimitWithComparison(dataType, cp, lessThanMinExpr, result.get(0)); + checkTypeRangeLimitWithComparison(dataType, cp, minExpr, result.get(1)); + checkTypeRangeLimitWithComparison(dataType, cp, betweenMinMaxExpr, result.get(2)); + checkTypeRangeLimitWithComparison(dataType, cp, maxExpr, result.get(3)); + checkTypeRangeLimitWithComparison(dataType, cp, greaterThanMaxExpr, result.get(4)); + } + } + + void checkTypeRangeLimitWithComparison(DataType dataType, ComparisonPredicate cp, + List> exprs, RangeLimitResult result) { + Expression slot = new SlotReference("slot", dataType, true); + for (Pair pair : exprs) { + Expression right = pair.first; + Expression rewriteRight = pair.second; + if (rewriteRight == null) { + rewriteRight = right; + } + Expression left = slot; + if (!left.getDataType().equals(right.getDataType())) { + left = new Cast(slot, right.getDataType()); + } + Expression originExpr = cp.withChildren(left, right); + Expression rewrittenExpr = executor.rewrite(originExpr, context); + Expression expectExpr = null; + // System.out.println("slot type: " + slot.getDataType() + ", literal type: " + right.getDataType()); + // System.out.println("origin expr: " + originExpr); + // System.out.println("rewrite expr: " + rewrittenExpr); + switch (result) { + case TRUE: + expectExpr = cp instanceof NullSafeEqual ? BooleanLiteral.TRUE + : ExpressionUtils.trueOrNull(slot); + break; + case FALSE: + expectExpr = cp instanceof NullSafeEqual ? BooleanLiteral.FALSE + : ExpressionUtils.falseOrNull(slot); + break; + case EQUALS: + Expression expectLeft = slot.getDataType().equals(rewriteRight.getDataType()) ? slot : left; + expectExpr = new EqualTo(expectLeft, rewriteRight); + break; + case NO_CHANGE_CP: + Assertions.assertInstanceOf(cp.getClass(), rewrittenExpr); + break; + default: + Assertions.assertTrue(false); + } + if (expectExpr != null) { + Assertions.assertEquals(expectExpr, rewrittenExpr); + } + } + } } diff --git a/regression-test/data/nereids_rules_p0/infer_predicate/pull_up_predicate_literal.out b/regression-test/data/nereids_rules_p0/infer_predicate/pull_up_predicate_literal.out index 524559cabeb34d..920289be66c360 100644 --- a/regression-test/data/nereids_rules_p0/infer_predicate/pull_up_predicate_literal.out +++ b/regression-test/data/nereids_rules_p0/infer_predicate/pull_up_predicate_literal.out @@ -262,15 +262,7 @@ PhysicalResultSink -- !const_value_and_join_column_type16 -- PhysicalResultSink ---PhysicalProject -----hashJoin[INNER_JOIN] hashCondition=((expr_cast(d_tinyint as SMALLINT) = t.c1)) otherCondition=() -------PhysicalLimit[GLOBAL] ---------PhysicalLimit[LOCAL] -----------PhysicalProject -------------PhysicalStorageLayerAggregate[test_pull_up_predicate_literal] -------PhysicalProject ---------filter((cast(d_tinyint as SMALLINT) = 32767)) -----------PhysicalOlapScan[test_types] +--PhysicalEmptyRelation -- !const_value_and_join_column_type17 -- PhysicalResultSink @@ -414,27 +406,11 @@ PhysicalResultSink -- !const_value_and_join_column_type32 -- PhysicalResultSink ---PhysicalProject -----hashJoin[INNER_JOIN] hashCondition=((expr_cast(d_tinyint as INT) = t.c1)) otherCondition=() -------PhysicalLimit[GLOBAL] ---------PhysicalLimit[LOCAL] -----------PhysicalProject -------------PhysicalStorageLayerAggregate[test_pull_up_predicate_literal] -------PhysicalProject ---------filter((cast(d_tinyint as INT) = 32768)) -----------PhysicalOlapScan[test_types] +--PhysicalEmptyRelation -- !const_value_and_join_column_type33 -- PhysicalResultSink ---PhysicalProject -----hashJoin[INNER_JOIN] hashCondition=((expr_cast(d_smallint as INT) = t.c1)) otherCondition=() -------PhysicalLimit[GLOBAL] ---------PhysicalLimit[LOCAL] -----------PhysicalProject -------------PhysicalStorageLayerAggregate[test_pull_up_predicate_literal] -------PhysicalProject ---------filter((cast(d_smallint as INT) = 32768)) -----------PhysicalOlapScan[test_types] +--PhysicalEmptyRelation -- !const_value_and_join_column_type34 -- PhysicalResultSink @@ -566,39 +542,15 @@ PhysicalResultSink -- !const_value_and_join_column_type48 -- PhysicalResultSink ---PhysicalProject -----hashJoin[INNER_JOIN] hashCondition=((expr_cast(d_tinyint as BIGINT) = t.c1)) otherCondition=() -------PhysicalLimit[GLOBAL] ---------PhysicalLimit[LOCAL] -----------PhysicalProject -------------PhysicalStorageLayerAggregate[test_pull_up_predicate_literal] -------PhysicalProject ---------filter((cast(d_tinyint as BIGINT) = 214748364799)) -----------PhysicalOlapScan[test_types] +--PhysicalEmptyRelation -- !const_value_and_join_column_type49 -- PhysicalResultSink ---PhysicalProject -----hashJoin[INNER_JOIN] hashCondition=((expr_cast(d_smallint as BIGINT) = t.c1)) otherCondition=() -------PhysicalLimit[GLOBAL] ---------PhysicalLimit[LOCAL] -----------PhysicalProject -------------PhysicalStorageLayerAggregate[test_pull_up_predicate_literal] -------PhysicalProject ---------filter((cast(d_smallint as BIGINT) = 214748364799)) -----------PhysicalOlapScan[test_types] +--PhysicalEmptyRelation -- !const_value_and_join_column_type50 -- PhysicalResultSink ---PhysicalProject -----hashJoin[INNER_JOIN] hashCondition=((expr_cast(d_int as BIGINT) = t.c1)) otherCondition=() -------PhysicalLimit[GLOBAL] ---------PhysicalLimit[LOCAL] -----------PhysicalProject -------------PhysicalStorageLayerAggregate[test_pull_up_predicate_literal] -------PhysicalProject ---------filter((cast(d_int as BIGINT) = 214748364799)) -----------PhysicalOlapScan[test_types] +--PhysicalEmptyRelation -- !const_value_and_join_column_type51 -- PhysicalResultSink @@ -718,51 +670,19 @@ PhysicalResultSink -- !const_value_and_join_column_type64 -- PhysicalResultSink ---PhysicalProject -----hashJoin[INNER_JOIN] hashCondition=((expr_cast(d_tinyint as LARGEINT) = t.c1)) otherCondition=() -------PhysicalLimit[GLOBAL] ---------PhysicalLimit[LOCAL] -----------PhysicalProject -------------PhysicalStorageLayerAggregate[test_pull_up_predicate_literal] -------PhysicalProject ---------filter((cast(d_tinyint as LARGEINT) = 922337203685477580722)) -----------PhysicalOlapScan[test_types] +--PhysicalEmptyRelation -- !const_value_and_join_column_type65 -- PhysicalResultSink ---PhysicalProject -----hashJoin[INNER_JOIN] hashCondition=((expr_cast(d_smallint as LARGEINT) = t.c1)) otherCondition=() -------PhysicalLimit[GLOBAL] ---------PhysicalLimit[LOCAL] -----------PhysicalProject -------------PhysicalStorageLayerAggregate[test_pull_up_predicate_literal] -------PhysicalProject ---------filter((cast(d_smallint as LARGEINT) = 922337203685477580722)) -----------PhysicalOlapScan[test_types] +--PhysicalEmptyRelation -- !const_value_and_join_column_type66 -- PhysicalResultSink ---PhysicalProject -----hashJoin[INNER_JOIN] hashCondition=((expr_cast(d_int as LARGEINT) = t.c1)) otherCondition=() -------PhysicalLimit[GLOBAL] ---------PhysicalLimit[LOCAL] -----------PhysicalProject -------------PhysicalStorageLayerAggregate[test_pull_up_predicate_literal] -------PhysicalProject ---------filter((cast(d_int as LARGEINT) = 922337203685477580722)) -----------PhysicalOlapScan[test_types] +--PhysicalEmptyRelation -- !const_value_and_join_column_type67 -- PhysicalResultSink ---PhysicalProject -----hashJoin[INNER_JOIN] hashCondition=((expr_cast(d_bigint as LARGEINT) = t.c1)) otherCondition=() -------PhysicalLimit[GLOBAL] ---------PhysicalLimit[LOCAL] -----------PhysicalProject -------------PhysicalStorageLayerAggregate[test_pull_up_predicate_literal] -------PhysicalProject ---------filter((cast(d_bigint as LARGEINT) = 922337203685477580722)) -----------PhysicalOlapScan[test_types] +--PhysicalEmptyRelation -- !const_value_and_join_column_type68 -- PhysicalResultSink @@ -814,15 +734,7 @@ PhysicalResultSink -- !const_value_and_join_column_type72 -- PhysicalResultSink ---PhysicalProject -----hashJoin[INNER_JOIN] hashCondition=((expr_cast(d_decimal as DOUBLE) = expr_cast(c1 as DOUBLE))) otherCondition=() -------PhysicalLimit[GLOBAL] ---------PhysicalLimit[LOCAL] -----------PhysicalProject -------------PhysicalStorageLayerAggregate[test_pull_up_predicate_literal] -------PhysicalProject ---------filter((cast(d_decimal as DOUBLE) = 9.223372036854776E20)) -----------PhysicalOlapScan[test_types] +--PhysicalEmptyRelation -- !const_value_and_join_column_type73 -- PhysicalResultSink diff --git a/regression-test/data/nereids_rules_p0/predicate_infer/infer_predicate.out b/regression-test/data/nereids_rules_p0/predicate_infer/infer_predicate.out index 288c30bb28c1cf..fb99fe6169c1d9 100644 --- a/regression-test/data/nereids_rules_p0/predicate_infer/infer_predicate.out +++ b/regression-test/data/nereids_rules_p0/predicate_infer/infer_predicate.out @@ -374,23 +374,23 @@ PhysicalResultSink -- !infer9 -- PhysicalResultSink --hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id)) otherCondition=() -----filter((cast(id as BIGINT) = 2147483648)) +----filter((t1.id = 12345)) ------PhysicalOlapScan[t1] -----filter((cast(id as BIGINT) = 2147483648)) +----filter((t2.id = 12345)) ------PhysicalOlapScan[t2] -- !infer10 -- PhysicalResultSink --hashJoin[INNER_JOIN] hashCondition=((expr_cast(id as SMALLINT) = expr_cast(id as SMALLINT))) otherCondition=() -----filter((cast(id as BIGINT) = 2147483648)) +----filter((t1.id = 12345)) ------PhysicalOlapScan[t1] ----PhysicalOlapScan[t2] -- !infer11 -- PhysicalResultSink --hashJoin[INNER_JOIN] hashCondition=((expr_cast(id as LARGEINT) = expr_cast(id as LARGEINT))) otherCondition=() -----filter((cast(id as BIGINT) = 2147483648)) +----filter((t1.id = 12345)) ------PhysicalOlapScan[t1] -----filter((cast(id as BIGINT) = 2147483648)) +----filter((t2.id = 12345)) ------PhysicalOlapScan[t2] diff --git a/regression-test/suites/nereids_rules_p0/predicate_infer/infer_predicate.groovy b/regression-test/suites/nereids_rules_p0/predicate_infer/infer_predicate.groovy index af985ecd7ee10e..6237d5d66e8ef3 100644 --- a/regression-test/suites/nereids_rules_p0/predicate_infer/infer_predicate.groovy +++ b/regression-test/suites/nereids_rules_p0/predicate_infer/infer_predicate.groovy @@ -294,18 +294,17 @@ suite("infer_predicate") { explain shape plan select * from t1 join t2 on t1.id != t2.id where t1.id = 1; """ - // 测试 infer predicate 是否能推出正确类型, 2147483648 超过了 Int32 的最大值, 但是不超过 Int64 的最大值,用这个值测试类型是否能推导正确 qt_infer9 """ - explain shape plan select * from (select * from t1 where t1.id = 2147483648) t1 join t2 on t1.id = t2.id; + explain shape plan select * from (select * from t1 where t1.id = 12345) t1 join t2 on t1.id = t2.id; """ // 测试 cast = cast qt_infer10 """ - explain shape plan select * from (select * from t1 where t1.id = 2147483648) t1 join t2 on cast(t1.id as smallint) = cast(t2.id as smallint); + explain shape plan select * from (select * from t1 where t1.id = 12345) t1 join t2 on cast(t1.id as smallint) = cast(t2.id as smallint); """ // 测试 cast = cast qt_infer11 """ - explain shape plan select * from (select * from t1 where t1.id = 2147483648) t1 join t2 on cast(t1.id as largeint) = cast(t2.id as largeint); + explain shape plan select * from (select * from t1 where t1.id = 12345) t1 join t2 on cast(t1.id as largeint) = cast(t2.id as largeint); """ }