Skip to content

Commit

Permalink
feat: convert Substrait window functions to Calcite RexOvers (#172)
Browse files Browse the repository at this point in the history
feat: support for window bounds type
feat: reject IGNORE NULLS

BREAKING CHANGE:
* windowFunction expression creator now requires window bound type parameter
* the WindowBound POJO representation has been reworked to use visitation and more closely match the spec
* ExpressionRexConverter now requires a WindowFunctionConverter
  • Loading branch information
vbarua committed Sep 6, 2023
1 parent daf7499 commit f20d065
Show file tree
Hide file tree
Showing 12 changed files with 337 additions and 128 deletions.
29 changes: 29 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,8 @@ abstract class WindowFunctionInvocation implements Expression {

public abstract List<SortField> sort();

public abstract WindowBoundsType boundsType();

public abstract WindowBound lowerBound();

public abstract WindowBound upperBound();
Expand All @@ -606,6 +608,33 @@ public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws
}
}

enum WindowBoundsType {
UNSPECIFIED(io.substrait.proto.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_UNSPECIFIED),
ROWS(io.substrait.proto.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_ROWS),
RANGE(io.substrait.proto.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_RANGE);

private final io.substrait.proto.Expression.WindowFunction.BoundsType proto;

WindowBoundsType(io.substrait.proto.Expression.WindowFunction.BoundsType proto) {
this.proto = proto;
}

public io.substrait.proto.Expression.WindowFunction.BoundsType toProto() {
return proto;
}

public static WindowBoundsType fromProto(
io.substrait.proto.Expression.WindowFunction.BoundsType proto) {
for (var v : values()) {
if (v.proto == proto) {
return v;
}
}

throw new IllegalArgumentException("Unknown type: " + proto);
}
}

