Skip to content

Commit

Permalink
refactor: unify aggregate and window functions in window handling (#170)
Browse files Browse the repository at this point in the history
feat: allow window function operands beyond just column references
fix: invocation was not set when building WindowFunction proto message

BREAKING CHANGE:
* Window and WindowFunction have been merged into WindowFunctionInvocation
* WindowFunctionInvocation is now an inner class of expression
* WindowFunctionInvocation now implements Expression.
* Expression visitor now visits WindowFunctionInvocation
* WindowFunctionConvert constructor requires fewer arguments
* windowCreator has additional fields
* POJO WindowBound offset must now be a long, better reflecting spec
  • Loading branch information
vbarua committed Sep 6, 2023
1 parent 24aefbd commit 8fd7c86
Show file tree
Hide file tree
Showing 17 changed files with 313 additions and 425 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ public OUTPUT visit(Expression.ScalarFunctionInvocation expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.WindowFunctionInvocation expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.Cast expr) throws EXCEPTION {
return visitFallback(expr);
Expand Down Expand Up @@ -173,9 +178,4 @@ public OUTPUT visit(Expression.ScalarSubquery expr) throws EXCEPTION {
public OUTPUT visit(Expression.InPredicate expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.Window expr) throws EXCEPTION {
return visitFallback(expr);
}
}
74 changes: 36 additions & 38 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
import com.google.protobuf.ByteString;
import io.substrait.extension.SimpleExtension;
import io.substrait.proto.AggregateFunction;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Rel;
import io.substrait.type.Type;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import javax.annotation.Nullable;
import org.immutables.value.Value;

@Value.Enclosing
Expand Down Expand Up @@ -573,6 +570,42 @@ public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws
}
}

@Value.Immutable
abstract class WindowFunctionInvocation implements Expression {

public abstract SimpleExtension.WindowFunctionVariant declaration();

public abstract List<FunctionArg> arguments();

public abstract Map<String, FunctionOption> options();

public abstract AggregationPhase aggregationPhase();

public abstract List<Expression> partitionBy();

public abstract List<SortField> sort();

public abstract WindowBound lowerBound();

public abstract WindowBound upperBound();

public abstract Type outputType();

public Type getType() {
return outputType();
}

public abstract AggregationInvocation invocation();

public static ImmutableExpression.WindowFunctionInvocation.Builder builder() {
return ImmutableExpression.WindowFunctionInvocation.builder();
}

public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws E {
return visitor.visit(this);
}
}

@Value.Immutable
abstract static class SingleOrList implements Expression {
public abstract Expression condition();
Expand Down Expand Up @@ -684,41 +717,6 @@ public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws
}
}

@Value.Immutable
abstract static class Window implements Expression {
@Nullable
public abstract Aggregate.Measure aggregateFunction();

@Nullable
public abstract WindowFunction windowFunction();

public abstract List<Expression> partitionBy();

public abstract List<SortField> orderBy();

public abstract WindowBound lowerBound();

public abstract WindowBound upperBound();

public abstract boolean hasNormalAggregateFunction();

public static ImmutableExpression.Window.Builder builder() {
return ImmutableExpression.Window.builder();
}

public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws E {
return visitor.visit(this);
}
}

@Value.Immutable
public abstract static class WindowFunction {
public abstract WindowFunctionInvocation getFunction();

public abstract Optional<Expression> getPreMeasureFilter();
/** public static ImmutableMeasure.Builder builder() { return ImmutableMeasure.builder(); } */
}

