Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: unify aggregate and window functions in window handling #170

Merged
merged 8 commits into from
Aug 30, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ public OUTPUT visit(Expression.ScalarFunctionInvocation expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.WindowFunctionInvocation expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.Cast expr) throws EXCEPTION {
return visitFallback(expr);
Expand Down Expand Up @@ -173,9 +178,4 @@ public OUTPUT visit(Expression.ScalarSubquery expr) throws EXCEPTION {
public OUTPUT visit(Expression.InPredicate expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.Window expr) throws EXCEPTION {
return visitFallback(expr);
}
}
74 changes: 36 additions & 38 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
import com.google.protobuf.ByteString;
import io.substrait.extension.SimpleExtension;
import io.substrait.proto.AggregateFunction;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Rel;
import io.substrait.type.Type;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import javax.annotation.Nullable;
import org.immutables.value.Value;

@Value.Enclosing
Expand Down Expand Up @@ -573,6 +570,42 @@ public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws
}
}

@Value.Immutable
abstract class WindowFunctionInvocation implements Expression {

public abstract SimpleExtension.WindowFunctionVariant declaration();

public abstract List<FunctionArg> arguments();

public abstract Map<String, FunctionOption> options();

public abstract AggregationPhase aggregationPhase();

public abstract List<Expression> partitionBy();

public abstract List<SortField> sort();

public abstract WindowBound lowerBound();

public abstract WindowBound upperBound();

public abstract Type outputType();

public Type getType() {
return outputType();
}

public abstract AggregationInvocation invocation();

public static ImmutableExpression.WindowFunctionInvocation.Builder builder() {
return ImmutableExpression.WindowFunctionInvocation.builder();
}

public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws E {
return visitor.visit(this);
}
}

@Value.Immutable
abstract static class SingleOrList implements Expression {
public abstract Expression condition();
Expand Down Expand Up @@ -684,41 +717,6 @@ public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws
}
}

@Value.Immutable
abstract static class Window implements Expression {
@Nullable
public abstract Aggregate.Measure aggregateFunction();

@Nullable
public abstract WindowFunction windowFunction();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having separate handling for window calls with window functions, and window calls with aggregate functions, this PR updates the extension handling to treat all aggregate functions as valid window functions and a such removes the need for this distinction.


public abstract List<Expression> partitionBy();

public abstract List<SortField> orderBy();

public abstract WindowBound lowerBound();

public abstract WindowBound upperBound();

public abstract boolean hasNormalAggregateFunction();

public static ImmutableExpression.Window.Builder builder() {
return ImmutableExpression.Window.builder();
}

public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws E {
return visitor.visit(this);
}
}

@Value.Immutable
public abstract static class WindowFunction {
public abstract WindowFunctionInvocation getFunction();

public abstract Optional<Expression> getPreMeasureFilter();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The preMeasureFilter field was not used anywhere, and the WindowFunction message cannot actually encode a pre-measure filter (for now).

I opted to remove it entirely. Support can be added in once the spec supports it.

/** public static ImmutableMeasure.Builder builder() { return ImmutableMeasure.builder(); } */
}

