Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: window function calcite support #172

Merged
merged 9 commits into from
Sep 1, 2023
29 changes: 29 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,8 @@ abstract class WindowFunctionInvocation implements Expression {

public abstract List<SortField> sort();

public abstract WindowBoundsType boundsType();

public abstract WindowBound lowerBound();

public abstract WindowBound upperBound();
Expand All @@ -606,6 +608,33 @@ public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws
}
}

enum WindowBoundsType {
UNSPECIFIED(io.substrait.proto.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_UNSPECIFIED),
ROWS(io.substrait.proto.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_ROWS),
RANGE(io.substrait.proto.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_RANGE);

private final io.substrait.proto.Expression.WindowFunction.BoundsType proto;

WindowBoundsType(io.substrait.proto.Expression.WindowFunction.BoundsType proto) {
this.proto = proto;
}

public io.substrait.proto.Expression.WindowFunction.BoundsType toProto() {
return proto;
}

public static WindowBoundsType fromProto(
io.substrait.proto.Expression.WindowFunction.BoundsType proto) {
for (var v : values()) {
if (v.proto == proto) {
return v;
}
}

throw new IllegalArgumentException("Unknown type: " + proto);
}
}

@Value.Immutable
abstract static class SingleOrList implements Expression {
public abstract Expression condition();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ public static Expression.WindowFunctionInvocation windowFunction(
List<Expression.SortField> sort,
Expression.AggregationInvocation invocation,
List<Expression> partitionBy,
Expression.WindowBoundsType boundsType,
WindowBound lowerBound,
WindowBound upperBound,
Iterable<? extends FunctionArg> arguments) {
Expand All @@ -329,6 +330,7 @@ public static Expression.WindowFunctionInvocation windowFunction(
.aggregationPhase(phase)
.sort(sort)
.partitionBy(partitionBy)
.boundsType(boundsType)
.lowerBound(lowerBound)
.upperBound(upperBound)
.invocation(invocation)
Expand All @@ -343,6 +345,7 @@ public static Expression.WindowFunctionInvocation windowFunction(
List<Expression.SortField> sort,
Expression.AggregationInvocation invocation,
List<Expression> partitionBy,
Expression.WindowBoundsType boundsType,
WindowBound lowerBound,
WindowBound upperBound,
FunctionArg... arguments) {
Expand All @@ -353,6 +356,7 @@ public static Expression.WindowFunctionInvocation windowFunction(
.sort(sort)
.invocation(invocation)
.partitionBy(partitionBy)
.boundsType(boundsType)
.lowerBound(lowerBound)
.upperBound(upperBound)
.addArguments(arguments)
Expand Down
66 changes: 35 additions & 31 deletions core/src/main/java/io/substrait/expression/WindowBound.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,58 +5,62 @@
@Value.Enclosing
public interface WindowBound {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than encoding the WindowBound type using a combination of BoundedKind and Direction, I've switched this to 4 classes:

  1. Preceding
  2. Following
  3. Current
  4. Unbounded

and introduced a WindowBoundVisitor to handle dispatch when converting from these to various other representations.

This representation of Window Bounds:

  1. More closely matches the spec.
  2. Reduces the need for casting of WindowBound interfaces to concrete types.
  3. Reduces cognitive complexity (in my opinion) as users no longer need to reason about 6 different enum combinations (3 bounded kinds * 2 directions), and instead have 4 concrete classes.


public BoundedKind boundedKind();
interface WindowBoundVisitor<R, E extends Throwable> {
R visit(Preceding preceding);

enum BoundedKind {
UNBOUNDED,
BOUNDED,
CURRENT_ROW
}
R visit(Following following);

R visit(CurrentRow currentRow);

enum Direction {
PRECEDING,
FOLLOWING
R visit(Unbounded unbounded);
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Direction doesn't have an equivalent in the Substrait spec itself. Calcite has a concept of UNBOUNDED_PRECEDING and UNBOUNDED_FOLLOWING, but those can be reduced to singular Unbounded WindowBound when converting from Calcite to Pojo.


public static CurrentRowWindowBound CURRENT_ROW =
ImmutableWindowBound.CurrentRowWindowBound.builder().build();
<R, E extends Throwable> R accept(WindowBoundVisitor<R, E> visitor);

CurrentRow CURRENT_ROW = ImmutableWindowBound.CurrentRow.builder().build();
Unbounded UNBOUNDED = ImmutableWindowBound.Unbounded.builder().build();

@Value.Immutable
abstract static class UnboundedWindowBound implements WindowBound {
@Override
public BoundedKind boundedKind() {
return BoundedKind.UNBOUNDED;
}
abstract class Preceding implements WindowBound {
public abstract long offset();

public abstract Direction direction();
public static Preceding of(long offset) {
return ImmutableWindowBound.Preceding.builder().offset(offset).build();
}

public static ImmutableWindowBound.UnboundedWindowBound.Builder builder() {
return ImmutableWindowBound.UnboundedWindowBound.builder();
@Override
public <R, E extends Throwable> R accept(WindowBoundVisitor<R, E> visitor) {
return visitor.visit(this);
}
}

@Value.Immutable
abstract static class BoundedWindowBound implements WindowBound {
abstract class Following implements WindowBound {
public abstract long offset();

@Override
public BoundedKind boundedKind() {
return BoundedKind.BOUNDED;
public static Following of(long offset) {
return ImmutableWindowBound.Following.builder().offset(offset).build();
}

public abstract Direction direction();

public abstract long offset();
@Override
public <R, E extends Throwable> R accept(WindowBoundVisitor<R, E> visitor) {
return visitor.visit(this);
}
}

public static ImmutableWindowBound.BoundedWindowBound.Builder builder() {
return ImmutableWindowBound.BoundedWindowBound.builder();
@Value.Immutable
abstract class CurrentRow implements WindowBound {
@Override
public <R, E extends Throwable> R accept(WindowBoundVisitor<R, E> visitor) {
return visitor.visit(this);
}
}

@Value.Immutable
static class CurrentRowWindowBound implements WindowBound {
abstract class Unbounded implements WindowBound {
@Override
public BoundedKind boundedKind() {
return BoundedKind.CURRENT_ROW;
public <R, E extends Throwable> R accept(WindowBoundVisitor<R, E> visitor) {
return visitor.visit(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,8 @@ public Expression visit(io.substrait.expression.Expression.WindowFunctionInvocat
.build())
.collect(java.util.stream.Collectors.toList());

Expression.WindowFunction.Bound upperBound = toBound(expr.upperBound());
Expression.WindowFunction.Bound lowerBound = toBound(expr.lowerBound());
Expression.WindowFunction.Bound lowerBound = BoundConverter.convert(expr.lowerBound());
Expression.WindowFunction.Bound upperBound = BoundConverter.convert(expr.upperBound());

return Expression.newBuilder()
.setWindowFunction(
Expand All @@ -453,35 +453,51 @@ public Expression visit(io.substrait.expression.Expression.WindowFunctionInvocat
.setInvocation(expr.invocation().toProto())
.addAllSorts(sortFields)
.addAllPartitions(partitionExprs)
.setBoundsType(expr.boundsType().toProto())
.setLowerBound(lowerBound)
.setUpperBound(upperBound))
.build();
}

private Expression.WindowFunction.Bound toBound(io.substrait.expression.WindowBound windowBound) {
var boundedKind = windowBound.boundedKind();
return switch (boundedKind) {
case CURRENT_ROW -> Expression.WindowFunction.Bound.newBuilder()
static class BoundConverter
implements WindowBound.WindowBoundVisitor<Expression.WindowFunction.Bound, RuntimeException> {

static Expression.WindowFunction.Bound convert(WindowBound bound) {
return bound.accept(TO_BOUND_VISITOR);
}

private static final BoundConverter TO_BOUND_VISITOR = new BoundConverter();

private BoundConverter() {}

@Override
public Expression.WindowFunction.Bound visit(WindowBound.Preceding preceding) {
return Expression.WindowFunction.Bound.newBuilder()
.setPreceding(
Expression.WindowFunction.Bound.Preceding.newBuilder().setOffset(preceding.offset()))
.build();
}

@Override
public Expression.WindowFunction.Bound visit(WindowBound.Following following) {
return Expression.WindowFunction.Bound.newBuilder()
.setFollowing(
Expression.WindowFunction.Bound.Following.newBuilder().setOffset(following.offset()))
.build();
}

@Override
public Expression.WindowFunction.Bound visit(WindowBound.CurrentRow currentRow) {
return Expression.WindowFunction.Bound.newBuilder()
.setCurrentRow(Expression.WindowFunction.Bound.CurrentRow.getDefaultInstance())
.build();
case BOUNDED -> {
WindowBound.BoundedWindowBound boundedWindowBound =
(WindowBound.BoundedWindowBound) windowBound;
var offset = boundedWindowBound.offset();
yield switch (boundedWindowBound.direction()) {
case PRECEDING -> Expression.WindowFunction.Bound.newBuilder()
.setPreceding(
Expression.WindowFunction.Bound.Preceding.newBuilder().setOffset(offset))
.build();
case FOLLOWING -> Expression.WindowFunction.Bound.newBuilder()
.setFollowing(
Expression.WindowFunction.Bound.Following.newBuilder().setOffset(offset))
.build();
};
}
case UNBOUNDED -> Expression.WindowFunction.Bound.newBuilder()
}

@Override
public Expression.WindowFunction.Bound visit(WindowBound.Unbounded unbounded) {
return Expression.WindowFunction.Bound.newBuilder()
.setUnbounded(Expression.WindowFunction.Bound.Unbounded.getDefaultInstance())
.build();
};
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ public Expression from(io.substrait.proto.Expression expr) {
.build())
.collect(java.util.stream.Collectors.toList());

WindowBound lowerBound = toLowerBound(windowFunction.getLowerBound());
WindowBound upperBound = toUpperBound(windowFunction.getUpperBound());
WindowBound lowerBound = toWindowBound(windowFunction.getLowerBound());
WindowBound upperBound = toWindowBound(windowFunction.getUpperBound());

yield Expression.WindowFunctionInvocation.builder()
.arguments(args)
Expand All @@ -148,6 +148,7 @@ public Expression from(io.substrait.proto.Expression expr) {
.aggregationPhase(Expression.AggregationPhase.fromProto(windowFunction.getPhase()))
.partitionBy(partitionExprs)
.sort(sortFields)
.boundsType(Expression.WindowBoundsType.fromProto(windowFunction.getBoundsType()))
.lowerBound(lowerBound)
.upperBound(upperBound)
.invocation(Expression.AggregationInvocation.fromProto(windowFunction.getInvocation()))
Expand Down Expand Up @@ -247,35 +248,16 @@ public Expression from(io.substrait.proto.Expression expr) {
};
}

private WindowBound toLowerBound(io.substrait.proto.Expression.WindowFunction.Bound bound) {
return toWindowBound(
bound,
WindowBound.UnboundedWindowBound.builder()
.direction(WindowBound.Direction.PRECEDING)
.build());
}

private WindowBound toUpperBound(io.substrait.proto.Expression.WindowFunction.Bound bound) {
return toWindowBound(
bound,
WindowBound.UnboundedWindowBound.builder()
.direction(WindowBound.Direction.FOLLOWING)
.build());
}

private WindowBound toWindowBound(
io.substrait.proto.Expression.WindowFunction.Bound bound, WindowBound defaultBound) {
private WindowBound toWindowBound(io.substrait.proto.Expression.WindowFunction.Bound bound) {
return switch (bound.getKindCase()) {
case PRECEDING -> WindowBound.BoundedWindowBound.builder()
.direction(WindowBound.Direction.PRECEDING)
.offset(bound.getPreceding().getOffset())
.build();
case FOLLOWING -> WindowBound.BoundedWindowBound.builder()
.direction(WindowBound.Direction.FOLLOWING)
.offset(bound.getFollowing().getOffset())
.build();
case PRECEDING -> WindowBound.Preceding.of(bound.getPreceding().getOffset());
case FOLLOWING -> WindowBound.Following.of(bound.getFollowing().getOffset());
case CURRENT_ROW -> WindowBound.CURRENT_ROW;
case UNBOUNDED, KIND_NOT_SET -> defaultBound;
case UNBOUNDED -> WindowBound.UNBOUNDED;
case KIND_NOT_SET ->
// per the spec, the lower and upper bounds default to the start or end of the partition
// respectively if not set
WindowBound.UNBOUNDED;
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.substrait.isthmus.expression.AggregateFunctionConverter;
import io.substrait.isthmus.expression.ExpressionRexConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import io.substrait.relation.AbstractRelVisitor;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
Expand Down Expand Up @@ -72,6 +73,7 @@ public SubstraitRelNodeConverter(
relBuilder,
new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory),
new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory),
new WindowFunctionConverter(extensions.windowFunctions(), typeFactory),
TypeConverter.DEFAULT);
}

Expand All @@ -80,6 +82,7 @@ public SubstraitRelNodeConverter(
RelBuilder relBuilder,
ScalarFunctionConverter scalarFunctionConverter,
AggregateFunctionConverter aggregateFunctionConverter,
WindowFunctionConverter windowFunctionConverter,
TypeConverter typeConverter) {
this.typeFactory = typeFactory;
this.typeConverter = typeConverter;
Expand All @@ -89,7 +92,7 @@ public SubstraitRelNodeConverter(
this.aggregateFunctionConverter = aggregateFunctionConverter;
this.expressionRexConverter =
new ExpressionRexConverter(
typeFactory, scalarFunctionConverter, aggregateFunctionConverter, typeConverter);
typeFactory, scalarFunctionConverter, windowFunctionConverter, typeConverter);
}

public static RelNode convert(
Expand Down
Loading