enum PredicateOp {
PREDICATE_OP_UNSPECIFIED(
io.substrait.proto.Expression.Subquery.SetPredicate.PredicateOp.PREDICATE_OP_UNSPECIFIED),
Expand Down
20 changes: 16 additions & 4 deletions core/src/main/java/io/substrait/expression/ExpressionCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -313,36 +313,48 @@ public static AggregateFunctionInvocation aggregateFunction(
.build();
}

public static WindowFunctionInvocation windowFunction(
public static Expression.WindowFunctionInvocation windowFunction(
SimpleExtension.WindowFunctionVariant declaration,
Type outputType,
Expression.AggregationPhase phase,
List<Expression.SortField> sort,
Expression.AggregationInvocation invocation,
List<Expression> partitionBy,
WindowBound lowerBound,
WindowBound upperBound,
Iterable<? extends FunctionArg> arguments) {
return WindowFunctionInvocation.builder()
return Expression.WindowFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.aggregationPhase(phase)
.sort(sort)
.partitionBy(partitionBy)
.lowerBound(lowerBound)
.upperBound(upperBound)
.invocation(invocation)
.addAllArguments(arguments)
.build();
}

public static WindowFunctionInvocation windowFunction(
public static Expression.WindowFunctionInvocation windowFunction(
SimpleExtension.WindowFunctionVariant declaration,
Type outputType,
Expression.AggregationPhase phase,
List<Expression.SortField> sort,
Expression.AggregationInvocation invocation,
List<Expression> partitionBy,
WindowBound lowerBound,
WindowBound upperBound,
FunctionArg... arguments) {
return WindowFunctionInvocation.builder()
return Expression.WindowFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.aggregationPhase(phase)
.sort(sort)
.invocation(invocation)
.partitionBy(partitionBy)
.lowerBound(lowerBound)
.upperBound(upperBound)
.addArguments(arguments)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ public interface ExpressionVisitor<R, E extends Throwable> {

R visit(Expression.ScalarFunctionInvocation expr) throws E;

R visit(Expression.WindowFunctionInvocation expr) throws E;

R visit(Expression.Cast expr) throws E;

R visit(Expression.SingleOrList expr) throws E;
Expand All @@ -70,6 +72,4 @@ public interface ExpressionVisitor<R, E extends Throwable> {
R visit(Expression.ScalarSubquery expr) throws E;

R visit(Expression.InPredicate expr) throws E;

R visit(Expression.Window expr) throws E;
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public BoundedKind boundedKind() {

public abstract Direction direction();

public abstract Expression offset();
public abstract long offset();

public static ImmutableWindowBound.BoundedWindowBound.Builder builder() {
return ImmutableWindowBound.BoundedWindowBound.builder();
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package io.substrait.expression.proto;

import io.substrait.expression.AbstractExpressionVisitor;
import io.substrait.expression.ExpressionVisitor;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.WindowBound;
import io.substrait.extension.ExtensionCollector;
import io.substrait.proto.Expression;
import io.substrait.proto.FunctionArgument;
import io.substrait.proto.Rel;
import io.substrait.proto.SortField;
import io.substrait.proto.Type;
import io.substrait.relation.RelVisitor;
import io.substrait.type.proto.TypeProtoConverter;
import java.util.List;
Expand Down Expand Up @@ -415,99 +416,48 @@ public Expression visit(io.substrait.expression.Expression.InPredicate expr)
.build();
}

public Expression visit(io.substrait.expression.Expression.Window expr) throws RuntimeException {
var partExps =
public Expression visit(io.substrait.expression.Expression.WindowFunctionInvocation expr)
throws RuntimeException {
var argVisitor = FunctionArg.toProto(typeProtoConverter, this);
List<FunctionArgument> args =
expr.arguments().stream()
.map(a -> a.accept(expr.declaration(), 0, argVisitor))
.collect(java.util.stream.Collectors.toList());
Type outputType = expr.getType().accept(typeProtoConverter);

List<Expression> partitionExprs =
expr.partitionBy().stream()
.map(e -> e.accept(this))
.collect(java.util.stream.Collectors.toList());
var outputType = expr.getType().accept(typeProtoConverter);
var builder = Expression.WindowFunction.newBuilder().setOutputType(outputType);
if (expr.hasNormalAggregateFunction()) {
var aggMeasureFunc = expr.aggregateFunction().getFunction();
var funcReference = extensionCollector.getFunctionReference(aggMeasureFunc.declaration());
var argVisitor = FunctionArg.toProto(typeProtoConverter, this);
var args =
aggMeasureFunc.arguments().stream()
.map(a -> a.accept(aggMeasureFunc.declaration(), 0, argVisitor))
.collect(java.util.stream.Collectors.toList());
builder
.setFunctionReference(funcReference)
.setPhase(aggMeasureFunc.aggregationPhase().toProto())
.addAllArguments(args);
} else {
var windowFunc = expr.windowFunction().getFunction();
var funcReference = extensionCollector.getFunctionReference(windowFunc.declaration());
var argVisitor = FunctionArg.toProto(typeProtoConverter, this);
var args =
windowFunc.arguments().stream()
.map(a -> a.accept(windowFunc.declaration(), 0, argVisitor))
.collect(java.util.stream.Collectors.toList());
builder
.setFunctionReference(funcReference)
.setPhase(windowFunc.aggregationPhase().toProto())
.addAllArguments(args);
}
var sortFields =
expr.orderBy().stream()

List<SortField> sortFields =
expr.sort().stream()
.map(
s -> {
return SortField.newBuilder()
.setDirection(s.direction().toProto())
.setExpr(s.expr().accept(this))
.build();
})
s ->
SortField.newBuilder()
.setDirection(s.direction().toProto())
.setExpr(s.expr().accept(this))
.build())
.collect(java.util.stream.Collectors.toList());
var upperBound = toBound(expr.upperBound());
var lowerBound = toBound(expr.lowerBound());

Expression.WindowFunction.Bound upperBound = toBound(expr.upperBound());
Expression.WindowFunction.Bound lowerBound = toBound(expr.lowerBound());

return Expression.newBuilder()
.setWindowFunction(
builder
.addAllPartitions(partExps)
Expression.WindowFunction.newBuilder()
.setFunctionReference(extensionCollector.getFunctionReference(expr.declaration()))
.addAllArguments(args)
.setOutputType(outputType)
.setPhase(expr.aggregationPhase().toProto())
.setInvocation(expr.invocation().toProto())
.addAllSorts(sortFields)
.addAllPartitions(partitionExprs)
.setLowerBound(lowerBound)
.setUpperBound(upperBound)
.build())
.setUpperBound(upperBound))
.build();
}

private static class LiteralToWindowBoundOffset
extends AbstractExpressionVisitor<Long, RuntimeException> {

@Override
public Long visitFallback(io.substrait.expression.Expression expr) {
throw new RuntimeException(
String.format("Expected positive integer for Window Bound offset, received: %s", expr));
}

private static long offsetIsPositive(long offset) {
if (offset >= 1) {
return offset;
}
throw new RuntimeException(
String.format("Expected positive offset for Window Bound offset, recieved: %d", offset));
}

@Override
public Long visit(io.substrait.expression.Expression.I8Literal expr) throws RuntimeException {
return offsetIsPositive(expr.value());
}

@Override
public Long visit(io.substrait.expression.Expression.I16Literal expr) throws RuntimeException {
return offsetIsPositive(expr.value());
}

@Override
public Long visit(io.substrait.expression.Expression.I32Literal expr) throws RuntimeException {
return offsetIsPositive(expr.value());
}

@Override
public Long visit(io.substrait.expression.Expression.I64Literal expr) throws RuntimeException {
return offsetIsPositive(expr.value());
}
}

private Expression.WindowFunction.Bound toBound(io.substrait.expression.WindowBound windowBound) {
var boundedKind = windowBound.boundedKind();
return switch (boundedKind) {
Expand All @@ -517,7 +467,7 @@ private Expression.WindowFunction.Bound toBound(io.substrait.expression.WindowBo
case BOUNDED -> {
WindowBound.BoundedWindowBound boundedWindowBound =
(WindowBound.BoundedWindowBound) windowBound;
var offset = boundedWindowBound.offset().accept(new LiteralToWindowBoundOffset());
var offset = boundedWindowBound.offset();
yield switch (boundedWindowBound.direction()) {
case PRECEDING -> Expression.WindowFunction.Bound.newBuilder()
.setPreceding(
Expand Down
Loading

0 comments on commit 8fd7c86

Please sign in to comment.