Skip to content

Commit

Permalink
[test](nereids) add test simplify comparison predicate (apache#44886)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Add test simplify comparison predicate
  • Loading branch information
yujun777 authored Dec 23, 2024
1 parent a032ece commit e09bc04
Show file tree
Hide file tree
Showing 2 changed files with 360 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.LargeIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
Expand All @@ -54,6 +55,7 @@
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.TinyIntType;
Expand Down Expand Up @@ -296,10 +298,197 @@ void testDoubleLiteral() {
Expression rewrittenExpression = executor.rewrite(expression, context);
Assertions.assertEquals(left.child(0).getDataType(), rewrittenExpression.child(1).getDataType());
Assertions.assertEquals(rewrittenExpression.child(0).getDataType(), rewrittenExpression.child(1).getDataType());

Expression tinyIntSlot = new SlotReference("a", TinyIntType.INSTANCE);
Expression smallIntSlot = new SlotReference("a", SmallIntType.INSTANCE);
Expression intSlot = new SlotReference("a", IntegerType.INSTANCE);
Expression bigIntSlot = new SlotReference("a", BigIntType.INSTANCE);

// tiny int, literal not exceeds data type limit
assertRewrite(new EqualTo(new Cast(tinyIntSlot, FloatType.INSTANCE), new FloatLiteral(12.0f)),
new EqualTo(tinyIntSlot, new TinyIntLiteral((byte) 12)));
assertRewrite(new EqualTo(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.0f)),
new EqualTo(tinyIntSlot, new TinyIntLiteral((byte) 12)));
assertRewrite(new EqualTo(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
ExpressionUtils.falseOrNull(tinyIntSlot));
assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new GreaterThan(tinyIntSlot, new TinyIntLiteral((byte) 12)));
assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new GreaterThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 13)));
assertRewrite(new LessThan(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new LessThan(tinyIntSlot, new TinyIntLiteral((byte) 13)));
assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new LessThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 12)));

// tiny int, literal exceeds data type limit
assertRewrite(new EqualTo(new Cast(tinyIntSlot, FloatType.INSTANCE), new FloatLiteral(200.0f)),
ExpressionUtils.falseOrNull(tinyIntSlot));
assertRewrite(new EqualTo(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(200.0f)),
ExpressionUtils.falseOrNull(tinyIntSlot));
assertRewrite(new EqualTo(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(200.3f)),
ExpressionUtils.falseOrNull(tinyIntSlot));
assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(200.3f)),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(200.3f)),
ExpressionUtils.falseOrNull(tinyIntSlot));
assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(200.3f)),
ExpressionUtils.falseOrNull(tinyIntSlot));
assertRewrite(new LessThan(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(200.3f)),
ExpressionUtils.trueOrNull(tinyIntSlot));
assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(200.3f)),
ExpressionUtils.trueOrNull(tinyIntSlot));

