Skip to content

Commit

Permalink
fix: support any<n>? type syntax in function extensions (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
guiyanakuang authored Sep 22, 2023
1 parent 6e82f39 commit 16e5604
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 11 deletions.
2 changes: 1 addition & 1 deletion core/src/main/antlr/SubstraitType.g4
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ expr
| Identifier Eq expr Newline+ (Identifier Eq expr Newline+)* finalType=type Newline* #MultilineDefinition
| type #TypeLiteral
| number=Number #LiteralNumber
| identifier=Identifier #TypeParam
| identifier=Identifier isnull='?'? #TypeParam
| Identifier OParen (expr (Comma expr)*)? CParen #FunctionCall
| left=expr op=(And | Or | Plus | Minus | Lt | Gt | Eq | NotEquals | Lte | Gte | Asterisk | ForwardSlash) right=expr #BinaryExpr
| If ifExpr=expr Then thenExpr=expr Else elseExpr=expr #IfExpr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ <R, E extends Throwable> R accept(final ParameterizedTypeVisitor<R, E> parameter
}

@Value.Immutable
abstract static class StringLiteral extends BaseParameterizedType {
abstract static class StringLiteral extends BaseParameterizedType implements NullableType {
public abstract String value();

public static ImmutableParameterizedType.StringLiteral.Builder builder() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,40 @@ protected ParameterizedTypeCreator(boolean nullable) {
super(nullable);
}

private static ParameterizedType.StringLiteral parameter(String literal, boolean nullable) {
return ParameterizedType.StringLiteral.builder().nullable(nullable).value(literal).build();
}

public ParameterizedType.StringLiteral parameter(String literal) {
return ParameterizedType.StringLiteral.builder().value(literal).build();
return parameter(literal, nullable);
}

public ParameterizedType fixedCharE(String len) {
return ParameterizedType.FixedChar.builder().nullable(nullable).length(parameter(len)).build();
return ParameterizedType.FixedChar.builder()
.nullable(nullable)
.length(parameter(len, false))
.build();
}

public ParameterizedType varCharE(String len) {
return ParameterizedType.VarChar.builder().nullable(nullable).length(parameter(len)).build();
return ParameterizedType.VarChar.builder()
.nullable(nullable)
.length(parameter(len, false))
.build();
}

public ParameterizedType fixedBinaryE(String len) {
return ParameterizedType.FixedBinary.builder()
.nullable(nullable)
.length(parameter(len))
.length(parameter(len, false))
.build();
}

public ParameterizedType decimalE(String precision, String scale) {
return ParameterizedType.Decimal.builder()
.nullable(nullable)
.precision(parameter(precision))
.scale(parameter(scale))
.precision(parameter(precision, false))
.scale(parameter(scale, false))
.build();
}

Expand Down
8 changes: 6 additions & 2 deletions core/src/main/java/io/substrait/type/parser/ParseToPojo.java
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,11 @@ public TypeExpression visitType(final SubstraitTypeParser.TypeContext ctx) {
@Override
public TypeExpression visitTypeParam(final SubstraitTypeParser.TypeParamContext ctx) {
checkParameterizedOrExpression();
return ParameterizedType.StringLiteral.builder().value(ctx.getText()).build();
boolean nullable = ctx.isnull != null;
return ParameterizedType.StringLiteral.builder()
.nullable(nullable)
.value(ctx.getText())
.build();
}

@Override
Expand Down Expand Up @@ -436,7 +440,7 @@ public TypeExpression visitNumericLiteral(final SubstraitTypeParser.NumericLiter
public TypeExpression visitNumericParameterName(
final SubstraitTypeParser.NumericParameterNameContext ctx) {
checkParameterizedOrExpression();
return ParameterizedType.StringLiteral.builder().value(ctx.getText()).build();
return ParameterizedType.StringLiteral.builder().nullable(false).value(ctx.getText()).build();
}

@Override
Expand Down
27 changes: 27 additions & 0 deletions core/src/test/java/io/substrait/extension/TypeExtensionTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.extension;

import static io.substrait.type.TypeCreator.REQUIRED;
import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.dsl.SubstraitBuilder;
Expand Down Expand Up @@ -76,4 +77,30 @@ void roundtripCustomType() {
var planReturned = protoPlanConverter.from(protoPlan);
assertEquals(plan, planReturned);
}

@Test
void roundtripNumberedAnyTypes() {
List<String> tableName = Stream.of("example").collect(Collectors.toList());
List<String> columnNames =
Stream.of("array_i64_type_column", "array_i64_column").collect(Collectors.toList());
List<io.substrait.type.Type> types =
Stream.of(REQUIRED.list(R.I64)).collect(Collectors.toList());

Plan plan =
b.plan(
b.root(
b.project(
input ->
Stream.of(
b.scalarFn(
NAMESPACE,
"array_index:list_i64",
R.I64,
b.fieldReference(input, 0)))
.collect(Collectors.toList()),
b.namedScan(tableName, columnNames, types))));
var protoPlan = planProtoConverter.toProto(plan);
var planReturned = protoPlanConverter.from(protoPlan);
assertEquals(plan, planReturned);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ private void compoundTests(ParseToPojo.Visitor v) {
private <T> void parameterizedTests(ParseToPojo.Visitor v) {
test(v, pn.listE(pr.parameter("K")), "List?<K>");
test(v, pr.structE(r.I8, r.I16, n.I8, pr.parameter("K")), "STRUCT<i8, i16, i8?, K>");
test(v, pn.parameter("any"), "any");
test(v, pr.parameter("any"), "any");
test(v, pn.parameter("any"), "any?");
test(v, pn.listE(pr.parameter("any")), "list?<any>");
test(v, pn.listE(pn.parameter("any")), "list?<any?>");
test(v, pn.structE(r.I8, r.I16, n.I8, pr.parameter("K")), "STRUCT?<i8, i16, i8?, K>");
Expand Down
9 changes: 9 additions & 0 deletions core/src/test/resources/extensions/custom_extensions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,12 @@ scalar_functions:
- name: arg1
value: i64
return: u!customType2
- name: "array_index"
description: "returns the element in the array at index, or NULL if index is out of bounds"
impls:
- args:
- name: array
value: list<any1>
- name: index
value: i64
return: any1?

0 comments on commit 16e5604

Please sign in to comment.