diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index ffcad476..7abe8d37 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -585,6 +585,8 @@ abstract class WindowFunctionInvocation implements Expression { public abstract List sort(); + public abstract WindowBoundsType boundsType(); + public abstract WindowBound lowerBound(); public abstract WindowBound upperBound(); @@ -606,6 +608,33 @@ public R accept(ExpressionVisitor visitor) throws } } + enum WindowBoundsType { + UNSPECIFIED(io.substrait.proto.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_UNSPECIFIED), + ROWS(io.substrait.proto.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_ROWS), + RANGE(io.substrait.proto.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_RANGE); + + private final io.substrait.proto.Expression.WindowFunction.BoundsType proto; + + WindowBoundsType(io.substrait.proto.Expression.WindowFunction.BoundsType proto) { + this.proto = proto; + } + + public io.substrait.proto.Expression.WindowFunction.BoundsType toProto() { + return proto; + } + + public static WindowBoundsType fromProto( + io.substrait.proto.Expression.WindowFunction.BoundsType proto) { + for (var v : values()) { + if (v.proto == proto) { + return v; + } + } + + throw new IllegalArgumentException("Unknown type: " + proto); + } + } + @Value.Immutable abstract static class SingleOrList implements Expression { public abstract Expression condition(); diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index 9e466e8c..7257b6d4 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -320,6 +320,7 @@ public static Expression.WindowFunctionInvocation windowFunction( List sort, Expression.AggregationInvocation invocation, List partitionBy, + Expression.WindowBoundsType boundsType, WindowBound lowerBound, WindowBound upperBound, Iterable arguments) { @@ -329,6 +330,7 @@ public static Expression.WindowFunctionInvocation windowFunction( .aggregationPhase(phase) .sort(sort) .partitionBy(partitionBy) + .boundsType(boundsType) .lowerBound(lowerBound) .upperBound(upperBound) .invocation(invocation) @@ -343,6 +345,7 @@ public static Expression.WindowFunctionInvocation windowFunction( List sort, Expression.AggregationInvocation invocation, List partitionBy, + Expression.WindowBoundsType boundsType, WindowBound lowerBound, WindowBound upperBound, FunctionArg... arguments) { @@ -353,6 +356,7 @@ public static Expression.WindowFunctionInvocation windowFunction( .sort(sort) .invocation(invocation) .partitionBy(partitionBy) + .boundsType(boundsType) .lowerBound(lowerBound) .upperBound(upperBound) .addArguments(arguments) diff --git a/core/src/main/java/io/substrait/expression/WindowBound.java b/core/src/main/java/io/substrait/expression/WindowBound.java index e478d459..7c9d9312 100644 --- a/core/src/main/java/io/substrait/expression/WindowBound.java +++ b/core/src/main/java/io/substrait/expression/WindowBound.java @@ -5,58 +5,62 @@ @Value.Enclosing public interface WindowBound { - public BoundedKind boundedKind(); + interface WindowBoundVisitor { + R visit(Preceding preceding); - enum BoundedKind { - UNBOUNDED, - BOUNDED, - CURRENT_ROW - } + R visit(Following following); + + R visit(CurrentRow currentRow); - enum Direction { - PRECEDING, - FOLLOWING + R visit(Unbounded unbounded); } - public static CurrentRowWindowBound CURRENT_ROW = - ImmutableWindowBound.CurrentRowWindowBound.builder().build(); + R accept(WindowBoundVisitor visitor); + + CurrentRow CURRENT_ROW = ImmutableWindowBound.CurrentRow.builder().build(); + Unbounded UNBOUNDED = ImmutableWindowBound.Unbounded.builder().build(); @Value.Immutable - abstract static class UnboundedWindowBound implements WindowBound { - @Override - public BoundedKind boundedKind() { - return BoundedKind.UNBOUNDED; - } + abstract class Preceding implements WindowBound { + public abstract long offset(); - public abstract Direction direction(); + public static Preceding of(long offset) { + return ImmutableWindowBound.Preceding.builder().offset(offset).build(); + } - public static ImmutableWindowBound.UnboundedWindowBound.Builder builder() { - return ImmutableWindowBound.UnboundedWindowBound.builder(); + @Override + public R accept(WindowBoundVisitor visitor) { + return visitor.visit(this); } } @Value.Immutable - abstract static class BoundedWindowBound implements WindowBound { + abstract class Following implements WindowBound { + public abstract long offset(); - @Override - public BoundedKind boundedKind() { - return BoundedKind.BOUNDED; + public static Following of(long offset) { + return ImmutableWindowBound.Following.builder().offset(offset).build(); } - public abstract Direction direction(); - - public abstract long offset(); + @Override + public R accept(WindowBoundVisitor visitor) { + return visitor.visit(this); + } + } - public static ImmutableWindowBound.BoundedWindowBound.Builder builder() { - return ImmutableWindowBound.BoundedWindowBound.builder(); + @Value.Immutable + abstract class CurrentRow implements WindowBound { + @Override + public R accept(WindowBoundVisitor visitor) { + return visitor.visit(this); } } @Value.Immutable - static class CurrentRowWindowBound implements WindowBound { + abstract class Unbounded implements WindowBound { @Override - public BoundedKind boundedKind() { - return BoundedKind.CURRENT_ROW; + public R accept(WindowBoundVisitor visitor) { + return visitor.visit(this); } } } diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index 8f0d0d50..c962e14f 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -440,8 +440,8 @@ public Expression visit(io.substrait.expression.Expression.WindowFunctionInvocat .build()) .collect(java.util.stream.Collectors.toList()); - Expression.WindowFunction.Bound upperBound = toBound(expr.upperBound()); - Expression.WindowFunction.Bound lowerBound = toBound(expr.lowerBound()); + Expression.WindowFunction.Bound lowerBound = BoundConverter.convert(expr.lowerBound()); + Expression.WindowFunction.Bound upperBound = BoundConverter.convert(expr.upperBound()); return Expression.newBuilder() .setWindowFunction( @@ -453,35 +453,51 @@ public Expression visit(io.substrait.expression.Expression.WindowFunctionInvocat .setInvocation(expr.invocation().toProto()) .addAllSorts(sortFields) .addAllPartitions(partitionExprs) + .setBoundsType(expr.boundsType().toProto()) .setLowerBound(lowerBound) .setUpperBound(upperBound)) .build(); } - private Expression.WindowFunction.Bound toBound(io.substrait.expression.WindowBound windowBound) { - var boundedKind = windowBound.boundedKind(); - return switch (boundedKind) { - case CURRENT_ROW -> Expression.WindowFunction.Bound.newBuilder() + static class BoundConverter + implements WindowBound.WindowBoundVisitor { + + static Expression.WindowFunction.Bound convert(WindowBound bound) { + return bound.accept(TO_BOUND_VISITOR); + } + + private static final BoundConverter TO_BOUND_VISITOR = new BoundConverter(); + + private BoundConverter() {} + + @Override + public Expression.WindowFunction.Bound visit(WindowBound.Preceding preceding) { + return Expression.WindowFunction.Bound.newBuilder() + .setPreceding( + Expression.WindowFunction.Bound.Preceding.newBuilder().setOffset(preceding.offset())) + .build(); + } + + @Override + public Expression.WindowFunction.Bound visit(WindowBound.Following following) { + return Expression.WindowFunction.Bound.newBuilder() + .setFollowing( + Expression.WindowFunction.Bound.Following.newBuilder().setOffset(following.offset())) + .build(); + } + + @Override + public Expression.WindowFunction.Bound visit(WindowBound.CurrentRow currentRow) { + return Expression.WindowFunction.Bound.newBuilder() .setCurrentRow(Expression.WindowFunction.Bound.CurrentRow.getDefaultInstance()) .build(); - case BOUNDED -> { - WindowBound.BoundedWindowBound boundedWindowBound = - (WindowBound.BoundedWindowBound) windowBound; - var offset = boundedWindowBound.offset(); - yield switch (boundedWindowBound.direction()) { - case PRECEDING -> Expression.WindowFunction.Bound.newBuilder() - .setPreceding( - Expression.WindowFunction.Bound.Preceding.newBuilder().setOffset(offset)) - .build(); - case FOLLOWING -> Expression.WindowFunction.Bound.newBuilder() - .setFollowing( - Expression.WindowFunction.Bound.Following.newBuilder().setOffset(offset)) - .build(); - }; - } - case UNBOUNDED -> Expression.WindowFunction.Bound.newBuilder() + } + + @Override + public Expression.WindowFunction.Bound visit(WindowBound.Unbounded unbounded) { + return Expression.WindowFunction.Bound.newBuilder() .setUnbounded(Expression.WindowFunction.Bound.Unbounded.getDefaultInstance()) .build(); - }; + } } } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index cd07aed9..c1cbfc63 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -138,8 +138,8 @@ public Expression from(io.substrait.proto.Expression expr) { .build()) .collect(java.util.stream.Collectors.toList()); - WindowBound lowerBound = toLowerBound(windowFunction.getLowerBound()); - WindowBound upperBound = toUpperBound(windowFunction.getUpperBound()); + WindowBound lowerBound = toWindowBound(windowFunction.getLowerBound()); + WindowBound upperBound = toWindowBound(windowFunction.getUpperBound()); yield Expression.WindowFunctionInvocation.builder() .arguments(args) @@ -148,6 +148,7 @@ public Expression from(io.substrait.proto.Expression expr) { .aggregationPhase(Expression.AggregationPhase.fromProto(windowFunction.getPhase())) .partitionBy(partitionExprs) .sort(sortFields) + .boundsType(Expression.WindowBoundsType.fromProto(windowFunction.getBoundsType())) .lowerBound(lowerBound) .upperBound(upperBound) .invocation(Expression.AggregationInvocation.fromProto(windowFunction.getInvocation())) @@ -247,35 +248,16 @@ public Expression from(io.substrait.proto.Expression expr) { }; } - private WindowBound toLowerBound(io.substrait.proto.Expression.WindowFunction.Bound bound) { - return toWindowBound( - bound, - WindowBound.UnboundedWindowBound.builder() - .direction(WindowBound.Direction.PRECEDING) - .build()); - } - - private WindowBound toUpperBound(io.substrait.proto.Expression.WindowFunction.Bound bound) { - return toWindowBound( - bound, - WindowBound.UnboundedWindowBound.builder() - .direction(WindowBound.Direction.FOLLOWING) - .build()); - } - - private WindowBound toWindowBound( - io.substrait.proto.Expression.WindowFunction.Bound bound, WindowBound defaultBound) { + private WindowBound toWindowBound(io.substrait.proto.Expression.WindowFunction.Bound bound) { return switch (bound.getKindCase()) { - case PRECEDING -> WindowBound.BoundedWindowBound.builder() - .direction(WindowBound.Direction.PRECEDING) - .offset(bound.getPreceding().getOffset()) - .build(); - case FOLLOWING -> WindowBound.BoundedWindowBound.builder() - .direction(WindowBound.Direction.FOLLOWING) - .offset(bound.getFollowing().getOffset()) - .build(); + case PRECEDING -> WindowBound.Preceding.of(bound.getPreceding().getOffset()); + case FOLLOWING -> WindowBound.Following.of(bound.getFollowing().getOffset()); case CURRENT_ROW -> WindowBound.CURRENT_ROW; - case UNBOUNDED, KIND_NOT_SET -> defaultBound; + case UNBOUNDED -> WindowBound.UNBOUNDED; + case KIND_NOT_SET -> + // per the spec, the lower and upper bounds default to the start or end of the partition + // respectively if not set + WindowBound.UNBOUNDED; }; } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index b528fce2..ba1dc3f9 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -7,6 +7,7 @@ import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.ExpressionRexConverter; import io.substrait.isthmus.expression.ScalarFunctionConverter; +import io.substrait.isthmus.expression.WindowFunctionConverter; import io.substrait.relation.AbstractRelVisitor; import io.substrait.relation.Aggregate; import io.substrait.relation.Cross; @@ -72,6 +73,7 @@ public SubstraitRelNodeConverter( relBuilder, new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory), new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory), + new WindowFunctionConverter(extensions.windowFunctions(), typeFactory), TypeConverter.DEFAULT); } @@ -80,6 +82,7 @@ public SubstraitRelNodeConverter( RelBuilder relBuilder, ScalarFunctionConverter scalarFunctionConverter, AggregateFunctionConverter aggregateFunctionConverter, + WindowFunctionConverter windowFunctionConverter, TypeConverter typeConverter) { this.typeFactory = typeFactory; this.typeConverter = typeConverter; @@ -89,7 +92,7 @@ public SubstraitRelNodeConverter( this.aggregateFunctionConverter = aggregateFunctionConverter; this.expressionRexConverter = new ExpressionRexConverter( - typeFactory, scalarFunctionConverter, aggregateFunctionConverter, typeConverter); + typeFactory, scalarFunctionConverter, windowFunctionConverter, typeConverter); } public static RelNode convert( diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 62ee84f5..23175263 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -1,7 +1,13 @@ package io.substrait.isthmus.expression; -import io.substrait.expression.*; +import com.google.common.collect.ImmutableList; +import io.substrait.expression.AbstractExpressionVisitor; +import io.substrait.expression.EnumArg; +import io.substrait.expression.Expression; import io.substrait.expression.Expression.SingleOrList; +import io.substrait.expression.FieldReference; +import io.substrait.expression.FunctionArg; +import io.substrait.expression.WindowBound; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.TypeConverter; import io.substrait.type.StringTypeVisitor; @@ -9,17 +15,23 @@ import io.substrait.util.DecimalUtil; import java.math.BigDecimal; import java.util.List; -import java.util.Optional; +import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; import org.apache.calcite.avatica.util.ByteString; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexFieldCollation; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexWindowBound; +import org.apache.calcite.rex.RexWindowBounds; +import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlIntervalQualifier; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; @@ -36,8 +48,7 @@ public class ExpressionRexConverter extends AbstractExpressionVisitor + new IllegalArgumentException( + callConversionFailureMessage( + "scalar", expr.declaration().name(), expr.arguments()))); + var eArgs = expr.arguments(); var args = IntStream.range(0, expr.arguments().size()) .mapToObj(i -> eArgs.get(i).accept(expr.declaration(), i, this)) .collect(java.util.stream.Collectors.toList()); - Optional operator = - scalarFunctionConverter.getSqlOperatorFromSubstraitFunc( - expr.declaration().key(), expr.outputType()); - if (operator.isPresent()) { - return rexBuilder.makeCall(operator.get(), args); - } else { - String msg = - String.format( - "Unable to convert scalar function %s(%s).", - expr.declaration().name(), - expr.arguments().stream().map(this::convert).collect(Collectors.joining(", "))); - throw new IllegalArgumentException(msg); - } + return rexBuilder.makeCall(operator, args); + } + + private String callConversionFailureMessage( + String functionType, String name, List args) { + return String.format( + "Unable to convert %s function %s(%s).", + functionType, name, args.stream().map(this::convert).collect(Collectors.joining(", "))); } @Override public RexNode visit(Expression.WindowFunctionInvocation expr) throws RuntimeException { - // todo:to construct the RexOver - return visitFallback(expr); + SqlOperator operator = + windowFunctionConverter + .getSqlOperatorFromSubstraitFunc(expr.declaration().key(), expr.outputType()) + .orElseThrow( + () -> + new IllegalArgumentException( + callConversionFailureMessage( + "window", expr.declaration().name(), expr.arguments()))); + + RelDataType outputType = typeConverter.toCalcite(typeFactory, expr.outputType()); + + List eArgs = expr.arguments(); + List args = + IntStream.range(0, expr.arguments().size()) + .mapToObj(i -> eArgs.get(i).accept(expr.declaration(), i, this)) + .collect(java.util.stream.Collectors.toList()); + + List partitionKeys = + expr.partitionBy().stream().map(e -> e.accept(this)).collect(Collectors.toList()); + + ImmutableList orderKeys = + expr.sort().stream() + .map( + sf -> { + Set direction = + switch (sf.direction()) { + case ASC_NULLS_FIRST -> Set.of(SqlKind.NULLS_FIRST); + case ASC_NULLS_LAST -> Set.of(SqlKind.NULLS_LAST); + case DESC_NULLS_FIRST -> Set.of(SqlKind.DESCENDING, SqlKind.NULLS_FIRST); + case DESC_NULLS_LAST -> Set.of(SqlKind.DESCENDING, SqlKind.NULLS_LAST); + case CLUSTERED -> throw new IllegalArgumentException( + "SORT_DIRECTION_CLUSTERED is not supported"); + }; + return new RexFieldCollation(sf.expr().accept(this), direction); + }) + .collect(ImmutableList.toImmutableList()); + + RexWindowBound lowerBound = ToRexWindowBound.lowerBound(rexBuilder, expr.lowerBound()); + RexWindowBound upperBound = ToRexWindowBound.lowerBound(rexBuilder, expr.upperBound()); + + boolean rowMode = + switch (expr.boundsType()) { + case ROWS -> true; + case RANGE -> false; + case UNSPECIFIED -> throw new IllegalArgumentException( + "bounds type on window function must be specified"); + }; + + boolean distinct = + switch (expr.invocation()) { + case UNSPECIFIED, ALL -> false; + case DISTINCT -> true; + }; + + // For queries like: SELECT last_value() IGNORE NULLS OVER ... + // Substrait has no mechanism to set this, so by default it is false + boolean ignoreNulls = false; + + // These both control a rewrite rule within rexBuilder.makeOver that rewrites the given + // expression into a case expression. These values are set as such to avoid this rewrite. + boolean nullWhenCountZero = false; + boolean allowPartial = true; + + return rexBuilder.makeOver( + outputType, + (SqlAggFunction) operator, + args, + partitionKeys, + orderKeys, + lowerBound, + upperBound, + rowMode, + allowPartial, + nullWhenCountZero, + distinct, + ignoreNulls); + } + + static class ToRexWindowBound + implements WindowBound.WindowBoundVisitor { + + static RexWindowBound lowerBound(RexBuilder rexBuilder, WindowBound bound) { + // per the spec, unbounded on the lower bound means the start of the partition + // thus UNBOUNDED_PRECEDING should be used when bound is unbounded + return bound.accept(new ToRexWindowBound(rexBuilder, RexWindowBounds.UNBOUNDED_PRECEDING)); + } + + static RexWindowBound upperBound(RexBuilder rexBuilder, WindowBound bound) { + // per the spec, unbounded on the upper bound means the end of the partition + // thus UNBOUNDED_FOLLOWING should be used when bound is unbounded + return bound.accept(new ToRexWindowBound(rexBuilder, RexWindowBounds.UNBOUNDED_FOLLOWING)); + } + + private final RexBuilder rexBuilder; + private final RexWindowBound unboundedVariant; + + private ToRexWindowBound(RexBuilder rexBuilder, RexWindowBound unboundedVariant) { + this.rexBuilder = rexBuilder; + this.unboundedVariant = unboundedVariant; + } + + @Override + public RexWindowBound visit(WindowBound.Preceding preceding) { + var offset = BigDecimal.valueOf(preceding.offset()); + return RexWindowBounds.preceding(rexBuilder.makeBigintLiteral(offset)); + } + + @Override + public RexWindowBound visit(WindowBound.Following following) { + var offset = BigDecimal.valueOf(following.offset()); + return RexWindowBounds.following(rexBuilder.makeBigintLiteral(offset)); + } + + @Override + public RexWindowBound visit(WindowBound.CurrentRow currentRow) { + return RexWindowBounds.CURRENT_ROW; + } + + @Override + public RexWindowBound visit(WindowBound.Unbounded unbounded) { + return unboundedVariant; + } } private String convert(FunctionArg a) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java index 17bb03a3..5de28714 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java @@ -99,6 +99,10 @@ public Expression visitLiteral(RexLiteral literal) { @Override public Expression visitOver(RexOver over) { + if (over.ignoreNulls()) { + throw new IllegalArgumentException("IGNORE NULLS cannot be expressed in Substrait"); + } + return windowFunctionConverter .convert(over, rexNode -> rexNode.accept(this), this) .orElseThrow(() -> new IllegalArgumentException(callConversionFailureMessage(over))); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java index a108fa9d..784db387 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java @@ -4,7 +4,6 @@ import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FunctionArg; -import io.substrait.expression.ImmutableWindowBound; import io.substrait.expression.WindowBound; import io.substrait.extension.SimpleExtension; import io.substrait.type.Type; @@ -66,8 +65,11 @@ protected Expression.WindowFunctionInvocation generateBinding( ? Expression.AggregationInvocation.DISTINCT : Expression.AggregationInvocation.ALL; - WindowBound lowerBound = toWindowBound(window.getLowerBound(), call.rexExpressionConverter); - WindowBound upperBound = toWindowBound(window.getUpperBound(), call.rexExpressionConverter); + // Calcite only supports ROW or RANGE mode + Expression.WindowBoundsType boundsType = + window.isRows() ? Expression.WindowBoundsType.ROWS : Expression.WindowBoundsType.RANGE; + WindowBound lowerBound = toWindowBound(window.getLowerBound()); + WindowBound upperBound = toWindowBound(window.getUpperBound()); return ExpressionCreator.windowFunction( function, @@ -76,6 +78,7 @@ protected Expression.WindowFunctionInvocation generateBinding( sorts, invocation, partitionExprs, + boundsType, lowerBound, upperBound, arguments); @@ -98,37 +101,32 @@ public Optional convert( return m.attemptMatch(wrapped, topLevelConverter); } - private WindowBound toWindowBound( - RexWindowBound rexWindowBound, RexExpressionConverter rexExpressionConverter) { + private WindowBound toWindowBound(RexWindowBound rexWindowBound) { if (rexWindowBound.isCurrentRow()) { return WindowBound.CURRENT_ROW; } if (rexWindowBound.isUnbounded()) { - var direction = findWindowBoundDirection(rexWindowBound); - return ImmutableWindowBound.UnboundedWindowBound.builder().direction(direction).build(); + return WindowBound.UNBOUNDED; } else { - var direction = findWindowBoundDirection(rexWindowBound); if (rexWindowBound.getOffset() instanceof RexLiteral literal && SqlTypeName.EXACT_TYPES.contains(literal.getTypeName())) { BigDecimal offset = (BigDecimal) literal.getValue4(); - return ImmutableWindowBound.BoundedWindowBound.builder() - .direction(direction) - .offset(offset.longValue()) - .build(); + if (rexWindowBound.isPreceding()) { + return WindowBound.Preceding.of(offset.longValue()); + } + if (rexWindowBound.isFollowing()) { + return WindowBound.Following.of(offset.longValue()); + } + throw new IllegalStateException( + "window bound was none of CURRENT ROW, UNBOUNDED, PRECEDING or FOLLOWING"); } throw new IllegalArgumentException( String.format( - "substrait only supports integer window offsets. Received: %", + "substrait only supports integer window offsets. Received: %s", rexWindowBound.getOffset().getKind())); } } - private WindowBound.Direction findWindowBoundDirection(RexWindowBound rexWindowBound) { - return rexWindowBound.isFollowing() - ? WindowBound.Direction.FOLLOWING - : WindowBound.Direction.PRECEDING; - } - private Expression.SortField toSortField( RexFieldCollation rexFieldCollation, RexExpressionConverter rexExpressionConverter) { var expr = rexFieldCollation.left.accept(rexExpressionConverter); diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index d11cc213..d7e18bfb 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -156,6 +156,7 @@ protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder r relBuilder, scalarFunctionConverter, aggregateFunctionConverter, + windowFunctionConverter, typeConverter); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index c1857f94..eb9d523f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -2,6 +2,7 @@ import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; import com.google.common.annotations.Beta; import com.google.common.base.Charsets; @@ -156,8 +157,6 @@ protected void assertFullRoundTrip(String sqlQuery, List createStatement // Verify that POJOs are the same assertEquals(pojo1, pojo2); - /* - // TODO vbarua: go all the way once window function conversions are allowed // Substrait POJO 2 -> Calcite 2 RelNode calcite2 = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo2); // It would be ideal to compare calcite1 and calcite2, however there isn't a good mechanism to @@ -165,11 +164,11 @@ protected void assertFullRoundTrip(String sqlQuery, List createStatement assertNotNull(calcite2); // Calcite 2 -> Substrait POJO 3 - io.substrait.relation.Rel pojo3 = SubstraitRelVisitor.convert(calcite1, EXTENSION_COLLECTION); + io.substrait.relation.Rel pojo3 = + SubstraitRelVisitor.convert(RelRoot.of(calcite2, calcite1.kind), EXTENSION_COLLECTION); // Verify that POJOs are the same assertEquals(pojo1, pojo3); - */ } } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java index 112f4980..e31ceafb 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java @@ -1,5 +1,7 @@ package io.substrait.isthmus; +import static org.junit.jupiter.api.Assertions.assertThrows; + import java.io.IOException; import org.apache.calcite.sql.parser.SqlParseException; import org.junit.jupiter.api.Nested; @@ -123,6 +125,30 @@ void rowsPrecedingAndFollowing() throws IOException, SqlParseException { assertFullRoundTrip( String.format("select min(O_SHIPPRIORITY) over (%s) from ORDERS", overClause)); } + + @Test + void rangePrecedingToCurrent() throws IOException, SqlParseException { + /* + LogicalProject(EXPR$0=[MIN($7) OVER (PARTITION BY $1 ORDER BY $3 RANGE 10 PRECEDING)]) + LogicalTableScan(table=[[ORDERS]]) + */ + var overClause = + "partition by O_CUSTKEY order by O_TOTALPRICE range between 10 preceding and current row"; + assertFullRoundTrip( + String.format("select min(O_SHIPPRIORITY) over (%s) from ORDERS", overClause)); + } + + @Test + void rangeCurrentToFollowing() throws IOException, SqlParseException { + /* + LogicalProject(EXPR$0=[MIN($7) OVER (PARTITION BY $1 ORDER BY $3 RANGE BETWEEN CURRENT ROW AND 11 FOLLOWING)]) + LogicalTableScan(table=[[ORDERS]]) + */ + var overClause = + "partition by O_CUSTKEY order by O_TOTALPRICE range between current row and 11 following"; + assertFullRoundTrip( + String.format("select min(O_SHIPPRIORITY) over (%s) from ORDERS", overClause)); + } } @Nested @@ -136,4 +162,12 @@ void standardAggregateFunctions(String aggFunction) throws SqlParseException, IO "select %s(L_LINENUMBER) over (partition BY L_PARTKEY) from lineitem", aggFunction)); } } + + @Test + void rejectQueriesWithIgnoreNulls() { + // IGNORE NULLS cannot be specified in the Substrait representation. + // Queries using it should be rejected. + var query = "select last_value(L_LINENUMBER) ignore nulls over () from lineitem"; + assertThrows(IllegalArgumentException.class, () -> assertFullRoundTrip(query)); + } }