// small int
assertRewrite(new EqualTo(new Cast(smallIntSlot, FloatType.INSTANCE), new FloatLiteral(12.0f)),
new EqualTo(smallIntSlot, new SmallIntLiteral((short) 12)));
assertRewrite(new EqualTo(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.0f)),
new EqualTo(smallIntSlot, new SmallIntLiteral((short) 12)));
assertRewrite(new EqualTo(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
ExpressionUtils.falseOrNull(smallIntSlot));
assertRewrite(new NullSafeEqual(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new GreaterThan(smallIntSlot, new SmallIntLiteral((short) 12)));
assertRewrite(new GreaterThanEqual(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new GreaterThanEqual(smallIntSlot, new SmallIntLiteral((short) 13)));
assertRewrite(new LessThan(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new LessThan(smallIntSlot, new SmallIntLiteral((short) 13)));
assertRewrite(new LessThanEqual(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new LessThanEqual(smallIntSlot, new SmallIntLiteral((short) 12)));

// int
assertRewrite(new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(12.0f)),
new EqualTo(intSlot, new IntegerLiteral(12)));
assertRewrite(new EqualTo(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.0f)),
new EqualTo(intSlot, new IntegerLiteral(12)));
assertRewrite(new EqualTo(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
ExpressionUtils.falseOrNull(intSlot));
assertRewrite(new NullSafeEqual(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new GreaterThan(intSlot, new IntegerLiteral(12)));
assertRewrite(new GreaterThanEqual(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new GreaterThanEqual(intSlot, new IntegerLiteral(13)));
assertRewrite(new LessThan(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new LessThan(intSlot, new IntegerLiteral(13)));
assertRewrite(new LessThanEqual(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new LessThanEqual(intSlot, new IntegerLiteral(12)));

// big int
assertRewrite(new EqualTo(new Cast(bigIntSlot, FloatType.INSTANCE), new FloatLiteral(12.0f)),
new EqualTo(bigIntSlot, new BigIntLiteral(12L)));
assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.0f)),
new EqualTo(bigIntSlot, new BigIntLiteral(12L)));
assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
ExpressionUtils.falseOrNull(bigIntSlot));
assertRewrite(new NullSafeEqual(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new GreaterThan(bigIntSlot, new BigIntLiteral(12L)));
assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new GreaterThanEqual(bigIntSlot, new BigIntLiteral(13L)));
assertRewrite(new LessThan(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new LessThan(bigIntSlot, new BigIntLiteral(13L)));
assertRewrite(new LessThanEqual(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new LessThanEqual(bigIntSlot, new BigIntLiteral(12L)));
}

@Test
void testIntCmpDecimalV3Literal() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(SimplifyComparisonPredicate.INSTANCE)
));

Expression tinyIntSlot = new SlotReference("a", TinyIntType.INSTANCE);
Expression smallIntSlot = new SlotReference("a", SmallIntType.INSTANCE);
Expression intSlot = new SlotReference("a", IntegerType.INSTANCE);
Expression bigIntSlot = new SlotReference("a", BigIntType.INSTANCE);

// tiny int, literal not exceeds data type limit
assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))),
new EqualTo(tinyIntSlot, new TinyIntLiteral((byte) 12)));
assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
ExpressionUtils.falseOrNull(tinyIntSlot));
assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new GreaterThan(tinyIntSlot, new TinyIntLiteral((byte) 12)));
assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new GreaterThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 13)));
assertRewrite(new LessThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new LessThan(tinyIntSlot, new TinyIntLiteral((byte) 13)));
assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new LessThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 12)));

// tiny int, literal exceeds data type limit
assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.0"))),
ExpressionUtils.falseOrNull(tinyIntSlot));
assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))),
ExpressionUtils.falseOrNull(tinyIntSlot));
assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))),
ExpressionUtils.falseOrNull(tinyIntSlot));
assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))),
ExpressionUtils.falseOrNull(tinyIntSlot));
assertRewrite(new LessThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))),
ExpressionUtils.trueOrNull(tinyIntSlot));
assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))),
ExpressionUtils.trueOrNull(tinyIntSlot));

// small int
assertRewrite(new EqualTo(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))),
new EqualTo(smallIntSlot, new SmallIntLiteral((short) 12)));
assertRewrite(new EqualTo(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
ExpressionUtils.falseOrNull(smallIntSlot));
assertRewrite(new NullSafeEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new GreaterThan(smallIntSlot, new SmallIntLiteral((short) 12)));
assertRewrite(new GreaterThanEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new GreaterThanEqual(smallIntSlot, new SmallIntLiteral((short) 13)));
assertRewrite(new LessThan(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new LessThan(smallIntSlot, new SmallIntLiteral((short) 13)));
assertRewrite(new LessThanEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new LessThanEqual(smallIntSlot, new SmallIntLiteral((short) 12)));

// int
assertRewrite(new EqualTo(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))),
new EqualTo(intSlot, new IntegerLiteral(12)));
assertRewrite(new EqualTo(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
ExpressionUtils.falseOrNull(intSlot));
assertRewrite(new NullSafeEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new GreaterThan(intSlot, new IntegerLiteral(12)));
assertRewrite(new GreaterThanEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new GreaterThanEqual(intSlot, new IntegerLiteral(13)));
assertRewrite(new LessThan(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new LessThan(intSlot, new IntegerLiteral(13)));
assertRewrite(new LessThanEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new LessThanEqual(intSlot, new IntegerLiteral(12)));

// big int
assertRewrite(new EqualTo(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))),
new EqualTo(bigIntSlot, new BigIntLiteral(12L)));
assertRewrite(new EqualTo(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
ExpressionUtils.falseOrNull(bigIntSlot));
assertRewrite(new NullSafeEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new GreaterThan(bigIntSlot, new BigIntLiteral(12L)));
assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new GreaterThanEqual(bigIntSlot, new BigIntLiteral(13L)));
assertRewrite(new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new LessThan(bigIntSlot, new BigIntLiteral(13L)));
assertRewrite(new LessThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new LessThanEqual(bigIntSlot, new BigIntLiteral(12L)));
}

@Test
void testDecimalV3Literal() {
void testDecimalCmpDecimalV3Literal() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(SimplifyComparisonPredicate.INSTANCE)
));
Expand Down
Loading

0 comments on commit e09bc04

Please sign in to comment.