Skip to content

Commit

Permalink
feat: support aggregation function in extended expression from/to poj…
Browse files Browse the repository at this point in the history
…o/proto
  • Loading branch information
davisusanibar committed Dec 6, 2023
1 parent 3d9b927 commit e790492
Show file tree
Hide file tree
Showing 7 changed files with 297 additions and 154 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.substrait.expression.Expression;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.extension.ExtensionCollector;
import io.substrait.proto.AggregateFunction;
import io.substrait.proto.ExpressionReference;
import io.substrait.proto.ExtendedExpression;
import io.substrait.type.proto.TypeProtoConverter;
Expand Down Expand Up @@ -37,11 +38,18 @@ public ExtendedExpression toProto(
builder.addReferredExpr(expressionReferenceBuilder);
} else if (expressionType
instanceof io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionType) {
throw new UnsupportedOperationException(
"Aggregate function types are not supported in conversion to proto Extended Expressions for now");
AggregateFunction measure =
((io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionType)
expressionType)
.getMeasure();
ExpressionReference.Builder expressionReferenceBuilder =
ExpressionReference.newBuilder()
.setMeasure(measure.toBuilder())
.addAllOutputNames(expressionReference.getOutputNames());
builder.addReferredExpr(expressionReferenceBuilder);
} else {
throw new UnsupportedOperationException(
"Only Expression or Aggregate Function type are supported in conversion to proto Extended Expressions for now");
"Only Expression or Aggregate Function type are supported in conversion to proto Extended Expressions");
}
}
builder.setBaseSchema(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import io.substrait.expression.Expression;
import io.substrait.expression.proto.ProtoExpressionConverter;
import io.substrait.extension.*;
import io.substrait.extension.ExtensionCollector;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.ImmutableExtensionLookup;
import io.substrait.extension.ImmutableSimpleExtension;
import io.substrait.extension.SimpleExtension;
import io.substrait.proto.AggregateFunction;
import io.substrait.proto.ExpressionReference;
import io.substrait.proto.NamedStruct;
import io.substrait.type.proto.ProtoTypeConverter;
Expand Down Expand Up @@ -36,8 +41,7 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp
NamedStruct baseSchemaProto = extendedExpression.getBaseSchema();

io.substrait.type.NamedStruct namedStruct =
io.substrait.type.NamedStruct.convertNamedStructProtoToPojo(
baseSchemaProto, protoTypeConverter);
io.substrait.type.NamedStruct.fromProto(baseSchemaProto, protoTypeConverter);

ProtoExpressionConverter protoExpressionConverter =
new ProtoExpressionConverter(
Expand All @@ -55,8 +59,12 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp
.addAllOutputNames(expressionReference.getOutputNamesList())
.build());
} else if (expressionReference.getExprTypeCase().getNumber() == 2) { // AggregateFunction
throw new UnsupportedOperationException(
"Aggregate function types are not supported in conversion from proto Extended Expressions for now");
AggregateFunction measure = expressionReference.getMeasure();
ImmutableExpressionReference.Builder builder =
ImmutableExpressionReference.builder()
.expressionType(ImmutableAggregateFunctionType.builder().measure(measure).build())
.addAllOutputNames(expressionReference.getOutputNamesList());
expressionReferences.add(builder.build());
} else {
throw new UnsupportedOperationException(
"Only Expression or Aggregate Function type are supported in conversion from proto Extended Expressions for now");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package io.substrait.relation;

import io.substrait.expression.FunctionArg;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.extension.ExtensionCollector;
import io.substrait.proto.AggregateFunction;
import io.substrait.type.proto.TypeProtoConverter;
import java.util.stream.IntStream;

/**
* Converts from {@link io.substrait.relation.Aggregate.Measure} to {@link
* io.substrait.proto.AggregateFunction}
*/
public class AggregateFunctionProtoController {

private final ExpressionProtoConverter exprProtoConverter;
private final TypeProtoConverter typeProtoConverter;
private final ExtensionCollector functionCollector;

public AggregateFunctionProtoController(ExtensionCollector functionCollector) {
this.functionCollector = functionCollector;
this.exprProtoConverter = new ExpressionProtoConverter(functionCollector, null);
this.typeProtoConverter = new TypeProtoConverter(functionCollector);
}

public AggregateFunction toProto(Aggregate.Measure measure) {
var argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter);
var args = measure.getFunction().arguments();
var aggFuncDef = measure.getFunction().declaration();

return AggregateFunction.newBuilder()
.setPhase(measure.getFunction().aggregationPhase().toProto())
.setInvocation(measure.getFunction().invocation().toProto())
.setOutputType(measure.getFunction().getType().accept(typeProtoConverter))
.addAllArguments(
IntStream.range(0, args.size())
.mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor))
.collect(java.util.stream.Collectors.toList()))
.setFunctionReference(
functionCollector.getFunctionReference(measure.getFunction().declaration()))
.build();
}
}
2 changes: 1 addition & 1 deletion core/src/main/java/io/substrait/type/NamedStruct.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ default io.substrait.proto.NamedStruct toProto(TypeProtoConverter typeProtoConve
.build();
}

static io.substrait.type.NamedStruct convertNamedStructProtoToPojo(
static io.substrait.type.NamedStruct fromProto(
io.substrait.proto.NamedStruct namedStruct, ProtoTypeConverter protoTypeConverter) {
var struct = namedStruct.getStruct();
return ImmutableNamedStruct.builder()
Expand Down

This file was deleted.

Loading

0 comments on commit e790492

Please sign in to comment.