Skip to content

Commit

Permalink
branch-2.1: [fix](function) fixed some nested type func's param type …
Browse files Browse the repository at this point in the history
…which is not suitable and make result wrong #44923 (#45798)

Cherry-picked from #44923
  • Loading branch information
amorynan authored Dec 24, 2024
1 parent 69704df commit 8b35b0e
Show file tree
Hide file tree
Showing 12 changed files with 208 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,16 @@
*/
public class ArrayApply extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), VarcharType.SYSTEM_DEFAULT,
new FollowToAnyDataType(0)));

public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), VarcharType.SYSTEM_DEFAULT,
new AnyDataType(0)));

/**
* constructor
*/
Expand Down Expand Up @@ -93,6 +98,13 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(2).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,14 @@
public class ArrayContains extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);
public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);

/**
* constructor with 2 arguments.
Expand Down Expand Up @@ -71,6 +75,13 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,16 @@
public class ArrayPosition extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);

public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);

/**
* constructor with 2 arguments.
*/
Expand Down Expand Up @@ -71,6 +76,13 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,16 @@
public class ArrayPushBack extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);

public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);

/**
* constructor with 1 argument.
*/
Expand All @@ -66,6 +71,13 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,16 @@
public class ArrayPushFront extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);

public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);

/**
* constructor with 1 argument.
*/
Expand All @@ -66,6 +71,13 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,16 @@
public class ArrayRemove extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.retArgType(0).args(
ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);

public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0).args(
ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);

/**
* constructor with 2 arguments.
*/
Expand All @@ -66,6 +71,13 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,16 @@
public class CountEqual extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);

public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);

/**
* constructor with 2 arguments.
*/
Expand Down Expand Up @@ -71,6 +76,16 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
// to find out element type in array vs param type,
// if they are different, return first array element type,
// else return least common type between element type and param
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,18 @@
public class MapContainsKey extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(MapType.of(new AnyDataType(0), AnyDataType.INSTANCE_WITHOUT_INDEX),
new FollowToAnyDataType(0))
);

public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(MapType.of(new AnyDataType(0), AnyDataType.INSTANCE_WITHOUT_INDEX),
new AnyDataType(0))
);

/**
* constructor with 2 arguments.
*/
Expand Down Expand Up @@ -72,6 +78,13 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isMapType()
&&
((MapType) getArgument(0).getDataType()).getKeyType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,18 @@
public class MapContainsValue extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(MapType.of(AnyDataType.INSTANCE_WITHOUT_INDEX, new AnyDataType(0)),
new FollowToAnyDataType(0))
);

public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(MapType.of(AnyDataType.INSTANCE_WITHOUT_INDEX, new AnyDataType(0)),
new AnyDataType(0))
);

/**
* constructor with 2 arguments.
*/
Expand Down Expand Up @@ -72,6 +78,13 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isMapType()
&&
((MapType) getArgument(0).getDataType()).getValueType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,41 @@ public DataType promotion() {
}
}

/**
* whether the param dataType is same-like type for nested in complex type
* same-like type means: string-like, date-like, number type
*/
public boolean isSameTypeForComplexTypeParam(DataType paramType) {
if (this.isArrayType() && paramType.isArrayType()) {
return ((ArrayType) this).getItemType()
.isSameTypeForComplexTypeParam(((ArrayType) paramType).getItemType());
} else if (this.isMapType() && paramType.isMapType()) {
MapType thisMapType = (MapType) this;
MapType otherMapType = (MapType) paramType;
return thisMapType.getKeyType().isSameTypeForComplexTypeParam(otherMapType.getKeyType())
&& thisMapType.getValueType().isSameTypeForComplexTypeParam(otherMapType.getValueType());
} else if (this.isStructType() && paramType.isStructType()) {
StructType thisStructType = (StructType) this;
StructType otherStructType = (StructType) paramType;
if (thisStructType.getFields().size() != otherStructType.getFields().size()) {
return false;
}
for (int i = 0; i < thisStructType.getFields().size(); i++) {
if (!thisStructType.getFields().get(i).getDataType().isSameTypeForComplexTypeParam(
otherStructType.getFields().get(i).getDataType())) {
return false;
}
}
return true;
} else if (this.isStringLikeType() && paramType.isStringLikeType()) {
return true;
} else if (this.isDateLikeType() && paramType.isDateLikeType()) {
return true;
} else {
return this.isNumericType() && paramType.isNumericType();
}
}

/** getAllPromotions */
public List<DataType> getAllPromotions() {
if (this instanceof ArrayType) {
Expand Down
27 changes: 27 additions & 0 deletions regression-test/data/nereids_function_p0/scalar_function/Array.out
Original file line number Diff line number Diff line change
Expand Up @@ -15579,3 +15579,30 @@ false
\N
\N

-- !sql --
0 0

-- !sql --
[258] []

-- !sql --
false false

-- !sql --
[257, 258] [258, 1, 2, 3]

-- !sql --
[1, 258, 257] [1, 2, 3, 258]

-- !sql --
[1, 258] [1, 2, 3]

-- !sql --
0 0

-- !sql --
false false

-- !sql --
false false

Original file line number Diff line number Diff line change
Expand Up @@ -1375,4 +1375,20 @@ suite("nereids_scalar_fn_Array") {
order_qt_sql_array_overlaps_5 """select arrays_overlap(b, c) from fn_test_array_with_large_decimal order by id"""
order_qt_sql_array_overlaps_6 """select arrays_overlap(c, b) from fn_test_array_with_large_decimal order by id"""

// tests for nereids array functions for number overflow cases
qt_sql """ SELECT array_position([1,258],257),array_position([2],258);"""
qt_sql """ select array_apply([258], '>' , 257), array_apply([1,2,3], '>', 258);"""
qt_sql """ select array_contains([258], 257), array_contains([1,2,3], 258);"""
// pushfront and pushback
qt_sql """ select array_pushfront([258], 257), array_pushfront([1,2,3], 258);"""
qt_sql """ select array_pushback([1,258], 257), array_pushback([1,2,3], 258);"""
// array_remove
qt_sql """ select array_remove([1,258], 257), array_remove([1,2,3], 258);"""
// countequal
qt_sql """ select countequal([1,258], 257), countequal([1,2,3], 258);"""
// map_contains_key
qt_sql """ select map_contains_key(map(1,258), 257), map_contains_key(map(2,1), 258);"""
// map_contains_value
qt_sql """ select map_contains_value(map(1,1), 257), map_contains_value(map(1,2), 258);"""

}

0 comments on commit 8b35b0e

Please sign in to comment.