Skip to content

Commit

Permalink
fix: code review core module testing side
Browse files Browse the repository at this point in the history
  • Loading branch information
davisusanibar committed Nov 30, 2023
1 parent b1c96bd commit 3d9b927
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io.substrait.expression.Expression;
import io.substrait.proto.AdvancedExtension;
import io.substrait.proto.AggregateFunction;
import io.substrait.type.NamedStruct;
import java.util.List;
import java.util.Optional;
Expand All @@ -21,8 +22,20 @@ public abstract class ExtendedExpression {

@Value.Immutable
public abstract static class ExpressionReference {
public abstract Expression getExpression();
public abstract ExpressionTypeReference getExpressionType();

public abstract List<String> getOutputNames();
}

public abstract static class ExpressionTypeReference {}

@Value.Immutable
public abstract static class ExpressionType extends ExpressionTypeReference {
public abstract Expression getExpression();
}

@Value.Immutable
public abstract static class AggregateFunctionType extends ExpressionTypeReference {
public abstract AggregateFunction getMeasure();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,29 @@ public ExtendedExpression toProto(

for (io.substrait.extendedexpression.ExtendedExpression.ExpressionReference
expressionReference : extendedExpression.getReferredExpressions()) {

io.substrait.proto.Expression expressionProto =
expressionProtoConverter.visit(
(Expression.ScalarFunctionInvocation) expressionReference.getExpression());

ExpressionReference.Builder expressionReferenceBuilder =
ExpressionReference.newBuilder()
.setExpression(expressionProto)
.addAllOutputNames(expressionReference.getOutputNames());

builder.addReferredExpr(expressionReferenceBuilder);
io.substrait.extendedexpression.ExtendedExpression.ExpressionTypeReference expressionType =
expressionReference.getExpressionType();
if (expressionType
instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionType) {
io.substrait.proto.Expression expressionProto =
expressionProtoConverter.visit(
(Expression.ScalarFunctionInvocation)
((io.substrait.extendedexpression.ExtendedExpression.ExpressionType)
expressionType)
.getExpression());
ExpressionReference.Builder expressionReferenceBuilder =
ExpressionReference.newBuilder()
.setExpression(expressionProto)
.addAllOutputNames(expressionReference.getOutputNames());
builder.addReferredExpr(expressionReferenceBuilder);
} else if (expressionType
instanceof io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionType) {
throw new UnsupportedOperationException(
"Aggregate function types are not supported in conversion to proto Extended Expressions for now");
} else {
throw new UnsupportedOperationException(
"Only Expression or Aggregate Function type are supported in conversion to proto Extended Expressions for now");
}
}
builder.setBaseSchema(
extendedExpression.getBaseSchema().toProto(new TypeProtoConverter(functionCollector)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,22 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp

List<ExtendedExpression.ExpressionReference> expressionReferences = new ArrayList<>();
for (ExpressionReference expressionReference : extendedExpression.getReferredExprList()) {
Expression expressionPojo =
protoExpressionConverter.from(expressionReference.getExpression());
expressionReferences.add(
ImmutableExpressionReference.builder()
.expression(expressionPojo)
.addAllOutputNames(expressionReference.getOutputNamesList())
.build());
if (expressionReference.getExprTypeCase().getNumber() == 1) { // Expression
Expression expressionPojo =
protoExpressionConverter.from(expressionReference.getExpression());
expressionReferences.add(
ImmutableExpressionReference.builder()
.expressionType(
ImmutableExpressionType.builder().expression(expressionPojo).build())
.addAllOutputNames(expressionReference.getOutputNamesList())
.build());
} else if (expressionReference.getExprTypeCase().getNumber() == 2) { // AggregateFunction
throw new UnsupportedOperationException(
"Aggregate function types are not supported in conversion from proto Extended Expressions for now");
} else {
throw new UnsupportedOperationException(
"Only Expression or Aggregate Function type are supported in conversion from proto Extended Expressions for now");
}
}

ImmutableExtendedExpression.Builder builder =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,35 @@
import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.TestBase;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FieldReference;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.expression.*;
import io.substrait.type.ImmutableNamedStruct;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.junit.jupiter.api.Test;

public class ExtendedExpressionProtoConverterTest extends TestBase {
static final String NAMESPACE = "/functions_arithmetic_decimal.yaml";

@Test
public void toProtoTest() {
// create predefined POJO extended expression
Optional<Expression.ScalarFunctionInvocation> scalarFunctionExpression =
defaultExtensionCollection.scalarFunctions().stream()
.filter(s -> s.name().equalsIgnoreCase("add"))
.findFirst()
.map(
declaration ->
ExpressionCreator.scalarFunction(
declaration,
TypeCreator.REQUIRED.BOOLEAN,
ImmutableFieldReference.builder()
.addSegments(FieldReference.StructField.of(0))
.type(TypeCreator.REQUIRED.decimal(10, 2))
.build(),
ExpressionCreator.i32(false, 183)));
Expression.ScalarFunctionInvocation scalarFunctionInvocation =
b.scalarFn(
NAMESPACE,
"add:dec_dec",
TypeCreator.REQUIRED.BOOLEAN,
ImmutableFieldReference.builder()
.addSegments(FieldReference.StructField.of(0))
.type(TypeCreator.REQUIRED.decimal(10, 2))
.build(),
ExpressionCreator.i32(false, 183));

ImmutableExpressionReference expressionReference =
ImmutableExpressionReference.builder()
.expression(scalarFunctionExpression.get())
.expressionType(
ImmutableExpressionType.builder().expression(scalarFunctionInvocation).build())
.addOutputNames("new-column")
.build();

Expand Down Expand Up @@ -66,8 +61,7 @@ public void toProtoTest() {
io.substrait.proto.ExtendedExpression proto =
new ExtendedExpressionProtoConverter().toProto(extendedExpression.build());

assertEquals(
"/functions_arithmetic_decimal.yaml", proto.getExtensionUrisList().get(0).getUri());
assertEquals(NAMESPACE, proto.getExtensionUrisList().get(0).getUri());
assertEquals("add:dec_dec", proto.getExtensionsList().get(0).getExtensionFunction().getName());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,30 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class ProtoExtendedExpressionConverterTest extends TestBase {
static final String NAMESPACE = "/functions_arithmetic_decimal.yaml";

@Test
public void fromTest() throws IOException {
// create predefined POJO extended expression
Optional<Expression.ScalarFunctionInvocation> scalarFunctionExpression =
defaultExtensionCollection.scalarFunctions().stream()
.filter(s -> s.name().equalsIgnoreCase("add"))
.findFirst()
.map(
declaration ->
ExpressionCreator.scalarFunction(
declaration,
TypeCreator.REQUIRED.BOOLEAN,
ImmutableFieldReference.builder()
.addSegments(FieldReference.StructField.of(0))
.type(TypeCreator.REQUIRED.decimal(10, 2))
.build(),
ExpressionCreator.i32(false, 183)));
Expression.ScalarFunctionInvocation scalarFunctionInvocation =
b.scalarFn(
NAMESPACE,
"add:dec_dec",
TypeCreator.REQUIRED.BOOLEAN,
ImmutableFieldReference.builder()
.addSegments(FieldReference.StructField.of(0))
.type(TypeCreator.REQUIRED.decimal(10, 2))
.build(),
ExpressionCreator.i32(false, 183));

ImmutableExpressionReference expressionReference =
ImmutableExpressionReference.builder()
.expression(scalarFunctionExpression.get())
.expressionType(
ImmutableExpressionType.builder().expression(scalarFunctionInvocation).build())
.addOutputNames("new-column")
.build();

Expand Down

0 comments on commit 3d9b927

Please sign in to comment.