Skip to content

Commit

Permalink
fix: code review core module
Browse files Browse the repository at this point in the history
  • Loading branch information
davisusanibar committed Nov 29, 2023
1 parent a677e47 commit 9c041b6
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 53 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.substrait.extended.expression;
package io.substrait.extendedexpression;

import io.substrait.expression.Expression;
import io.substrait.proto.AdvancedExtension;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.substrait.extended.expression;
package io.substrait.extendedexpression;

import io.substrait.expression.Expression;
import io.substrait.expression.proto.ExpressionProtoConverter;
Expand All @@ -7,44 +7,40 @@
import io.substrait.proto.ExtendedExpression;
import io.substrait.type.proto.TypeProtoConverter;

/**
* Converts from {@link io.substrait.extended.expression.ExtendedExpression} to {@link
* ExtendedExpression}
*/
/** Converts from {@link ExtendedExpression} to {@link ExtendedExpression} */
public class ExtendedExpressionProtoConverter {
public ExtendedExpression toProto(
io.substrait.extended.expression.ExtendedExpression extendedExpression) {
io.substrait.extendedexpression.ExtendedExpression extendedExpression) {

ExtendedExpression.Builder builder = ExtendedExpression.newBuilder();
ExtensionCollector functionCollector = new ExtensionCollector();

final ExpressionProtoConverter expressionProtoConverter =
new ExpressionProtoConverter(functionCollector, null);

for (io.substrait.extended.expression.ExtendedExpression.ExpressionReference
expressionReference : extendedExpression.getReferredExpr()) {
for (io.substrait.extendedexpression.ExtendedExpression.ExpressionReference
expressionReference : extendedExpression.getReferredExpressions()) {

io.substrait.proto.Expression expressionProto =
expressionProtoConverter.visit(
(Expression.ScalarFunctionInvocation) expressionReference.getReferredExpr());
(Expression.ScalarFunctionInvocation) expressionReference.getExpression());

ExpressionReference.Builder expressionReferenceBuilder =
ExpressionReference.newBuilder()
.setExpression(expressionProto)
.addAllOutputNames(expressionReference.getOutputNames());

extendedExpressionBuilder.addReferredExpr(expressionReferenceBuilder);
builder.addReferredExpr(expressionReferenceBuilder);
}
extendedExpressionBuilder.setBaseSchema(
builder.setBaseSchema(
extendedExpression.getBaseSchema().toProto(new TypeProtoConverter(functionCollector)));

// the process of adding simple extensions, such as extensionURIs and extensions, is handled on
// the fly
functionCollector.addExtensionsToExtendedExpression(extendedExpressionBuilder);
functionCollector.addExtensionsToExtendedExpression(builder);
if (extendedExpression.getAdvancedExtension().isPresent()) {
extendedExpressionBuilder.setAdvancedExtensions(
extendedExpression.getAdvancedExtension().get());
builder.setAdvancedExtensions(extendedExpression.getAdvancedExtension().get());
}
return extendedExpressionBuilder.build();
return builder.build();
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package io.substrait.extended.expression;
package io.substrait.extendedexpression;

import io.substrait.expression.Expression;
import io.substrait.expression.proto.ProtoExpressionConverter;
import io.substrait.extension.*;
import io.substrait.proto.ExpressionReference;
import io.substrait.proto.NamedStruct;
import io.substrait.type.ImmutableNamedStruct;
import io.substrait.type.Type;
import io.substrait.type.proto.ProtoTypeConverter;
import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -33,12 +31,13 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp
// fill in simple extension information through a discovery in the current proto-extended
// expression
ExtensionLookup functionLookup =
ImmutableExtensionLookup.builder()
.from(extendedExpression.getExtensionUrisList(), extendedExpression.getExtensionsList())
.build();
ImmutableExtensionLookup.builder().from(extendedExpression).build();

NamedStruct baseSchemaProto = extendedExpression.getBaseSchema();
io.substrait.type.NamedStruct namedStruct = convertNamedStrutProtoToPojo(baseSchemaProto);

io.substrait.type.NamedStruct namedStruct =
io.substrait.type.NamedStruct.convertNamedStructProtoToPojo(
baseSchemaProto, protoTypeConverter);

ProtoExpressionConverter protoExpressionConverter =
new ProtoExpressionConverter(
Expand All @@ -50,14 +49,14 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp
protoExpressionConverter.from(expressionReference.getExpression());
expressionReferences.add(
ImmutableExpressionReference.builder()
.referredExpr(expressionPojo)
.expression(expressionPojo)
.addAllOutputNames(expressionReference.getOutputNamesList())
.build());
}

ImmutableExtendedExpression.Builder builder =
ImmutableExtendedExpression.builder()
.referredExpr(expressionReferences)
.referredExpressions(expressionReferences)
.advancedExtension(
Optional.ofNullable(
extendedExpression.hasAdvancedExtensions()
Expand All @@ -66,19 +65,4 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp
.baseSchema(namedStruct);
return builder.build();
}

private io.substrait.type.NamedStruct convertNamedStrutProtoToPojo(NamedStruct namedStruct) {
var struct = namedStruct.getStruct();
return ImmutableNamedStruct.builder()
.names(namedStruct.getNamesList())
.struct(
Type.Struct.builder()
.fields(
struct.getTypesList().stream()
.map(protoTypeConverter::from)
.collect(java.util.stream.Collectors.toList()))
.nullable(ProtoTypeConverter.isNullable(struct.getNullability()))
.build())
.build();
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.substrait.extension;

import io.substrait.proto.ExtendedExpression;
import io.substrait.proto.Plan;
import io.substrait.proto.SimpleExtensionDeclaration;
import io.substrait.proto.SimpleExtensionURI;
import java.util.Collections;
Expand Down Expand Up @@ -31,7 +33,16 @@ public static class Builder {
private final Map<Integer, SimpleExtension.FunctionAnchor> functionMap = new HashMap<>();
private final Map<Integer, SimpleExtension.TypeAnchor> typeMap = new HashMap<>();

public Builder from(
public Builder from(Plan plan) {
return from(plan.getExtensionUrisList(), plan.getExtensionsList());
}

public Builder from(ExtendedExpression extendedExpression) {
return from(
extendedExpression.getExtensionUrisList(), extendedExpression.getExtensionsList());
}

private Builder from(
List<SimpleExtensionURI> simpleExtensionURIs,
List<SimpleExtensionDeclaration> simpleExtensionDeclarations) {
Map<Integer, String> namespaceMap = new HashMap<>();
Expand Down
5 changes: 1 addition & 4 deletions core/src/main/java/io/substrait/plan/ProtoPlanConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup)
}

public Plan from(io.substrait.proto.Plan plan) {
ExtensionLookup functionLookup =
ImmutableExtensionLookup.builder()
.from(plan.getExtensionUrisList(), plan.getExtensionsList())
.build();
ExtensionLookup functionLookup = ImmutableExtensionLookup.builder().from(plan).build();
ProtoRelConverter relConverter = getProtoRelConverter(functionLookup);
List<Plan.Root> roots = new ArrayList<>();
for (PlanRel planRel : plan.getRelationsList()) {
Expand Down
17 changes: 17 additions & 0 deletions core/src/main/java/io/substrait/type/NamedStruct.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.type;

import io.substrait.type.proto.ProtoTypeConverter;
import io.substrait.type.proto.TypeProtoConverter;
import java.util.List;
import org.immutables.value.Value;
Expand All @@ -21,4 +22,20 @@ default io.substrait.proto.NamedStruct toProto(TypeProtoConverter typeProtoConve
.addAllNames(names())
.build();
}

static io.substrait.type.NamedStruct convertNamedStructProtoToPojo(
io.substrait.proto.NamedStruct namedStruct, ProtoTypeConverter protoTypeConverter) {
var struct = namedStruct.getStruct();
return ImmutableNamedStruct.builder()
.names(namedStruct.getNamesList())
.struct(
Type.Struct.builder()
.fields(
struct.getTypesList().stream()
.map(protoTypeConverter::from)
.collect(java.util.stream.Collectors.toList()))
.nullable(ProtoTypeConverter.isNullable(struct.getNullability()))
.build())
.build();
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.substrait.extended.expression;
package io.substrait.extendedexpression;

import static org.junit.jupiter.api.Assertions.assertEquals;

Expand Down Expand Up @@ -36,7 +36,7 @@ public void toProtoTest() {

ImmutableExpressionReference expressionReference =
ImmutableExpressionReference.builder()
.referredExpr(scalarFunctionExpression.get())
.expression(scalarFunctionExpression.get())
.addOutputNames("new-column")
.build();

Expand All @@ -59,7 +59,7 @@ public void toProtoTest() {

ImmutableExtendedExpression.Builder extendedExpression =
ImmutableExtendedExpression.builder()
.referredExpr(expressionReferences)
.referredExpressions(expressionReferences)
.baseSchema(namedStruct);

// convert POJO extended expression into PROTOBUF extended expression
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.substrait.extended.expression;
package io.substrait.extendedexpression;

import io.substrait.TestBase;
import io.substrait.expression.Expression;
Expand Down Expand Up @@ -37,11 +37,11 @@ public void fromTest() throws IOException {

ImmutableExpressionReference expressionReference =
ImmutableExpressionReference.builder()
.referredExpr(scalarFunctionExpression.get())
.expression(scalarFunctionExpression.get())
.addOutputNames("new-column")
.build();

List<io.substrait.extended.expression.ExtendedExpression.ExpressionReference>
List<io.substrait.extendedexpression.ExtendedExpression.ExpressionReference>
expressionReferences = new ArrayList<>();
expressionReferences.add(expressionReference);

Expand All @@ -62,7 +62,7 @@ public void fromTest() throws IOException {
// pojo initial extended expression
ImmutableExtendedExpression extendedExpressionPojoInitial =
ImmutableExtendedExpression.builder()
.referredExpr(expressionReferences)
.referredExpressions(expressionReferences)
.baseSchema(namedStruct)
.build();

Expand All @@ -71,7 +71,7 @@ public void fromTest() throws IOException {
new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial);

// pojo final extended expression
io.substrait.extended.expression.ExtendedExpression extendedExpressionPojoFinal =
io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal =
new ProtoExtendedExpressionConverter().from(extendedExpressionProto);

// validate extended expression pojo initial equals to final roundtrip
Expand Down

0 comments on commit 9c041b6

Please sign in to comment.