From dae7196929e824e28df6c0724bf5d76528952614 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Wed, 30 Aug 2023 11:42:26 -0700 Subject: [PATCH] refactor: unify aggregate and window functions in window handling (#170) feat: allow window function operands beyond just column references fix: invocation was not set when building WindowFunction proto message BREAKING CHANGES: * Window and WindowFunction have been merged into WindowFunctionInvocation * WindowFunctionInvocation is now an inner class of expression * WindowFunctionInvocation now implements Expression. * Expression visitor now visits WindowFunctionInvocation * WindowFunctionConvert constructor requires fewer arguments * windowCreator has additional fields * POJO WindowBound offset must now be a long, better reflecting spec --- .../expression/AbstractExpressionVisitor.java | 10 +- .../io/substrait/expression/Expression.java | 74 ++++--- .../expression/ExpressionCreator.java | 20 +- .../expression/ExpressionVisitor.java | 4 +- .../io/substrait/expression/WindowBound.java | 2 +- .../expression/WindowFunctionInvocation.java | 33 ---- .../proto/ExpressionProtoConverter.java | 116 ++++------- .../proto/ProtoExpressionConverter.java | 59 ++---- .../substrait/extension/SimpleExtension.java | 55 ++++-- .../isthmus/SubstraitRelVisitor.java | 8 +- .../expression/ExpressionRexConverter.java | 12 +- .../isthmus/expression/FunctionMappings.java | 2 + .../expression/RexExpressionConverter.java | 64 +++--- .../expression/WindowFunctionConverter.java | 184 +++++------------- .../substrait/isthmus/CustomFunctionTest.java | 6 +- .../io/substrait/isthmus/PlanTestBase.java | 61 ++++++ .../substrait/isthmus/WindowFunctionTest.java | 28 ++- 17 files changed, 313 insertions(+), 425 deletions(-) delete mode 100644 core/src/main/java/io/substrait/expression/WindowFunctionInvocation.java diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index b3062b2e..3192edc8 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -139,6 +139,11 @@ public OUTPUT visit(Expression.ScalarFunctionInvocation expr) throws EXCEPTION { return visitFallback(expr); } + @Override + public OUTPUT visit(Expression.WindowFunctionInvocation expr) throws EXCEPTION { + return visitFallback(expr); + } + @Override public OUTPUT visit(Expression.Cast expr) throws EXCEPTION { return visitFallback(expr); @@ -173,9 +178,4 @@ public OUTPUT visit(Expression.ScalarSubquery expr) throws EXCEPTION { public OUTPUT visit(Expression.InPredicate expr) throws EXCEPTION { return visitFallback(expr); } - - @Override - public OUTPUT visit(Expression.Window expr) throws EXCEPTION { - return visitFallback(expr); - } } diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 80088ca7..ffcad476 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -3,15 +3,12 @@ import com.google.protobuf.ByteString; import io.substrait.extension.SimpleExtension; import io.substrait.proto.AggregateFunction; -import io.substrait.relation.Aggregate; import io.substrait.relation.Rel; import io.substrait.type.Type; import java.nio.ByteBuffer; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.UUID; -import javax.annotation.Nullable; import org.immutables.value.Value; @Value.Enclosing @@ -573,6 +570,42 @@ public R accept(ExpressionVisitor visitor) throws } } + @Value.Immutable + abstract class WindowFunctionInvocation implements Expression { + + public abstract SimpleExtension.WindowFunctionVariant declaration(); + + public abstract List arguments(); + + public abstract Map options(); + + public abstract AggregationPhase aggregationPhase(); + + public abstract List partitionBy(); + + public abstract List sort(); + + public abstract WindowBound lowerBound(); + + public abstract WindowBound upperBound(); + + public abstract Type outputType(); + + public Type getType() { + return outputType(); + } + + public abstract AggregationInvocation invocation(); + + public static ImmutableExpression.WindowFunctionInvocation.Builder builder() { + return ImmutableExpression.WindowFunctionInvocation.builder(); + } + + public R accept(ExpressionVisitor visitor) throws E { + return visitor.visit(this); + } + } + @Value.Immutable abstract static class SingleOrList implements Expression { public abstract Expression condition(); @@ -684,41 +717,6 @@ public R accept(ExpressionVisitor visitor) throws } } - @Value.Immutable - abstract static class Window implements Expression { - @Nullable - public abstract Aggregate.Measure aggregateFunction(); - - @Nullable - public abstract WindowFunction windowFunction(); - - public abstract List partitionBy(); - - public abstract List orderBy(); - - public abstract WindowBound lowerBound(); - - public abstract WindowBound upperBound(); - - public abstract boolean hasNormalAggregateFunction(); - - public static ImmutableExpression.Window.Builder builder() { - return ImmutableExpression.Window.builder(); - } - - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); - } - } - - @Value.Immutable - public abstract static class WindowFunction { - public abstract WindowFunctionInvocation getFunction(); - - public abstract Optional getPreMeasureFilter(); - /** public static ImmutableMeasure.Builder builder() { return ImmutableMeasure.builder(); } */ - } - enum PredicateOp { PREDICATE_OP_UNSPECIFIED( io.substrait.proto.Expression.Subquery.SetPredicate.PredicateOp.PREDICATE_OP_UNSPECIFIED), diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index ddf94b3f..9e466e8c 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -313,36 +313,48 @@ public static AggregateFunctionInvocation aggregateFunction( .build(); } - public static WindowFunctionInvocation windowFunction( + public static Expression.WindowFunctionInvocation windowFunction( SimpleExtension.WindowFunctionVariant declaration, Type outputType, Expression.AggregationPhase phase, List sort, Expression.AggregationInvocation invocation, + List partitionBy, + WindowBound lowerBound, + WindowBound upperBound, Iterable arguments) { - return WindowFunctionInvocation.builder() + return Expression.WindowFunctionInvocation.builder() .declaration(declaration) .outputType(outputType) .aggregationPhase(phase) .sort(sort) + .partitionBy(partitionBy) + .lowerBound(lowerBound) + .upperBound(upperBound) .invocation(invocation) .addAllArguments(arguments) .build(); } - public static WindowFunctionInvocation windowFunction( + public static Expression.WindowFunctionInvocation windowFunction( SimpleExtension.WindowFunctionVariant declaration, Type outputType, Expression.AggregationPhase phase, List sort, Expression.AggregationInvocation invocation, + List partitionBy, + WindowBound lowerBound, + WindowBound upperBound, FunctionArg... arguments) { - return WindowFunctionInvocation.builder() + return Expression.WindowFunctionInvocation.builder() .declaration(declaration) .outputType(outputType) .aggregationPhase(phase) .sort(sort) .invocation(invocation) + .partitionBy(partitionBy) + .lowerBound(lowerBound) + .upperBound(upperBound) .addArguments(arguments) .build(); } diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index c95330c0..dcde321a 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -57,6 +57,8 @@ public interface ExpressionVisitor { R visit(Expression.ScalarFunctionInvocation expr) throws E; + R visit(Expression.WindowFunctionInvocation expr) throws E; + R visit(Expression.Cast expr) throws E; R visit(Expression.SingleOrList expr) throws E; @@ -70,6 +72,4 @@ public interface ExpressionVisitor { R visit(Expression.ScalarSubquery expr) throws E; R visit(Expression.InPredicate expr) throws E; - - R visit(Expression.Window expr) throws E; } diff --git a/core/src/main/java/io/substrait/expression/WindowBound.java b/core/src/main/java/io/substrait/expression/WindowBound.java index 15724c7b..e478d459 100644 --- a/core/src/main/java/io/substrait/expression/WindowBound.java +++ b/core/src/main/java/io/substrait/expression/WindowBound.java @@ -45,7 +45,7 @@ public BoundedKind boundedKind() { public abstract Direction direction(); - public abstract Expression offset(); + public abstract long offset(); public static ImmutableWindowBound.BoundedWindowBound.Builder builder() { return ImmutableWindowBound.BoundedWindowBound.builder(); diff --git a/core/src/main/java/io/substrait/expression/WindowFunctionInvocation.java b/core/src/main/java/io/substrait/expression/WindowFunctionInvocation.java deleted file mode 100644 index 5007f4d9..00000000 --- a/core/src/main/java/io/substrait/expression/WindowFunctionInvocation.java +++ /dev/null @@ -1,33 +0,0 @@ -package io.substrait.expression; - -import io.substrait.extension.SimpleExtension; -import io.substrait.type.Type; -import java.util.List; -import java.util.Map; -import org.immutables.value.Value; - -@Value.Immutable -public abstract class WindowFunctionInvocation { - - public abstract SimpleExtension.WindowFunctionVariant declaration(); - - public abstract List arguments(); - - public abstract Map options(); - - public abstract Expression.AggregationPhase aggregationPhase(); - - public abstract List sort(); - - public abstract Type outputType(); - - public Type getType() { - return outputType(); - } - - public abstract Expression.AggregationInvocation invocation(); - - public static ImmutableWindowFunctionInvocation.Builder builder() { - return ImmutableWindowFunctionInvocation.builder(); - } -} diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index c287257a..8f0d0d50 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -1,14 +1,15 @@ package io.substrait.expression.proto; -import io.substrait.expression.AbstractExpressionVisitor; import io.substrait.expression.ExpressionVisitor; import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; import io.substrait.expression.WindowBound; import io.substrait.extension.ExtensionCollector; import io.substrait.proto.Expression; +import io.substrait.proto.FunctionArgument; import io.substrait.proto.Rel; import io.substrait.proto.SortField; +import io.substrait.proto.Type; import io.substrait.relation.RelVisitor; import io.substrait.type.proto.TypeProtoConverter; import java.util.List; @@ -415,99 +416,48 @@ public Expression visit(io.substrait.expression.Expression.InPredicate expr) .build(); } - public Expression visit(io.substrait.expression.Expression.Window expr) throws RuntimeException { - var partExps = + public Expression visit(io.substrait.expression.Expression.WindowFunctionInvocation expr) + throws RuntimeException { + var argVisitor = FunctionArg.toProto(typeProtoConverter, this); + List args = + expr.arguments().stream() + .map(a -> a.accept(expr.declaration(), 0, argVisitor)) + .collect(java.util.stream.Collectors.toList()); + Type outputType = expr.getType().accept(typeProtoConverter); + + List partitionExprs = expr.partitionBy().stream() .map(e -> e.accept(this)) .collect(java.util.stream.Collectors.toList()); - var outputType = expr.getType().accept(typeProtoConverter); - var builder = Expression.WindowFunction.newBuilder().setOutputType(outputType); - if (expr.hasNormalAggregateFunction()) { - var aggMeasureFunc = expr.aggregateFunction().getFunction(); - var funcReference = extensionCollector.getFunctionReference(aggMeasureFunc.declaration()); - var argVisitor = FunctionArg.toProto(typeProtoConverter, this); - var args = - aggMeasureFunc.arguments().stream() - .map(a -> a.accept(aggMeasureFunc.declaration(), 0, argVisitor)) - .collect(java.util.stream.Collectors.toList()); - builder - .setFunctionReference(funcReference) - .setPhase(aggMeasureFunc.aggregationPhase().toProto()) - .addAllArguments(args); - } else { - var windowFunc = expr.windowFunction().getFunction(); - var funcReference = extensionCollector.getFunctionReference(windowFunc.declaration()); - var argVisitor = FunctionArg.toProto(typeProtoConverter, this); - var args = - windowFunc.arguments().stream() - .map(a -> a.accept(windowFunc.declaration(), 0, argVisitor)) - .collect(java.util.stream.Collectors.toList()); - builder - .setFunctionReference(funcReference) - .setPhase(windowFunc.aggregationPhase().toProto()) - .addAllArguments(args); - } - var sortFields = - expr.orderBy().stream() + + List sortFields = + expr.sort().stream() .map( - s -> { - return SortField.newBuilder() - .setDirection(s.direction().toProto()) - .setExpr(s.expr().accept(this)) - .build(); - }) + s -> + SortField.newBuilder() + .setDirection(s.direction().toProto()) + .setExpr(s.expr().accept(this)) + .build()) .collect(java.util.stream.Collectors.toList()); - var upperBound = toBound(expr.upperBound()); - var lowerBound = toBound(expr.lowerBound()); + + Expression.WindowFunction.Bound upperBound = toBound(expr.upperBound()); + Expression.WindowFunction.Bound lowerBound = toBound(expr.lowerBound()); + return Expression.newBuilder() .setWindowFunction( - builder - .addAllPartitions(partExps) + Expression.WindowFunction.newBuilder() + .setFunctionReference(extensionCollector.getFunctionReference(expr.declaration())) + .addAllArguments(args) + .setOutputType(outputType) + .setPhase(expr.aggregationPhase().toProto()) + .setInvocation(expr.invocation().toProto()) .addAllSorts(sortFields) + .addAllPartitions(partitionExprs) .setLowerBound(lowerBound) - .setUpperBound(upperBound) - .build()) + .setUpperBound(upperBound)) .build(); } - private static class LiteralToWindowBoundOffset - extends AbstractExpressionVisitor { - - @Override - public Long visitFallback(io.substrait.expression.Expression expr) { - throw new RuntimeException( - String.format("Expected positive integer for Window Bound offset, received: %s", expr)); - } - - private static long offsetIsPositive(long offset) { - if (offset >= 1) { - return offset; - } - throw new RuntimeException( - String.format("Expected positive offset for Window Bound offset, recieved: %d", offset)); - } - - @Override - public Long visit(io.substrait.expression.Expression.I8Literal expr) throws RuntimeException { - return offsetIsPositive(expr.value()); - } - - @Override - public Long visit(io.substrait.expression.Expression.I16Literal expr) throws RuntimeException { - return offsetIsPositive(expr.value()); - } - - @Override - public Long visit(io.substrait.expression.Expression.I32Literal expr) throws RuntimeException { - return offsetIsPositive(expr.value()); - } - - @Override - public Long visit(io.substrait.expression.Expression.I64Literal expr) throws RuntimeException { - return offsetIsPositive(expr.value()); - } - } - private Expression.WindowFunction.Bound toBound(io.substrait.expression.WindowBound windowBound) { var boundedKind = windowBound.boundedKind(); return switch (boundedKind) { @@ -517,7 +467,7 @@ private Expression.WindowFunction.Bound toBound(io.substrait.expression.WindowBo case BOUNDED -> { WindowBound.BoundedWindowBound boundedWindowBound = (WindowBound.BoundedWindowBound) windowBound; - var offset = boundedWindowBound.offset().accept(new LiteralToWindowBoundOffset()); + var offset = boundedWindowBound.offset(); yield switch (boundedWindowBound.direction()) { case PRECEDING -> Expression.WindowFunction.Bound.newBuilder() .setPreceding( diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 3503f712..cd07aed9 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -6,9 +6,7 @@ import io.substrait.expression.FunctionArg; import io.substrait.expression.ImmutableExpression; import io.substrait.expression.WindowBound; -import io.substrait.expression.WindowFunctionInvocation; import io.substrait.extension.ExtensionLookup; -import io.substrait.extension.ImmutableSimpleExtension; import io.substrait.extension.SimpleExtension; import io.substrait.relation.ProtoRelConverter; import io.substrait.type.Type; @@ -119,30 +117,12 @@ public Expression from(io.substrait.proto.Expression expr) { case WINDOW_FUNCTION -> { var windowFunction = expr.getWindowFunction(); var functionReference = windowFunction.getFunctionReference(); - SimpleExtension.WindowFunctionVariant functionVariant; - try { - functionVariant = lookup.getWindowFunction(functionReference, extensions); - } catch (RuntimeException e) { - // TODO: Ideally we shouldn't need to catch a RuntimeException to be able to attempt our - // second lookup - var aggFunctionVariant = lookup.getAggregateFunction(functionReference, extensions); - functionVariant = - ImmutableSimpleExtension.WindowFunctionVariant.builder() - // Sets all fields declared in the Function interface - .from(aggFunctionVariant) - // Set WindowFunctionVariant fields - .decomposability(aggFunctionVariant.decomposability()) - .intermediate(aggFunctionVariant.intermediate()) - // Aggregate Functions used in Windows have WindowType Streaming - .windowType(SimpleExtension.WindowType.STREAMING) - .build(); - } - final SimpleExtension.WindowFunctionVariant declaration = functionVariant; + var declaration = lookup.getWindowFunction(functionReference, extensions); - var pF = new FunctionArg.ProtoFrom(this, protoTypeConverter); + var argVisitor = new FunctionArg.ProtoFrom(this, protoTypeConverter); var args = IntStream.range(0, windowFunction.getArgumentsCount()) - .mapToObj(i -> pF.convert(declaration, i, windowFunction.getArguments(i))) + .mapToObj(i -> argVisitor.convert(declaration, i, windowFunction.getArguments(i))) .collect(java.util.stream.Collectors.toList()); var partitionExprs = windowFunction.getPartitionsList().stream() @@ -157,29 +137,20 @@ public Expression from(io.substrait.proto.Expression expr) { .expr(from(s.getExpr())) .build()) .collect(java.util.stream.Collectors.toList()); - var wfi = - WindowFunctionInvocation.builder() - .addAllArguments(args) - .declaration(declaration) - .outputType(protoTypeConverter.from(windowFunction.getOutputType())) - .aggregationPhase(Expression.AggregationPhase.fromProto(windowFunction.getPhase())) - .addAllSort(sortFields) - .invocation( - Expression.AggregationInvocation.fromProto(windowFunction.getInvocation())) - .build(); WindowBound lowerBound = toLowerBound(windowFunction.getLowerBound()); WindowBound upperBound = toUpperBound(windowFunction.getUpperBound()); - var wf = ImmutableExpression.WindowFunction.builder().function(wfi).build(); - yield Expression.Window.builder() - .windowFunction(wf) - .hasNormalAggregateFunction(false) - .type(protoTypeConverter.from(windowFunction.getOutputType())) + yield Expression.WindowFunctionInvocation.builder() + .arguments(args) + .declaration(declaration) + .outputType(protoTypeConverter.from(windowFunction.getOutputType())) + .aggregationPhase(Expression.AggregationPhase.fromProto(windowFunction.getPhase())) .partitionBy(partitionExprs) - .orderBy(sortFields) + .sort(sortFields) .lowerBound(lowerBound) .upperBound(upperBound) + .invocation(Expression.AggregationInvocation.fromProto(windowFunction.getInvocation())) .build(); } case IF_THEN -> { @@ -297,17 +268,11 @@ private WindowBound toWindowBound( return switch (bound.getKindCase()) { case PRECEDING -> WindowBound.BoundedWindowBound.builder() .direction(WindowBound.Direction.PRECEDING) - .offset( - Expression.Literal.I64Literal.builder() - .value(bound.getPreceding().getOffset()) - .build()) + .offset(bound.getPreceding().getOffset()) .build(); case FOLLOWING -> WindowBound.BoundedWindowBound.builder() .direction(WindowBound.Direction.FOLLOWING) - .offset( - Expression.Literal.I64Literal.builder() - .value(bound.getFollowing().getOffset()) - .build()) + .offset(bound.getFollowing().getOffset()) .build(); case CURRENT_ROW -> WindowBound.CURRENT_ROW; case UNBOUNDED, KIND_NOT_SET -> defaultBound; diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index fbc7b8e9..5bb5a4c8 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -464,6 +464,10 @@ public abstract static class WindowFunction { public Stream resolve(String uri) { return impls().stream().map(f -> f.resolve(uri, name(), description())); } + + public static ImmutableSimpleExtension.WindowFunction.Builder builder() { + return ImmutableSimpleExtension.WindowFunction.builder(); + } } @JsonDeserialize(as = ImmutableSimpleExtension.AggregateFunctionVariant.class) @@ -542,6 +546,10 @@ WindowFunctionVariant resolve(String uri, String name, String description) { .windowType(windowType()) .build(); } + + public static ImmutableSimpleExtension.WindowFunctionVariant.Builder builder() { + return ImmutableSimpleExtension.WindowFunctionVariant.builder(); + } } @JsonDeserialize(as = ImmutableSimpleExtension.Type.class) @@ -796,20 +804,43 @@ public static ExtensionCollection load(String namespace, InputStream stream) { public static ExtensionCollection buildExtensionCollection( String namespace, ExtensionSignatures extensionSignatures) { + List scalarFunctionVariants = + extensionSignatures.scalars().stream() + .flatMap(t -> t.resolve(namespace)) + .collect(java.util.stream.Collectors.toList()); + + List aggregateFunctionVariants = + extensionSignatures.aggregates().stream() + .flatMap(t -> t.resolve(namespace)) + .collect(java.util.stream.Collectors.toList()); + + Stream windowFunctionVariants = + extensionSignatures.windows().stream().flatMap(t -> t.resolve(namespace)); + + // Aggregate functions can be used as Window Functions + Stream windowAggFunctionVariants = + aggregateFunctionVariants.stream() + .map( + afi -> + WindowFunctionVariant.builder() + // Sets all fields declared in the Function interface + .from(afi) + // Set WindowFunctionVariant fields + .decomposability(afi.decomposability()) + .intermediate(afi.intermediate()) + // Aggregate Functions used in Windows have WindowType Streaming + .windowType(SimpleExtension.WindowType.STREAMING) + .build()); + + List allWindowFunctionVariants = + Stream.concat(windowFunctionVariants, windowAggFunctionVariants) + .collect(Collectors.toList()); + var collection = ImmutableSimpleExtension.ExtensionCollection.builder() - .addAllAggregateFunctions( - extensionSignatures.aggregates().stream() - .flatMap(t -> t.resolve(namespace)) - .collect(java.util.stream.Collectors.toList())) - .addAllScalarFunctions( - extensionSignatures.scalars().stream() - .flatMap(t -> t.resolve(namespace)) - .collect(java.util.stream.Collectors.toList())) - .addAllWindowFunctions( - extensionSignatures.windows().stream() - .flatMap(t -> t.resolve(namespace)) - .collect(java.util.stream.Collectors.toList())) + .scalarFunctions(scalarFunctionVariants) + .aggregateFunctions(aggregateFunctionVariants) + .windowFunctions(allWindowFunctionVariants) .addAllTypes(extensionSignatures.types()) .build(); logger.debug( diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index 53336a8f..44013231 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -90,8 +90,7 @@ public SubstraitRelVisitor( this.aggregateFunctionConverter = new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory); var windowFunctionConverter = - new WindowFunctionConverter( - extensions.windowFunctions(), typeFactory, aggregateFunctionConverter, typeConverter); + new WindowFunctionConverter(extensions.windowFunctions(), typeFactory); this.converter = new RexExpressionConverter(this, converters, windowFunctionConverter, typeConverter); this.featureBoard = features; @@ -168,15 +167,10 @@ public Rel visit(LogicalCalc calc) { @Override public Rel visit(LogicalProject project) { - var input = apply(project.getInput()); - this.converter.setInputRel(project.getInput()); - this.converter.setInputType(input.getRecordType()); var expressions = project.getProjects().stream() .map(this::toExpression) .collect(java.util.stream.Collectors.toList()); - this.converter.setInputRel(null); - this.converter.setInputType(null); // todo: eliminate excessive projects. This should be done by converting rexinputrefs to remaps. return Project.builder() diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 8fda2ca1..62ee84f5 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -243,6 +243,12 @@ public RexNode visit(Expression.ScalarFunctionInvocation expr) throws RuntimeExc } } + @Override + public RexNode visit(Expression.WindowFunctionInvocation expr) throws RuntimeException { + // todo:to construct the RexOver + return visitFallback(expr); + } + private String convert(FunctionArg a) { String v; if (a instanceof EnumArg ea) { @@ -282,12 +288,6 @@ public RexNode visit(FieldReference expr) throws RuntimeException { return visitFallback(expr); } - @Override - public RexNode visit(Expression.Window expr) throws RuntimeException { - // todo:to construct the RexOver - return visitFallback(expr); - } - @Override public RexNode visitFallback(Expression expr) { throw new UnsupportedOperationException( diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java index 1d7cabc0..96f32163 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java @@ -92,6 +92,8 @@ public class FunctionMappings { s(SqlStdOperatorTable.FIRST_VALUE, "first_value"), s(SqlStdOperatorTable.LAST_VALUE, "last_value"), s(SqlStdOperatorTable.NTH_VALUE, "nth_value")) + // Aggregate Functions can be used in Windows + .addAll(AGGREGATE_SIGS) .build(); // contains return-type based resolver for both scalar and aggregator operator diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java index 648c415e..17bb03a3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java @@ -7,13 +7,24 @@ import io.substrait.isthmus.TypeConverter; import io.substrait.relation.Rel; import io.substrait.type.StringTypeVisitor; -import io.substrait.type.Type; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; -import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rex.*; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexCorrelVariable; +import org.apache.calcite.rex.RexDynamicParam; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexLocalRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexOver; +import org.apache.calcite.rex.RexPatternFieldRef; +import org.apache.calcite.rex.RexRangeRef; +import org.apache.calcite.rex.RexSubQuery; +import org.apache.calcite.rex.RexTableInputRef; +import org.apache.calcite.rex.RexVisitor; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlStdOperatorTable; @@ -27,9 +38,6 @@ public class RexExpressionConverter implements RexVisitor { private final TypeConverter typeConverter; private WindowFunctionConverter windowFunctionConverter; - private RelNode inputRel; - private Type.Struct inputType; - public RexExpressionConverter(SubstraitRelVisitor relVisitor, CallConverter... callConverters) { this(relVisitor, Arrays.asList(callConverters), null, TypeConverter.DEFAULT); } @@ -66,20 +74,22 @@ public Expression visitInputRef(RexInputRef inputRef) { @Override public Expression visitCall(RexCall call) { for (var c : callConverters) { - var out = c.convert(call, r -> r.accept(this)); + var out = c.convert(call, rexNode -> rexNode.accept(this)); if (out.isPresent()) { return out.get(); } } - String msg = - String.format( - "Unable to convert call %s(%s).", - call.getOperator().getName(), - call.getOperands().stream() - .map(t -> t.accept(this).getType().accept(new StringTypeVisitor())) - .collect(Collectors.joining(", "))); - throw new IllegalArgumentException(msg); + throw new IllegalArgumentException(callConversionFailureMessage(call)); + } + + private String callConversionFailureMessage(RexCall call) { + return String.format( + "Unable to convert call %s(%s).", + call.getOperator().getName(), + call.getOperands().stream() + .map(t -> t.accept(this).getType().accept(new StringTypeVisitor())) + .collect(Collectors.joining(", "))); } @Override @@ -89,19 +99,9 @@ public Expression visitLiteral(RexLiteral literal) { @Override public Expression visitOver(RexOver over) { - // maybe a aggregate function or a window function - var exp = - windowFunctionConverter.convert( - inputRel, - inputType, - over, - t -> { - RexNode r = (RexNode) t; - return r.accept(this); - }, - this); - - return exp; + return windowFunctionConverter + .convert(over, rexNode -> rexNode.accept(this), this) + .orElseThrow(() -> new IllegalArgumentException(callConversionFailureMessage(over))); } @Override @@ -189,12 +189,4 @@ public Expression visitLocalRef(RexLocalRef localRef) { public Expression visitPatternFieldRef(RexPatternFieldRef fieldRef) { throw new UnsupportedOperationException("RexPatternFieldRef not supported"); } - - public void setInputRel(RelNode inputRel) { - this.inputRel = inputRel; - } - - public void setInputType(Type.Struct inputType) { - this.inputType = inputType; - } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java index 5b7f08ad..a108fa9d 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java @@ -4,40 +4,32 @@ import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FunctionArg; -import io.substrait.expression.ImmutableExpression; import io.substrait.expression.ImmutableWindowBound; import io.substrait.expression.WindowBound; -import io.substrait.expression.WindowFunctionInvocation; import io.substrait.extension.SimpleExtension; -import io.substrait.isthmus.SubstraitRelVisitor; -import io.substrait.isthmus.TypeConverter; -import io.substrait.relation.Aggregate; import io.substrait.type.Type; +import java.math.BigDecimal; import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.function.Function; import java.util.stream.Stream; -import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.RelFieldCollation; -import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexFieldCollation; +import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexOver; -import org.apache.calcite.rex.RexSlot; +import org.apache.calcite.rex.RexWindow; import org.apache.calcite.rex.RexWindowBound; +import org.apache.calcite.sql.type.SqlTypeName; public class WindowFunctionConverter extends FunctionConverter< SimpleExtension.WindowFunctionVariant, - WindowFunctionInvocation, - WindowFunctionConverter.WrappedAggregateCall> { - - private AggregateFunctionConverter aggregateFunctionConverter; + Expression.WindowFunctionInvocation, + WindowFunctionConverter.WrappedWindowCall> { @Override protected ImmutableList getSigs() { @@ -45,131 +37,65 @@ protected ImmutableList getSigs() { } public WindowFunctionConverter( - List functions, - RelDataTypeFactory typeFactory, - AggregateFunctionConverter aggregateFunctionConverter, - TypeConverter typeConverter) { + List functions, RelDataTypeFactory typeFactory) { super(functions, typeFactory); - this.aggregateFunctionConverter = aggregateFunctionConverter; } @Override - protected WindowFunctionInvocation generateBinding( - WrappedAggregateCall call, + protected Expression.WindowFunctionInvocation generateBinding( + WrappedWindowCall call, SimpleExtension.WindowFunctionVariant function, List arguments, Type outputType) { - AggregateCall agg = call.getUnderlying(); + RexOver over = call.over; + RexWindow window = over.getWindow(); + + var partitionExprs = + window.partitionKeys.stream() + .map(r -> r.accept(call.rexExpressionConverter)) + .collect(java.util.stream.Collectors.toList()); List sorts = - agg.getCollation() != null - ? agg.getCollation().getFieldCollations().stream() - .map(r -> SubstraitRelVisitor.toSortField(r, call.inputType)) + window.orderKeys != null + ? window.orderKeys.stream() + .map(rfc -> toSortField(rfc, call.rexExpressionConverter)) .collect(java.util.stream.Collectors.toList()) : Collections.emptyList(); Expression.AggregationInvocation invocation = - agg.isDistinct() + over.isDistinct() ? Expression.AggregationInvocation.DISTINCT : Expression.AggregationInvocation.ALL; + + WindowBound lowerBound = toWindowBound(window.getLowerBound(), call.rexExpressionConverter); + WindowBound upperBound = toWindowBound(window.getUpperBound(), call.rexExpressionConverter); + return ExpressionCreator.windowFunction( function, outputType, Expression.AggregationPhase.INITIAL_TO_RESULT, sorts, invocation, + partitionExprs, + lowerBound, + upperBound, arguments); } - public Expression.Window convert( - RelNode input, - Type.Struct inputType, + public Optional convert( RexOver over, Function topLevelConverter, RexExpressionConverter rexExpressionConverter) { - - var lowerBound = toWindowBound(over.getWindow().getLowerBound(), rexExpressionConverter); - var upperBound = toWindowBound(over.getWindow().getUpperBound(), rexExpressionConverter); - var sqlAggFunction = over.getAggOperator(); - var argList = - over.getOperands().stream() - .map(r -> ((RexSlot) r).getIndex()) - .collect(java.util.stream.Collectors.toList()); - boolean approximate = false; - int filterArg = -1; - var call = - AggregateCall.create( - sqlAggFunction, - over.isDistinct(), - approximate, - over.ignoreNulls(), - argList, - filterArg, - null, - RelCollations.EMPTY, - over.getType(), - sqlAggFunction.getName()); - var windowBuilder = Expression.Window.builder(); - var aggregateFunctionInvocation = - aggregateFunctionConverter.convert(input, inputType, call, topLevelConverter); - boolean find = false; - if (aggregateFunctionInvocation.isPresent()) { - var aggMeasure = - Aggregate.Measure.builder().function(aggregateFunctionInvocation.get()).build(); - windowBuilder.aggregateFunction(aggMeasure).hasNormalAggregateFunction(true); - find = true; - } else { - // maybe it's a window function - var windowFuncInvocation = - findWindowFunctionInvocation(input, inputType, call, topLevelConverter); - if (windowFuncInvocation.isPresent()) { - var windowFunc = - ImmutableExpression.WindowFunction.builder() - .function(windowFuncInvocation.get()) - .build(); - windowBuilder.windowFunction(windowFunc).hasNormalAggregateFunction(false); - find = true; - } - } - if (!find) { - throw new RuntimeException( - String.format( - "Not found the corresponding window aggregate function:%s", sqlAggFunction)); - } - var window = over.getWindow(); - var partitionExps = - window.partitionKeys.stream() - .map(r -> r.accept(rexExpressionConverter)) - .collect(java.util.stream.Collectors.toList()); - var sortFields = - window.orderKeys.stream() - .map(r -> toSortField(r, rexExpressionConverter)) - .collect(java.util.stream.Collectors.toList()); - - return windowBuilder - .addAllOrderBy(sortFields) - .addAllPartitionBy(partitionExps) - .lowerBound(lowerBound) - .upperBound(upperBound) - .type(typeConverter.toSubstrait(over.getType())) - .build(); - } - - private Optional findWindowFunctionInvocation( - RelNode input, - Type.Struct inputType, - AggregateCall call, - Function topLevelConverter) { - FunctionFinder m = signatures.get(call.getAggregation()); + var aggFunction = over.getAggOperator(); + FunctionFinder m = signatures.get(aggFunction); if (m == null) { return Optional.empty(); } - if (!m.allowedArgCount(call.getArgList().size())) { + if (!m.allowedArgCount(over.getOperands().size())) { return Optional.empty(); } - var wrapped = new WrappedAggregateCall(call, input, rexBuilder, inputType); - var windowFunctionInvocation = m.attemptMatch(wrapped, topLevelConverter); - return windowFunctionInvocation; + var wrapped = new WrappedWindowCall(over, rexExpressionConverter); + return m.attemptMatch(wrapped, topLevelConverter); } private WindowBound toWindowBound( @@ -182,11 +108,18 @@ private WindowBound toWindowBound( return ImmutableWindowBound.UnboundedWindowBound.builder().direction(direction).build(); } else { var direction = findWindowBoundDirection(rexWindowBound); - var offset = rexWindowBound.getOffset().accept(rexExpressionConverter); - return ImmutableWindowBound.BoundedWindowBound.builder() - .direction(direction) - .offset(offset) - .build(); + if (rexWindowBound.getOffset() instanceof RexLiteral literal + && SqlTypeName.EXACT_TYPES.contains(literal.getTypeName())) { + BigDecimal offset = (BigDecimal) literal.getValue4(); + return ImmutableWindowBound.BoundedWindowBound.builder() + .direction(direction) + .offset(offset.longValue()) + .build(); + } + throw new IllegalArgumentException( + String.format( + "substrait only supports integer window offsets. Received: %", + rexWindowBound.getOffset().getKind())); } } @@ -219,32 +152,23 @@ private Expression.SortField toSortField( return Expression.SortField.builder().expr(expr).direction(direction).build(); } - static class WrappedAggregateCall implements FunctionConverter.GenericCall { - private final AggregateCall call; - private final RelNode input; - private final RexBuilder rexBuilder; - private final Type.Struct inputType; - - private WrappedAggregateCall( - AggregateCall call, RelNode input, RexBuilder rexBuilder, Type.Struct inputType) { - this.call = call; - this.input = input; - this.rexBuilder = rexBuilder; - this.inputType = inputType; + static class WrappedWindowCall implements FunctionConverter.GenericCall { + private final RexOver over; + private final RexExpressionConverter rexExpressionConverter; + + private WrappedWindowCall(RexOver over, RexExpressionConverter rexExpressionConverter) { + this.over = over; + this.rexExpressionConverter = rexExpressionConverter; } @Override public Stream getOperands() { - return call.getArgList().stream().map(r -> rexBuilder.makeInputRef(input, r)); - } - - public AggregateCall getUnderlying() { - return call; + return over.getOperands().stream(); } @Override public RelDataType getType() { - return call.getType(); + return over.getType(); } } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index bc6488f7..d11cc213 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -137,11 +137,7 @@ public RelDataType toCalcite(Type.UserDefined type) { typeFactory, typeConverter); WindowFunctionConverter windowFunctionConverter = - new WindowFunctionConverter( - extensionCollection.windowFunctions(), - typeFactory, - aggregateFunctionConverter, - typeConverter); + new WindowFunctionConverter(extensionCollection.windowFunctions(), typeFactory); // Create a SubstraitToCalcite converter that has access to the custom Function Converters class CustomSubstraitToCalcite extends SubstraitToCalcite { diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index 9ee62d8a..c1857f94 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -3,13 +3,17 @@ import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static org.junit.jupiter.api.Assertions.assertEquals; +import com.google.common.annotations.Beta; import com.google.common.base.Charsets; import com.google.common.io.Resources; +import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; import io.substrait.plan.Plan; import io.substrait.plan.PlanProtoConverter; import io.substrait.plan.ProtoPlanConverter; +import io.substrait.relation.ProtoRelConverter; import io.substrait.relation.Rel; +import io.substrait.relation.RelProtoConverter; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -111,4 +115,61 @@ protected List assertSqlSubstraitRelRoundTrip(String query, ListFor the given transformations: + * SQL -> Calcite 1 -> Substrait POJO 1 -> Substrait Proto -> Substrait POJO 2 -> Calcite 2 -> Substrait POJO 3 + * this code also checks that: + * + *
    + *
  • Substrait POJO 1 == Substrait POJO 2 + *
  • Substrait POJO 2 == Substrait POJO 3 + *
+ */ + @Beta + protected void assertFullRoundTrip(String sqlQuery, List createStatements) + throws SqlParseException { + SqlToSubstrait sqlConverter = new SqlToSubstrait(); + List relRoots = sqlConverter.sqlToRelNode(sqlQuery, createStatements); + + for (RelRoot calcite1 : relRoots) { + var extensionCollector = new ExtensionCollector(); + + // Calcite 1 -> Substrait POJO 1 + io.substrait.relation.Rel pojo1 = SubstraitRelVisitor.convert(calcite1, EXTENSION_COLLECTION); + + // Substrait POJO 1 -> Substrait Proto + io.substrait.proto.Rel proto = new RelProtoConverter(extensionCollector).toProto(pojo1); + + // Substrait Proto -> Substrait Pojo 2 + io.substrait.relation.Rel pojo2 = + new ProtoRelConverter(extensionCollector, EXTENSION_COLLECTION).from(proto); + + // Verify that POJOs are the same + assertEquals(pojo1, pojo2); + + /* + // TODO vbarua: go all the way once window function conversions are allowed + // Substrait POJO 2 -> Calcite 2 + RelNode calcite2 = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo2); + // It would be ideal to compare calcite1 and calcite2, however there isn't a good mechanism to + // do so + assertNotNull(calcite2); + + // Calcite 2 -> Substrait POJO 3 + io.substrait.relation.Rel pojo3 = SubstraitRelVisitor.convert(calcite1, EXTENSION_COLLECTION); + + // Verify that POJOs are the same + assertEquals(pojo1, pojo3); + */ + } + } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java index 48667815..112f4980 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java @@ -2,7 +2,6 @@ import java.io.IOException; import org.apache.calcite.sql.parser.SqlParseException; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -15,7 +14,7 @@ class WindowFunctionInvocations { @Test void rowNumber() throws IOException, SqlParseException { - assertProtoPlanRoundrip("select O_ORDERKEY, row_number() over () from ORDERS"); + assertFullRoundTrip("select O_ORDERKEY, row_number() over () from ORDERS"); } @ParameterizedTest @@ -24,7 +23,7 @@ void rankFunctions(String rankFunction) throws IOException, SqlParseException { var query = String.format( "select O_ORDERKEY, %s() over (order by O_SHIPPRIORITY) from ORDERS", rankFunction); - assertProtoPlanRoundrip(query); + assertFullRoundTrip(query); } @ParameterizedTest @@ -34,21 +33,18 @@ void rankFunctionsWithPartitions(String rankFunction) throws IOException, SqlPar String.format( "select O_ORDERKEY, %s() over (partition by O_CUSTKEY order by O_SHIPPRIORITY) from ORDERS", rankFunction); - assertProtoPlanRoundrip(query); + assertFullRoundTrip(query); } @Test void cumeDist() throws IOException, SqlParseException { - assertProtoPlanRoundrip( + assertFullRoundTrip( "select O_ORDERKEY, cume_dist() over (order by O_SHIPPRIORITY) from ORDERS"); } @Test - @Disabled void ntile() throws IOException, SqlParseException { - // TODO: The WindowFunctionConverter has some assumptions about function arguments that need - // to be addressed for this to work. - assertProtoPlanRoundrip("select O_ORDERKEY, ntile(4) over () from ORDERS"); + assertFullRoundTrip("select O_ORDERKEY, ntile(4) over () from ORDERS"); } } @@ -66,7 +62,7 @@ void unbounded() throws IOException, SqlParseException { LogicalProject(EXPR$0=[MAX($7) OVER ()]) LogicalTableScan(table=[[ORDERS]]) */ - assertProtoPlanRoundrip("select max(O_SHIPPRIORITY) over () from ORDERS"); + assertFullRoundTrip("select max(O_SHIPPRIORITY) over () from ORDERS"); } @Test @@ -76,7 +72,7 @@ void unboundedPreceding() throws IOException, SqlParseException { LogicalTableScan(table=[[ORDERS]]) */ var overClause = "partition by O_CUSTKEY order by O_ORDERDATE rows unbounded preceding"; - assertProtoPlanRoundrip( + assertFullRoundTrip( String.format("select min(O_SHIPPRIORITY) over (%s) from ORDERS", overClause)); } @@ -88,7 +84,7 @@ void unboundedFollowing() throws IOException, SqlParseException { */ var overClaus = "partition by O_CUSTKEY order by O_ORDERDATE rows between current row AND unbounded following"; - assertProtoPlanRoundrip( + assertFullRoundTrip( String.format("select max(O_SHIPPRIORITY) over (%s) from ORDERS", overClaus)); } @@ -100,7 +96,7 @@ void rowsPrecedingToCurrent() throws IOException, SqlParseException { */ var overClause = "partition by O_CUSTKEY order by O_ORDERDATE rows between 1 preceding and current row"; - assertProtoPlanRoundrip( + assertFullRoundTrip( String.format("select min(O_SHIPPRIORITY) over (%s) from ORDERS", overClause)); } @@ -112,7 +108,7 @@ void currentToRowsFollowing() throws IOException, SqlParseException { */ var overClause = "partition by O_CUSTKEY order by O_ORDERDATE rows between current row and 2 following"; - assertProtoPlanRoundrip( + assertFullRoundTrip( String.format("select max(O_SHIPPRIORITY) over (%s) from ORDERS", overClause)); } @@ -124,7 +120,7 @@ void rowsPrecedingAndFollowing() throws IOException, SqlParseException { */ var overClause = "partition by O_CUSTKEY order by O_ORDERDATE rows between 3 preceding and 4 following"; - assertProtoPlanRoundrip( + assertFullRoundTrip( String.format("select min(O_SHIPPRIORITY) over (%s) from ORDERS", overClause)); } } @@ -135,7 +131,7 @@ class AggregateFunctionInvocations { @ParameterizedTest @ValueSource(strings = {"avg", "count", "max", "min", "sum"}) void standardAggregateFunctions(String aggFunction) throws SqlParseException, IOException { - assertProtoPlanRoundrip( + assertFullRoundTrip( String.format( "select %s(L_LINENUMBER) over (partition BY L_PARTKEY) from lineitem", aggFunction)); }