diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java similarity index 94% rename from core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java rename to core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java index 2aee599c2..de405f9a3 100644 --- a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java @@ -1,4 +1,4 @@ -package io.substrait.extended.expression; +package io.substrait.extendedexpression; import io.substrait.expression.Expression; import io.substrait.proto.AdvancedExtension; diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java similarity index 66% rename from core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java rename to core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java index f3d8441ae..a123e4b9f 100644 --- a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java @@ -1,4 +1,4 @@ -package io.substrait.extended.expression; +package io.substrait.extendedexpression; import io.substrait.expression.Expression; import io.substrait.expression.proto.ExpressionProtoConverter; @@ -7,13 +7,10 @@ import io.substrait.proto.ExtendedExpression; import io.substrait.type.proto.TypeProtoConverter; -/** - * Converts from {@link io.substrait.extended.expression.ExtendedExpression} to {@link - * ExtendedExpression} - */ +/** Converts from {@link ExtendedExpression} to {@link ExtendedExpression} */ public class ExtendedExpressionProtoConverter { public ExtendedExpression toProto( - io.substrait.extended.expression.ExtendedExpression extendedExpression) { + io.substrait.extendedexpression.ExtendedExpression extendedExpression) { ExtendedExpression.Builder builder = ExtendedExpression.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); @@ -21,30 +18,29 @@ public ExtendedExpression toProto( final ExpressionProtoConverter expressionProtoConverter = new ExpressionProtoConverter(functionCollector, null); - for (io.substrait.extended.expression.ExtendedExpression.ExpressionReference - expressionReference : extendedExpression.getReferredExpr()) { + for (io.substrait.extendedexpression.ExtendedExpression.ExpressionReference + expressionReference : extendedExpression.getReferredExpressions()) { io.substrait.proto.Expression expressionProto = expressionProtoConverter.visit( - (Expression.ScalarFunctionInvocation) expressionReference.getReferredExpr()); + (Expression.ScalarFunctionInvocation) expressionReference.getExpression()); ExpressionReference.Builder expressionReferenceBuilder = ExpressionReference.newBuilder() .setExpression(expressionProto) .addAllOutputNames(expressionReference.getOutputNames()); - extendedExpressionBuilder.addReferredExpr(expressionReferenceBuilder); + builder.addReferredExpr(expressionReferenceBuilder); } - extendedExpressionBuilder.setBaseSchema( + builder.setBaseSchema( extendedExpression.getBaseSchema().toProto(new TypeProtoConverter(functionCollector))); // the process of adding simple extensions, such as extensionURIs and extensions, is handled on // the fly - functionCollector.addExtensionsToExtendedExpression(extendedExpressionBuilder); + functionCollector.addExtensionsToExtendedExpression(builder); if (extendedExpression.getAdvancedExtension().isPresent()) { - extendedExpressionBuilder.setAdvancedExtensions( - extendedExpression.getAdvancedExtension().get()); + builder.setAdvancedExtensions(extendedExpression.getAdvancedExtension().get()); } - return extendedExpressionBuilder.build(); + return builder.build(); } } diff --git a/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java similarity index 69% rename from core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java rename to core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java index 14c82b5ac..3af6ee20d 100644 --- a/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java @@ -1,12 +1,10 @@ -package io.substrait.extended.expression; +package io.substrait.extendedexpression; import io.substrait.expression.Expression; import io.substrait.expression.proto.ProtoExpressionConverter; import io.substrait.extension.*; import io.substrait.proto.ExpressionReference; import io.substrait.proto.NamedStruct; -import io.substrait.type.ImmutableNamedStruct; -import io.substrait.type.Type; import io.substrait.type.proto.ProtoTypeConverter; import java.io.IOException; import java.util.ArrayList; @@ -33,12 +31,13 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp // fill in simple extension information through a discovery in the current proto-extended // expression ExtensionLookup functionLookup = - ImmutableExtensionLookup.builder() - .from(extendedExpression.getExtensionUrisList(), extendedExpression.getExtensionsList()) - .build(); + ImmutableExtensionLookup.builder().from(extendedExpression).build(); NamedStruct baseSchemaProto = extendedExpression.getBaseSchema(); - io.substrait.type.NamedStruct namedStruct = convertNamedStrutProtoToPojo(baseSchemaProto); + + io.substrait.type.NamedStruct namedStruct = + io.substrait.type.NamedStruct.convertNamedStructProtoToPojo( + baseSchemaProto, protoTypeConverter); ProtoExpressionConverter protoExpressionConverter = new ProtoExpressionConverter( @@ -50,14 +49,14 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp protoExpressionConverter.from(expressionReference.getExpression()); expressionReferences.add( ImmutableExpressionReference.builder() - .referredExpr(expressionPojo) + .expression(expressionPojo) .addAllOutputNames(expressionReference.getOutputNamesList()) .build()); } ImmutableExtendedExpression.Builder builder = ImmutableExtendedExpression.builder() - .referredExpr(expressionReferences) + .referredExpressions(expressionReferences) .advancedExtension( Optional.ofNullable( extendedExpression.hasAdvancedExtensions() @@ -66,19 +65,4 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp .baseSchema(namedStruct); return builder.build(); } - - private io.substrait.type.NamedStruct convertNamedStrutProtoToPojo(NamedStruct namedStruct) { - var struct = namedStruct.getStruct(); - return ImmutableNamedStruct.builder() - .names(namedStruct.getNamesList()) - .struct( - Type.Struct.builder() - .fields( - struct.getTypesList().stream() - .map(protoTypeConverter::from) - .collect(java.util.stream.Collectors.toList())) - .nullable(ProtoTypeConverter.isNullable(struct.getNullability())) - .build()) - .build(); - } } diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index c88bafc1c..70034d9b1 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -1,5 +1,7 @@ package io.substrait.extension; +import io.substrait.proto.ExtendedExpression; +import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; import io.substrait.proto.SimpleExtensionURI; import java.util.Collections; @@ -31,7 +33,16 @@ public static class Builder { private final Map functionMap = new HashMap<>(); private final Map typeMap = new HashMap<>(); - public Builder from( + public Builder from(Plan plan) { + return from(plan.getExtensionUrisList(), plan.getExtensionsList()); + } + + public Builder from(ExtendedExpression extendedExpression) { + return from( + extendedExpression.getExtensionUrisList(), extendedExpression.getExtensionsList()); + } + + private Builder from( List simpleExtensionURIs, List simpleExtensionDeclarations) { Map namespaceMap = new HashMap<>(); diff --git a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java index 7222eb7ed..be4f4ad9f 100644 --- a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java +++ b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java @@ -32,10 +32,7 @@ protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) } public Plan from(io.substrait.proto.Plan plan) { - ExtensionLookup functionLookup = - ImmutableExtensionLookup.builder() - .from(plan.getExtensionUrisList(), plan.getExtensionsList()) - .build(); + ExtensionLookup functionLookup = ImmutableExtensionLookup.builder().from(plan).build(); ProtoRelConverter relConverter = getProtoRelConverter(functionLookup); List roots = new ArrayList<>(); for (PlanRel planRel : plan.getRelationsList()) { diff --git a/core/src/main/java/io/substrait/type/NamedStruct.java b/core/src/main/java/io/substrait/type/NamedStruct.java index 8bf345aa9..11fdd38ad 100644 --- a/core/src/main/java/io/substrait/type/NamedStruct.java +++ b/core/src/main/java/io/substrait/type/NamedStruct.java @@ -1,5 +1,6 @@ package io.substrait.type; +import io.substrait.type.proto.ProtoTypeConverter; import io.substrait.type.proto.TypeProtoConverter; import java.util.List; import org.immutables.value.Value; @@ -21,4 +22,20 @@ default io.substrait.proto.NamedStruct toProto(TypeProtoConverter typeProtoConve .addAllNames(names()) .build(); } + + static io.substrait.type.NamedStruct convertNamedStructProtoToPojo( + io.substrait.proto.NamedStruct namedStruct, ProtoTypeConverter protoTypeConverter) { + var struct = namedStruct.getStruct(); + return ImmutableNamedStruct.builder() + .names(namedStruct.getNamesList()) + .struct( + Type.Struct.builder() + .fields( + struct.getTypesList().stream() + .map(protoTypeConverter::from) + .collect(java.util.stream.Collectors.toList())) + .nullable(ProtoTypeConverter.isNullable(struct.getNullability())) + .build()) + .build(); + } } diff --git a/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java similarity index 94% rename from core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java rename to core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java index 20079e24f..fbe3526eb 100644 --- a/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java +++ b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java @@ -1,4 +1,4 @@ -package io.substrait.extended.expression; +package io.substrait.extendedexpression; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -36,7 +36,7 @@ public void toProtoTest() { ImmutableExpressionReference expressionReference = ImmutableExpressionReference.builder() - .referredExpr(scalarFunctionExpression.get()) + .expression(scalarFunctionExpression.get()) .addOutputNames("new-column") .build(); @@ -59,7 +59,7 @@ public void toProtoTest() { ImmutableExtendedExpression.Builder extendedExpression = ImmutableExtendedExpression.builder() - .referredExpr(expressionReferences) + .referredExpressions(expressionReferences) .baseSchema(namedStruct); // convert POJO extended expression into PROTOBUF extended expression diff --git a/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java b/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java similarity index 90% rename from core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java rename to core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java index 9ab84f274..69a03a90a 100644 --- a/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java +++ b/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java @@ -1,4 +1,4 @@ -package io.substrait.extended.expression; +package io.substrait.extendedexpression; import io.substrait.TestBase; import io.substrait.expression.Expression; @@ -37,11 +37,11 @@ public void fromTest() throws IOException { ImmutableExpressionReference expressionReference = ImmutableExpressionReference.builder() - .referredExpr(scalarFunctionExpression.get()) + .expression(scalarFunctionExpression.get()) .addOutputNames("new-column") .build(); - List + List expressionReferences = new ArrayList<>(); expressionReferences.add(expressionReference); @@ -62,7 +62,7 @@ public void fromTest() throws IOException { // pojo initial extended expression ImmutableExtendedExpression extendedExpressionPojoInitial = ImmutableExtendedExpression.builder() - .referredExpr(expressionReferences) + .referredExpressions(expressionReferences) .baseSchema(namedStruct) .build(); @@ -71,7 +71,7 @@ public void fromTest() throws IOException { new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); // pojo final extended expression - io.substrait.extended.expression.ExtendedExpression extendedExpressionPojoFinal = + io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = new ProtoExtendedExpressionConverter().from(extendedExpressionProto); // validate extended expression pojo initial equals to final roundtrip