@Value.Immutable
abstract static class SingleOrList implements Expression {
public abstract Expression condition();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ public static Expression.WindowFunctionInvocation windowFunction(
List<Expression.SortField> sort,
Expression.AggregationInvocation invocation,
List<Expression> partitionBy,
Expression.WindowBoundsType boundsType,
WindowBound lowerBound,
WindowBound upperBound,
Iterable<? extends FunctionArg> arguments) {
Expand All @@ -329,6 +330,7 @@ public static Expression.WindowFunctionInvocation windowFunction(
.aggregationPhase(phase)
.sort(sort)
.partitionBy(partitionBy)
.boundsType(boundsType)
.lowerBound(lowerBound)
.upperBound(upperBound)
.invocation(invocation)
Expand All @@ -343,6 +345,7 @@ public static Expression.WindowFunctionInvocation windowFunction(
List<Expression.SortField> sort,
Expression.AggregationInvocation invocation,
List<Expression> partitionBy,
Expression.WindowBoundsType boundsType,
WindowBound lowerBound,
WindowBound upperBound,
FunctionArg... arguments) {
Expand All @@ -353,6 +356,7 @@ public static Expression.WindowFunctionInvocation windowFunction(
.sort(sort)
.invocation(invocation)
.partitionBy(partitionBy)
.boundsType(boundsType)
.lowerBound(lowerBound)
.upperBound(upperBound)
.addArguments(arguments)
Expand Down
66 changes: 35 additions & 31 deletions core/src/main/java/io/substrait/expression/WindowBound.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,58 +5,62 @@
@Value.Enclosing
public interface WindowBound {

public BoundedKind boundedKind();
interface WindowBoundVisitor<R, E extends Throwable> {
R visit(Preceding preceding);

enum BoundedKind {
UNBOUNDED,
BOUNDED,
CURRENT_ROW
}
R visit(Following following);

R visit(CurrentRow currentRow);

enum Direction {
PRECEDING,
FOLLOWING
R visit(Unbounded unbounded);
}

public static CurrentRowWindowBound CURRENT_ROW =
ImmutableWindowBound.CurrentRowWindowBound.builder().build();
<R, E extends Throwable> R accept(WindowBoundVisitor<R, E> visitor);

CurrentRow CURRENT_ROW = ImmutableWindowBound.CurrentRow.builder().build();
Unbounded UNBOUNDED = ImmutableWindowBound.Unbounded.builder().build();

@Value.Immutable
abstract static class UnboundedWindowBound implements WindowBound {
@Override
public BoundedKind boundedKind() {
return BoundedKind.UNBOUNDED;
}
abstract class Preceding implements WindowBound {
public abstract long offset();

public abstract Direction direction();
public static Preceding of(long offset) {
return ImmutableWindowBound.Preceding.builder().offset(offset).build();
}

public static ImmutableWindowBound.UnboundedWindowBound.Builder builder() {
return ImmutableWindowBound.UnboundedWindowBound.builder();
@Override
public <R, E extends Throwable> R accept(WindowBoundVisitor<R, E> visitor) {
return visitor.visit(this);
}
}

@Value.Immutable
abstract static class BoundedWindowBound implements WindowBound {
abstract class Following implements WindowBound {
public abstract long offset();

@Override
public BoundedKind boundedKind() {
return BoundedKind.BOUNDED;
public static Following of(long offset) {
return ImmutableWindowBound.Following.builder().offset(offset).build();
}

public abstract Direction direction();

public abstract long offset();
@Override
public <R, E extends Throwable> R accept(WindowBoundVisitor<R, E> visitor) {
return visitor.visit(this);
}
}

public static ImmutableWindowBound.BoundedWindowBound.Builder builder() {
return ImmutableWindowBound.BoundedWindowBound.builder();
@Value.Immutable
abstract class CurrentRow implements WindowBound {
@Override
public <R, E extends Throwable> R accept(WindowBoundVisitor<R, E> visitor) {
return visitor.visit(this);
}
}

@Value.Immutable
static class CurrentRowWindowBound implements WindowBound {
abstract class Unbounded implements WindowBound {
@Override
public BoundedKind boundedKind() {
return BoundedKind.CURRENT_ROW;
public <R, E extends Throwable> R accept(WindowBoundVisitor<R, E> visitor) {
return visitor.visit(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,8 @@ public Expression visit(io.substrait.expression.Expression.WindowFunctionInvocat
.build())
.collect(java.util.stream.Collectors.toList());

Expression.WindowFunction.Bound upperBound = toBound(expr.upperBound());
Expression.WindowFunction.Bound lowerBound = toBound(expr.lowerBound());
Expression.WindowFunction.Bound lowerBound = BoundConverter.convert(expr.lowerBound());
Expression.WindowFunction.Bound upperBound = BoundConverter.convert(expr.upperBound());

return Expression.newBuilder()
.setWindowFunction(
Expand All @@ -453,35 +453,51 @@ public Expression visit(io.substrait.expression.Expression.WindowFunctionInvocat
.setInvocation(expr.invocation().toProto())
.addAllSorts(sortFields)
.addAllPartitions(partitionExprs)
.setBoundsType(expr.boundsType().toProto())
.setLowerBound(lowerBound)
.setUpperBound(upperBound))
.build();
}

private Expression.WindowFunction.Bound toBound(io.substrait.expression.WindowBound windowBound) {
var boundedKind = windowBound.boundedKind();
return switch (boundedKind) {
case CURRENT_ROW -> Expression.WindowFunction.Bound.newBuilder()
static class BoundConverter
implements WindowBound.WindowBoundVisitor<Expression.WindowFunction.Bound, RuntimeException> {

static Expression.WindowFunction.Bound convert(WindowBound bound) {
return bound.accept(TO_BOUND_VISITOR);
}

private static final BoundConverter TO_BOUND_VISITOR = new BoundConverter();

private BoundConverter() {}

@Override
public Expression.WindowFunction.Bound visit(WindowBound.Preceding preceding) {
return Expression.WindowFunction.Bound.newBuilder()
.setPreceding(
Expression.WindowFunction.Bound.Preceding.newBuilder().setOffset(preceding.offset()))
.build();
}

@Override
public Expression.WindowFunction.Bound visit(WindowBound.Following following) {
return Expression.WindowFunction.Bound.newBuilder()
.setFollowing(
Expression.WindowFunction.Bound.Following.newBuilder().setOffset(following.offset()))
.build();
}

@Override
public Expression.WindowFunction.Bound visit(WindowBound.CurrentRow currentRow) {
return Expression.WindowFunction.Bound.newBuilder()
.setCurrentRow(Expression.WindowFunction.Bound.CurrentRow.getDefaultInstance())
.build();
case BOUNDED -> {
WindowBound.BoundedWindowBound boundedWindowBound =
(WindowBound.BoundedWindowBound) windowBound;
var offset = boundedWindowBound.offset();
yield switch (boundedWindowBound.direction()) {
case PRECEDING -> Expression.WindowFunction.Bound.newBuilder()
.setPreceding(
Expression.WindowFunction.Bound.Preceding.newBuilder().setOffset(offset))
.build();
case FOLLOWING -> Expression.WindowFunction.Bound.newBuilder()
.setFollowing(
Expression.WindowFunction.Bound.Following.newBuilder().setOffset(offset))
.build();
};
}
case UNBOUNDED -> Expression.WindowFunction.Bound.newBuilder()
}

@Override
public Expression.WindowFunction.Bound visit(WindowBound.Unbounded unbounded) {
return Expression.WindowFunction.Bound.newBuilder()
.setUnbounded(Expression.WindowFunction.Bound.Unbounded.getDefaultInstance())
.build();
};
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ public Expression from(io.substrait.proto.Expression expr) {
.build())
.collect(java.util.stream.Collectors.toList());

WindowBound lowerBound = toLowerBound(windowFunction.getLowerBound());
WindowBound upperBound = toUpperBound(windowFunction.getUpperBound());
WindowBound lowerBound = toWindowBound(windowFunction.getLowerBound());
WindowBound upperBound = toWindowBound(windowFunction.getUpperBound());

yield Expression.WindowFunctionInvocation.builder()
.arguments(args)
Expand All @@ -148,6 +148,7 @@ public Expression from(io.substrait.proto.Expression expr) {
.aggregationPhase(Expression.AggregationPhase.fromProto(windowFunction.getPhase()))
.partitionBy(partitionExprs)
.sort(sortFields)
.boundsType(Expression.WindowBoundsType.fromProto(windowFunction.getBoundsType()))
.lowerBound(lowerBound)
.upperBound(upperBound)
.invocation(Expression.AggregationInvocation.fromProto(windowFunction.getInvocation()))
Expand Down Expand Up @@ -247,35 +248,16 @@ public Expression from(io.substrait.proto.Expression expr) {
};
}

private WindowBound toLowerBound(io.substrait.proto.Expression.WindowFunction.Bound bound) {
return toWindowBound(
bound,
WindowBound.UnboundedWindowBound.builder()
.direction(WindowBound.Direction.PRECEDING)
.build());
}

private WindowBound toUpperBound(io.substrait.proto.Expression.WindowFunction.Bound bound) {
return toWindowBound(
bound,
WindowBound.UnboundedWindowBound.builder()
.direction(WindowBound.Direction.FOLLOWING)
.build());
}

private WindowBound toWindowBound(
io.substrait.proto.Expression.WindowFunction.Bound bound, WindowBound defaultBound) {
private WindowBound toWindowBound(io.substrait.proto.Expression.WindowFunction.Bound bound) {
return switch (bound.getKindCase()) {
case PRECEDING -> WindowBound.BoundedWindowBound.builder()
.direction(WindowBound.Direction.PRECEDING)
.offset(bound.getPreceding().getOffset())
.build();
case FOLLOWING -> WindowBound.BoundedWindowBound.builder()
.direction(WindowBound.Direction.FOLLOWING)
.offset(bound.getFollowing().getOffset())
.build();
case PRECEDING -> WindowBound.Preceding.of(bound.getPreceding().getOffset());
case FOLLOWING -> WindowBound.Following.of(bound.getFollowing().getOffset());
case CURRENT_ROW -> WindowBound.CURRENT_ROW;
case UNBOUNDED, KIND_NOT_SET -> defaultBound;
case UNBOUNDED -> WindowBound.UNBOUNDED;
case KIND_NOT_SET ->
// per the spec, the lower and upper bounds default to the start or end of the partition
// respectively if not set
WindowBound.UNBOUNDED;
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.substrait.isthmus.expression.AggregateFunctionConverter;
import io.substrait.isthmus.expression.ExpressionRexConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import io.substrait.relation.AbstractRelVisitor;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
Expand Down Expand Up @@ -72,6 +73,7 @@ public SubstraitRelNodeConverter(
relBuilder,
new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory),
new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory),
new WindowFunctionConverter(extensions.windowFunctions(), typeFactory),
TypeConverter.DEFAULT);
}

Expand All @@ -80,6 +82,7 @@ public SubstraitRelNodeConverter(
RelBuilder relBuilder,
ScalarFunctionConverter scalarFunctionConverter,
AggregateFunctionConverter aggregateFunctionConverter,
WindowFunctionConverter windowFunctionConverter,
TypeConverter typeConverter) {
this.typeFactory = typeFactory;
this.typeConverter = typeConverter;
Expand All @@ -89,7 +92,7 @@ public SubstraitRelNodeConverter(
this.aggregateFunctionConverter = aggregateFunctionConverter;
this.expressionRexConverter =
new ExpressionRexConverter(
typeFactory, scalarFunctionConverter, aggregateFunctionConverter, typeConverter);
typeFactory, scalarFunctionConverter, windowFunctionConverter, typeConverter);
}

public static RelNode convert(
Expand Down
Loading

0 comments on commit f20d065

Please sign in to comment.