Skip to content

Commit

Permalink
[fix](Nereids) some expression not cast in in predicate (apache#24680)
Browse files Browse the repository at this point in the history
1. should use castIfNotSameType in InPredicate and CaseWhen
2. StringLikeLiteral should override equals to ignore type
  • Loading branch information
morrySnow authored Sep 22, 2023
1 parent 034582b commit 320fc14
Show file tree
Hide file tree
Showing 43 changed files with 118 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.types.DateTimeV2Type;

import com.google.common.collect.Lists;
import com.google.common.collect.ImmutableList;

import java.util.List;

Expand All @@ -46,12 +47,25 @@ public Expression visitInPredicate(InPredicate expr, ExpressionRewriteContext co
List<Expression> literals = expr.children().subList(1, expr.children().size());
if (literals.stream().allMatch(literal -> literal instanceof DateTimeV2Literal
&& canLosslessConvertToDateV2Literal((DateTimeV2Literal) literal))) {
List<Expression> children = Lists.newArrayList();
ImmutableList.Builder<Expression> children = ImmutableList.builder();
children.add(cast.child());
literals.stream().forEach(
l -> children.add(convertToDateV2Literal((DateTimeV2Literal) l)));
return expr.withChildren(children);
literals.forEach(l -> children.add(convertToDateV2Literal((DateTimeV2Literal) l)));
return expr.withChildren(children.build());
}
} else if (cast.child().getDataType().isDateTimeV2Type()
&& expr.child(1) instanceof DateTimeV2Literal) {
List<Expression> literals = expr.children().subList(1, expr.children().size());
DateTimeV2Type compareType = (DateTimeV2Type) cast.child().getDataType();
if (literals.stream().allMatch(literal -> literal instanceof DateTimeV2Literal
&& canLosslessConvertToLowScaleLiteral(
(DateTimeV2Literal) literal, compareType.getScale()))) {
ImmutableList.Builder<Expression> children = ImmutableList.builder();
children.add(cast.child());
literals.forEach(l -> children.add(new DateTimeV2Literal(compareType,
((DateTimeV2Literal) l).getStringValue())));
return expr.withChildren(children.build());
}

}
}
}
Expand All @@ -75,4 +89,8 @@ private static boolean canLosslessConvertToDateV2Literal(DateTimeV2Literal liter
private DateV2Literal convertToDateV2Literal(DateTimeV2Literal literal) {
return new DateV2Literal(literal.getYear(), literal.getMonth(), literal.getDay());
}

private static boolean canLosslessConvertToLowScaleLiteral(DateTimeV2Literal literal, int targetScale) {
return literal.getMicroSecond() % (1L << (DateTimeV2Type.MAX_SCALE - targetScale)) == 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import org.apache.doris.nereids.types.DataType;

import java.util.Objects;

/** StringLikeLiteral. */
public abstract class StringLikeLiteral extends Literal {
public final String value;
Expand All @@ -44,6 +46,23 @@ public double getDouble() {
return (double) v;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof StringLikeLiteral)) {
return false;
}
StringLikeLiteral that = (StringLikeLiteral) o;
return Objects.equals(value, that.value);
}

@Override
public int hashCode() {
return Objects.hash(value);
}

@Override
public String toString() {
return "'" + value + "'";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ public Type toCatalogDataType() {
return ScalarType.createChar(len);
}

@Override
public boolean acceptsType(DataType other) {
return other instanceof CharType;
}

@Override
public String simpleString() {
return "char";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ public Type toCatalogDataType() {
return Type.STRING;
}

@Override
public boolean acceptsType(DataType other) {
return other instanceof StringType || other instanceof VarcharType;
}

@Override
public String simpleString() {
return "string";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,6 @@ public Type toCatalogDataType() {
return catalogDataType;
}

@Override
public boolean acceptsType(DataType other) {
return other instanceof VarcharType || other instanceof StringType;
}

@Override
public String simpleString() {
return "varchar";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,7 @@ public static Expression castIfNotSameType(Expression input, DataType targetType
if (input.isNullLiteral()) {
return new NullLiteral(targetType);
} else if (input.getDataType().equals(targetType) || isSubqueryAndDataTypeIsBitmap(input)
|| (isVarCharOrStringType(input.getDataType())
&& isVarCharOrStringType(targetType))) {
|| (input.getDataType().isStringLikeType()) && targetType.isStringLikeType()) {
return input;
} else {
checkCanCastTo(input.getDataType(), targetType);
Expand All @@ -352,10 +351,6 @@ private static boolean isSubqueryAndDataTypeIsBitmap(Expression input) {
return input instanceof SubqueryExpr && input.getDataType().isBitmapType();
}

private static boolean isVarCharOrStringType(DataType dataType) {
return dataType instanceof VarcharType || dataType instanceof StringType;
}

private static boolean canCastTo(DataType input, DataType target) {
return Type.canCastTo(input.toCatalogDataType(), target.toCatalogDataType());
}
Expand Down Expand Up @@ -857,7 +852,7 @@ public static Expression processInPredicate(InPredicate inPredicate) {
return optionalCommonType
.map(commonType -> {
List<Expression> newChildren = inPredicate.children().stream()
.map(e -> TypeCoercionUtils.castIfNotMatchType(e, commonType))
.map(e -> TypeCoercionUtils.castIfNotSameType(e, commonType))
.collect(Collectors.toList());
return inPredicate.withChildren(newChildren);
})
Expand Down Expand Up @@ -886,7 +881,7 @@ public static Expression processCaseWhen(CaseWhen caseWhen) {
List<Expression> newChildren
= caseWhen.getWhenClauses().stream()
.map(wc -> {
Expression valueExpr = TypeCoercionUtils.castIfNotMatchType(
Expression valueExpr = TypeCoercionUtils.castIfNotSameType(
wc.getResult(), commonType);
// we must cast every child to the common type, and then
// FoldConstantRuleOnFe can eliminate some branches and direct
Expand All @@ -899,7 +894,7 @@ public static Expression processCaseWhen(CaseWhen caseWhen) {
.collect(Collectors.toList());
caseWhen.getDefaultValue()
.map(dv -> {
Expression defaultExpr = TypeCoercionUtils.castIfNotMatchType(dv, commonType);
Expression defaultExpr = TypeCoercionUtils.castIfNotSameType(dv, commonType);
if (!defaultExpr.getDataType().equals(commonType)) {
defaultExpr = new Cast(defaultExpr, commonType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public void q4_1() {
ImmutableList.of(
"(c_region = 'AMERICA')",
"(s_region = 'AMERICA')",
"p_mfgr IN ('MFGR#2', 'MFGR#1')"
"p_mfgr IN ('MFGR#1', 'MFGR#2')"
)
);
}
Expand All @@ -58,7 +58,7 @@ public void q4_2() {
"d_year IN (1997, 1998)",
"(c_region = 'AMERICA')",
"(s_region = 'AMERICA')",
"p_mfgr IN ('MFGR#2', 'MFGR#1')"
"p_mfgr IN ('MFGR#1', 'MFGR#2')"
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ public void test() {
));
Map<String, Slot> mem = Maps.newHashMap();
Expression rewrittenExpression = PARSER.parseExpression("cast(CA as DATETIME) in ('1992-01-31 00:00:00', '1992-02-01 00:00:00')");
// after parse and type coercion: CAST(CAST(CA AS DATETIMEV2(0)) AS DATETIMEV2(6)) IN ('1992-01-31 00:00:00.000000', '1992-02-01 00:00:00.000000')
rewrittenExpression = typeCoercion(replaceUnboundSlot(rewrittenExpression, mem));
// after first rewrite: CAST(CA AS DATETIMEV2(0)) IN ('1992-01-31 00:00:00', '1992-02-01 00:00:00')
rewrittenExpression = executor.rewrite(rewrittenExpression, context);
// after second rewrite: CA IN ('1992-01-31', '1992-02-01')
rewrittenExpression = executor.rewrite(rewrittenExpression, context);
Expression expectedExpression = PARSER.parseExpression("CA in (cast('1992-01-31' as date), cast('1992-02-01' as date))");
expectedExpression = replaceUnboundSlot(expectedExpression, mem);
executor = new ExpressionRuleExecutor(ImmutableList.of(
FoldConstantRule.INSTANCE
));
expectedExpression = executor.rewrite(expectedExpression, context);
Assertions.assertEquals(expectedExpression.toSql(), rewrittenExpression.toSql());
Assertions.assertEquals(expectedExpression, rewrittenExpression);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,8 @@ public void testCharAccept() {
int scale = Math.min(precision, Math.abs(new Random().nextInt() % DecimalV2Type.MAX_SCALE));
Assertions.assertFalse(dataType.acceptsType(DecimalV2Type.createDecimalV2Type(precision, scale)));
Assertions.assertTrue(dataType.acceptsType(new CharType(new Random().nextInt())));
Assertions.assertFalse(dataType.acceptsType(new VarcharType(new Random().nextInt())));
Assertions.assertFalse(dataType.acceptsType(StringType.INSTANCE));
Assertions.assertTrue(dataType.acceptsType(new VarcharType(new Random().nextInt())));
Assertions.assertTrue(dataType.acceptsType(StringType.INSTANCE));
Assertions.assertFalse(dataType.acceptsType(DateType.INSTANCE));
Assertions.assertFalse(dataType.acceptsType(DateTimeType.INSTANCE));
}
Expand All @@ -474,7 +474,7 @@ public void testVarcharAccept() {
int precision = Math.abs(new Random().nextInt() % (DecimalV2Type.MAX_PRECISION - 1)) + 1;
int scale = Math.min(precision, Math.abs(new Random().nextInt() % DecimalV2Type.MAX_SCALE));
Assertions.assertFalse(dataType.acceptsType(DecimalV2Type.createDecimalV2Type(precision, scale)));
Assertions.assertFalse(dataType.acceptsType(new CharType(new Random().nextInt())));
Assertions.assertTrue(dataType.acceptsType(new CharType(new Random().nextInt())));
Assertions.assertTrue(dataType.acceptsType(new VarcharType(new Random().nextInt())));
Assertions.assertTrue(dataType.acceptsType(StringType.INSTANCE));
Assertions.assertFalse(dataType.acceptsType(DateType.INSTANCE));
Expand All @@ -496,7 +496,7 @@ public void testStringAccept() {
int precision = Math.abs(new Random().nextInt() % (DecimalV2Type.MAX_PRECISION - 1)) + 1;
int scale = Math.min(precision, Math.abs(new Random().nextInt() % DecimalV2Type.MAX_SCALE));
Assertions.assertFalse(dataType.acceptsType(DecimalV2Type.createDecimalV2Type(precision, scale)));
Assertions.assertFalse(dataType.acceptsType(new CharType(new Random().nextInt())));
Assertions.assertTrue(dataType.acceptsType(new CharType(new Random().nextInt())));
Assertions.assertTrue(dataType.acceptsType(new VarcharType(new Random().nextInt())));
Assertions.assertTrue(dataType.acceptsType(StringType.INSTANCE));
Assertions.assertFalse(dataType.acceptsType(DateType.INSTANCE));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.literal.CharLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BitmapType;
Expand Down Expand Up @@ -692,9 +695,18 @@ private void testFindCommonPrimitiveTypeForCaseWhen(DataType commonType, DataTyp
@Test
public void testCastIfNotSameType() {
Assertions.assertEquals(new DoubleLiteral(5L),
TypeCoercionUtils.castIfNotMatchType(new DoubleLiteral(5L), DoubleType.INSTANCE));
TypeCoercionUtils.castIfNotSameType(new DoubleLiteral(5L), DoubleType.INSTANCE));
Assertions.assertEquals(new Cast(new DoubleLiteral(5L), BooleanType.INSTANCE),
TypeCoercionUtils.castIfNotMatchType(new DoubleLiteral(5L), BooleanType.INSTANCE));
TypeCoercionUtils.castIfNotSameType(new DoubleLiteral(5L), BooleanType.INSTANCE));
Assertions.assertEquals(new StringLiteral("varchar"),
TypeCoercionUtils.castIfNotSameType(new VarcharLiteral("varchar"), StringType.INSTANCE));
Assertions.assertEquals(new StringLiteral("char"),
TypeCoercionUtils.castIfNotSameType(new CharLiteral("char", 4), StringType.INSTANCE));
Assertions.assertEquals(new CharLiteral("char", 4),
TypeCoercionUtils.castIfNotSameType(new CharLiteral("char", 4), VarcharType.createVarcharType(100)));
Assertions.assertEquals(new StringLiteral("string"),
TypeCoercionUtils.castIfNotSameType(new StringLiteral("string"), VarcharType.createVarcharType(100)));

}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ PhysicalResultSink
--------------------------------PhysicalOlapScan[customer]
--------------------PhysicalDistribute
----------------------PhysicalProject
------------------------filter(p_mfgr IN ('MFGR#2', 'MFGR#1'))
------------------------filter(p_mfgr IN ('MFGR#1', 'MFGR#2'))
--------------------------PhysicalOlapScan[part]
------------------PhysicalDistribute
--------------------PhysicalProject
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ PhysicalResultSink
------------------------------PhysicalOlapScan[customer]
------------------PhysicalDistribute
--------------------PhysicalProject
----------------------filter(p_mfgr IN ('MFGR#2', 'MFGR#1'))
----------------------filter(p_mfgr IN ('MFGR#1', 'MFGR#2'))
------------------------PhysicalOlapScan[part]

Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
----------------------PhysicalDistribute
------------------------PhysicalProject
--------------------------filter((cast(s_state as VARCHAR(*)) = 'SD'))
--------------------------filter((store.s_state = 'SD'))
----------------------------PhysicalOlapScan[store]
------------------hashAgg[GLOBAL]
--------------------PhysicalDistribute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ PhysicalResultSink
--------------------PhysicalOlapScan[store]
------------------PhysicalDistribute
--------------------PhysicalProject
----------------------hashJoin[INNER_JOIN](customer_demographics.cd_demo_sk = store_sales.ss_cdemo_sk)((((((cast(cd_marital_status as VARCHAR(*)) = 'D') AND (cast(cd_education_status as VARCHAR(*)) = 'Unknown')) AND ((store_sales.ss_sales_price >= 100.00) AND (store_sales.ss_sales_price <= 150.00))) AND (household_demographics.hd_dep_count = 3)) OR ((((cast(cd_marital_status as VARCHAR(*)) = 'S') AND (cast(cd_education_status as VARCHAR(*)) = 'College')) AND ((store_sales.ss_sales_price >= 50.00) AND (store_sales.ss_sales_price <= 100.00))) AND (household_demographics.hd_dep_count = 1))) OR ((((cast(cd_marital_status as VARCHAR(*)) = 'M') AND (cast(cd_education_status as VARCHAR(*)) = '4 yr Degree')) AND ((store_sales.ss_sales_price >= 150.00) AND (store_sales.ss_sales_price <= 200.00))) AND (household_demographics.hd_dep_count = 1)))
----------------------hashJoin[INNER_JOIN](customer_demographics.cd_demo_sk = store_sales.ss_cdemo_sk)((((((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Unknown')) AND ((store_sales.ss_sales_price >= 100.00) AND (store_sales.ss_sales_price <= 150.00))) AND (household_demographics.hd_dep_count = 3)) OR ((((customer_demographics.cd_marital_status = 'S') AND (customer_demographics.cd_education_status = 'College')) AND ((store_sales.ss_sales_price >= 50.00) AND (store_sales.ss_sales_price <= 100.00))) AND (household_demographics.hd_dep_count = 1))) OR ((((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = '4 yr Degree')) AND ((store_sales.ss_sales_price >= 150.00) AND (store_sales.ss_sales_price <= 200.00))) AND (household_demographics.hd_dep_count = 1)))
------------------------PhysicalProject
--------------------------filter(((((cast(cd_marital_status as VARCHAR(*)) = 'D') AND (cast(cd_education_status as VARCHAR(*)) = 'Unknown')) OR ((cast(cd_marital_status as VARCHAR(*)) = 'S') AND (cast(cd_education_status as VARCHAR(*)) = 'College'))) OR ((cast(cd_marital_status as VARCHAR(*)) = 'M') AND (cast(cd_education_status as VARCHAR(*)) = '4 yr Degree'))))
--------------------------filter(((((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Unknown')) OR ((customer_demographics.cd_marital_status = 'S') AND (customer_demographics.cd_education_status = 'College'))) OR ((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = '4 yr Degree'))))
----------------------------PhysicalOlapScan[customer_demographics]
------------------------PhysicalDistribute
--------------------------hashJoin[INNER_JOIN](store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ PhysicalResultSink
--------------------------------PhysicalOlapScan[catalog_sales]
------------------------------PhysicalDistribute
--------------------------------PhysicalProject
----------------------------------filter((cast(ca_state as VARCHAR(*)) = 'WV'))
----------------------------------filter((customer_address.ca_state = 'WV'))
------------------------------------PhysicalOlapScan[customer_address]
----------------------------PhysicalDistribute
------------------------------PhysicalProject
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ PhysicalResultSink
----------------------------------------------PhysicalOlapScan[store_sales]
--------------------------------------------PhysicalDistribute
----------------------------------------------PhysicalProject
------------------------------------------------filter((cast(d_quarter_name as VARCHAR(*)) = '2001Q1'))
------------------------------------------------filter((d1.d_quarter_name = '2001Q1'))
--------------------------------------------------PhysicalOlapScan[date_dim]
--------------------------------------PhysicalDistribute
----------------------------------------PhysicalProject
Expand Down
Loading

0 comments on commit 320fc14

Please sign in to comment.