enum PredicateOp {
PREDICATE_OP_UNSPECIFIED(
io.substrait.proto.Expression.Subquery.SetPredicate.PredicateOp.PREDICATE_OP_UNSPECIFIED),
Expand Down
20 changes: 16 additions & 4 deletions core/src/main/java/io/substrait/expression/ExpressionCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -313,36 +313,48 @@ public static AggregateFunctionInvocation aggregateFunction(
.build();
}

public static WindowFunctionInvocation windowFunction(
public static Expression.WindowFunctionInvocation windowFunction(
SimpleExtension.WindowFunctionVariant declaration,
Type outputType,
Expression.AggregationPhase phase,
List<Expression.SortField> sort,
Expression.AggregationInvocation invocation,
List<Expression> partitionBy,
WindowBound lowerBound,
WindowBound upperBound,
Iterable<? extends FunctionArg> arguments) {
return WindowFunctionInvocation.builder()
return Expression.WindowFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.aggregationPhase(phase)
.sort(sort)
.partitionBy(partitionBy)
.lowerBound(lowerBound)
.upperBound(upperBound)
.invocation(invocation)
.addAllArguments(arguments)
.build();
}

public static WindowFunctionInvocation windowFunction(
public static Expression.WindowFunctionInvocation windowFunction(
SimpleExtension.WindowFunctionVariant declaration,
Type outputType,
Expression.AggregationPhase phase,
List<Expression.SortField> sort,
Expression.AggregationInvocation invocation,
List<Expression> partitionBy,
WindowBound lowerBound,
WindowBound upperBound,
FunctionArg... arguments) {
return WindowFunctionInvocation.builder()
return Expression.WindowFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.aggregationPhase(phase)
.sort(sort)
.invocation(invocation)
.partitionBy(partitionBy)
.lowerBound(lowerBound)
.upperBound(upperBound)
.addArguments(arguments)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ public interface ExpressionVisitor<R, E extends Throwable> {

R visit(Expression.ScalarFunctionInvocation expr) throws E;

R visit(Expression.WindowFunctionInvocation expr) throws E;

R visit(Expression.Cast expr) throws E;

R visit(Expression.SingleOrList expr) throws E;
Expand All @@ -70,6 +72,4 @@ public interface ExpressionVisitor<R, E extends Throwable> {
R visit(Expression.ScalarSubquery expr) throws E;

R visit(Expression.InPredicate expr) throws E;

R visit(Expression.Window expr) throws E;
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public BoundedKind boundedKind() {

public abstract Direction direction();

public abstract Expression offset();
public abstract long offset();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the spec, offset must be int64.

The POJO layer should reflect this restriction.


public static ImmutableWindowBound.BoundedWindowBound.Builder builder() {
return ImmutableWindowBound.BoundedWindowBound.builder();
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package io.substrait.expression.proto;

import io.substrait.expression.AbstractExpressionVisitor;
import io.substrait.expression.ExpressionVisitor;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.WindowBound;
import io.substrait.extension.ExtensionCollector;
import io.substrait.proto.Expression;
import io.substrait.proto.FunctionArgument;
import io.substrait.proto.Rel;
import io.substrait.proto.SortField;
import io.substrait.proto.Type;
import io.substrait.relation.RelVisitor;
import io.substrait.type.proto.TypeProtoConverter;
import java.util.List;
Expand Down Expand Up @@ -415,99 +416,48 @@ public Expression visit(io.substrait.expression.Expression.InPredicate expr)
.build();
}

public Expression visit(io.substrait.expression.Expression.Window expr) throws RuntimeException {
var partExps =
public Expression visit(io.substrait.expression.Expression.WindowFunctionInvocation expr)
throws RuntimeException {
var argVisitor = FunctionArg.toProto(typeProtoConverter, this);
List<FunctionArgument> args =
expr.arguments().stream()
.map(a -> a.accept(expr.declaration(), 0, argVisitor))
.collect(java.util.stream.Collectors.toList());
Type outputType = expr.getType().accept(typeProtoConverter);

List<Expression> partitionExprs =
expr.partitionBy().stream()
.map(e -> e.accept(this))
.collect(java.util.stream.Collectors.toList());
var outputType = expr.getType().accept(typeProtoConverter);
var builder = Expression.WindowFunction.newBuilder().setOutputType(outputType);
if (expr.hasNormalAggregateFunction()) {
var aggMeasureFunc = expr.aggregateFunction().getFunction();
var funcReference = extensionCollector.getFunctionReference(aggMeasureFunc.declaration());
var argVisitor = FunctionArg.toProto(typeProtoConverter, this);
var args =
aggMeasureFunc.arguments().stream()
.map(a -> a.accept(aggMeasureFunc.declaration(), 0, argVisitor))
.collect(java.util.stream.Collectors.toList());
builder
.setFunctionReference(funcReference)
.setPhase(aggMeasureFunc.aggregationPhase().toProto())
.addAllArguments(args);
} else {
var windowFunc = expr.windowFunction().getFunction();
var funcReference = extensionCollector.getFunctionReference(windowFunc.declaration());
var argVisitor = FunctionArg.toProto(typeProtoConverter, this);
var args =
windowFunc.arguments().stream()
.map(a -> a.accept(windowFunc.declaration(), 0, argVisitor))
.collect(java.util.stream.Collectors.toList());
builder
.setFunctionReference(funcReference)
.setPhase(windowFunc.aggregationPhase().toProto())
.addAllArguments(args);
}
var sortFields =
expr.orderBy().stream()

List<SortField> sortFields =
expr.sort().stream()
.map(
s -> {
return SortField.newBuilder()
.setDirection(s.direction().toProto())
.setExpr(s.expr().accept(this))
.build();
})
s ->
SortField.newBuilder()
.setDirection(s.direction().toProto())
.setExpr(s.expr().accept(this))
.build())
.collect(java.util.stream.Collectors.toList());
var upperBound = toBound(expr.upperBound());
var lowerBound = toBound(expr.lowerBound());

Expression.WindowFunction.Bound upperBound = toBound(expr.upperBound());
Expression.WindowFunction.Bound lowerBound = toBound(expr.lowerBound());

return Expression.newBuilder()
.setWindowFunction(
builder
.addAllPartitions(partExps)
Expression.WindowFunction.newBuilder()
.setFunctionReference(extensionCollector.getFunctionReference(expr.declaration()))
.addAllArguments(args)
.setOutputType(outputType)
.setPhase(expr.aggregationPhase().toProto())
.setInvocation(expr.invocation().toProto())
.addAllSorts(sortFields)
.addAllPartitions(partitionExprs)
.setLowerBound(lowerBound)
.setUpperBound(upperBound)
.build())
.setUpperBound(upperBound))
.build();
}

private static class LiteralToWindowBoundOffset
extends AbstractExpressionVisitor<Long, RuntimeException> {

@Override
public Long visitFallback(io.substrait.expression.Expression expr) {
throw new RuntimeException(
String.format("Expected positive integer for Window Bound offset, received: %s", expr));
}

private static long offsetIsPositive(long offset) {
if (offset >= 1) {
return offset;
}
throw new RuntimeException(
String.format("Expected positive offset for Window Bound offset, recieved: %d", offset));
}

@Override
public Long visit(io.substrait.expression.Expression.I8Literal expr) throws RuntimeException {
return offsetIsPositive(expr.value());
}

@Override
public Long visit(io.substrait.expression.Expression.I16Literal expr) throws RuntimeException {
return offsetIsPositive(expr.value());
}

@Override
public Long visit(io.substrait.expression.Expression.I32Literal expr) throws RuntimeException {
return offsetIsPositive(expr.value());
}

@Override
public Long visit(io.substrait.expression.Expression.I64Literal expr) throws RuntimeException {
return offsetIsPositive(expr.value());
}
}

private Expression.WindowFunction.Bound toBound(io.substrait.expression.WindowBound windowBound) {
var boundedKind = windowBound.boundedKind();
return switch (boundedKind) {
Expand All @@ -517,7 +467,7 @@ private Expression.WindowFunction.Bound toBound(io.substrait.expression.WindowBo
case BOUNDED -> {
WindowBound.BoundedWindowBound boundedWindowBound =
(WindowBound.BoundedWindowBound) windowBound;
var offset = boundedWindowBound.offset().accept(new LiteralToWindowBoundOffset());
var offset = boundedWindowBound.offset();
yield switch (boundedWindowBound.direction()) {
case PRECEDING -> Expression.WindowFunction.Bound.newBuilder()
.setPreceding(
Expand Down
Loading