Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
924060929 committed Nov 26, 2024
1 parent 14d928b commit 42c0875
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.WindowFrame;
import org.apache.doris.nereids.trees.expressions.WindowFrame.FrameBoundType;
import org.apache.doris.nereids.trees.expressions.WindowFrame.FrameBoundary;
import org.apache.doris.nereids.trees.expressions.WindowFrame.FrameUnitsType;
import org.apache.doris.nereids.trees.expressions.functions.window.DenseRank;
import org.apache.doris.nereids.trees.expressions.functions.window.Rank;
import org.apache.doris.nereids.trees.expressions.functions.window.RowNumber;
Expand Down Expand Up @@ -215,13 +218,34 @@ && child(0).child(0) instanceof LogicalPartitionTopN)) {
long chosenPartitionLimit = Long.MAX_VALUE;
long chosenRowNumberPartitionLimit = Long.MAX_VALUE;
boolean hasRowNumber = false;
long atLeastLimit = -1;
for (NamedExpression windowExpr : windowExpressions) {
if (windowExpr == null || windowExpr.children().size() != 1
|| !(windowExpr.child(0) instanceof WindowExpression)) {
continue;
}
WindowExpression windowFunc = (WindowExpression) windowExpr.child(0);

Optional<WindowFrame> windowFrame = windowFunc.getWindowFrame();
if (windowFrame.isPresent()) {
WindowFrame frame = windowFrame.get();
FrameUnitsType frameUnits = frame.getFrameUnits();
FrameBoundary rightBoundary = frame.getRightBoundary();
if (rightBoundary.getFrameBoundType() == FrameBoundType.UNBOUNDED_FOLLOWING) {
return null;
} else if (frameUnits == FrameUnitsType.ROWS
&& rightBoundary.getFrameBoundType().isFollowing()
&& rightBoundary.getBoundOffset().isPresent()
&& rightBoundary.getBoundOffset().get() instanceof IntegerLikeLiteral) {
IntegerLikeLiteral intLiteral
= (IntegerLikeLiteral) rightBoundary.getBoundOffset().get();
long offset = intLiteral.getLongValue();
if (offset + 1 > atLeastLimit) {
atLeastLimit = offset + 1;
}
}
}

// Check the window function name.
if (!(windowFunc.getFunction() instanceof RowNumber
|| windowFunc.getFunction() instanceof Rank
Expand All @@ -235,7 +259,6 @@ && child(0).child(0) instanceof LogicalPartitionTopN)) {
}

// Check the window type and window frame.
Optional<WindowFrame> windowFrame = windowFunc.getWindowFrame();
if (windowFrame.isPresent()) {
WindowFrame frame = windowFrame.get();
if (!(frame.getLeftBoundary().getFrameBoundType() == WindowFrame.FrameBoundType.UNBOUNDED_PRECEDING
Expand Down Expand Up @@ -307,7 +330,9 @@ && child(0).child(0) instanceof LogicalPartitionTopN)) {
&& chosenRowNumberPartitionLimit == Long.MAX_VALUE)) {
return null;
} else {
return Pair.of(chosenWindowFunc, hasRowNumber ? chosenRowNumberPartitionLimit : chosenPartitionLimit);
return Pair.of(chosenWindowFunc,
Math.max(atLeastLimit, hasRowNumber ? chosenRowNumberPartitionLimit : chosenPartitionLimit)
);
}
}

Expand Down
42 changes: 42 additions & 0 deletions regression-test/suites/nereids_syntax_p0/window_function.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -240,4 +240,46 @@ suite("window_function") {
"""

qt_sql """ select LAST_VALUE(col_tinyint_undef_signed_not_null) over (partition by col_double_undef_signed_not_null, col_int_undef_signed, (col_float_undef_signed_not_null - col_int_undef_signed), round_bankers(col_int_undef_signed) order by pk rows between unbounded preceding and 4 preceding) AS col_alias56089 from table_200_undef_partitions2_keys3_properties4_distributed_by53 order by col_alias56089; """

test {
sql """select *
from (
select
row_number() over(partition by c1 order by c2) rn,
lead(c2, 2, '') over(partition by c1 order by c2)
from (
select 1 as c1, 'a' as c2
union all
select 1 as c1, 'b' as c2
union all
select 1 as c1, 'c' as c2
union all
select 1 as c1, 'd' as c2
union all
select 1 as c1, 'e' as c2
)t
)a where rn=1"""
result([[1L, "c"]])
}

test {
sql """select *
from (
select
row_number() over(partition by c1 order by c2) rn,
sum(c2) over(order by c2 range between unbounded preceding and unbounded following)
from (
select 1 as c1, 5 as c2
union all
select 1 as c1, 6 as c2
union all
select 1 as c1, 7 as c2
union all
select 1 as c1, 8 as c2
union all
select 1 as c1, 9 as c2
)t
)a where rn=1"""
result([[1L, 35L]])
}
}

0 comments on commit 42c0875

Please sign in to comment.