From 9dc218b359172917041d31d7f5f8ab1af0c55aa9 Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Tue, 10 Dec 2024 21:34:09 +0800 Subject: [PATCH] remove first_value and second_value if the second parameter is false --- .../rules/analysis/WindowFunctionChecker.java | 14 ++++++-- ...CheckAndStandardizeWindowFunctionTest.java | 26 +++++++++++++++ .../nereids_syntax_p0/window_function.out | 32 +++++++++++++++++++ .../nereids_syntax_p0/window_function.groovy | 3 ++ 4 files changed, 73 insertions(+), 2 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/WindowFunctionChecker.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/WindowFunctionChecker.java index d6904ae074c7da..2da28269fd711c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/WindowFunctionChecker.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/WindowFunctionChecker.java @@ -317,9 +317,15 @@ public Lead visitLead(Lead lead, Void ctx) { @Override public FirstOrLastValue visitFirstValue(FirstValue firstValue, Void ctx) { FirstOrLastValue.checkSecondParameter(firstValue); - if (2 == firstValue.arity() && firstValue.child(1).equals(BooleanLiteral.TRUE)) { - return firstValue; + if (2 == firstValue.arity()) { + if (firstValue.child(1).equals(BooleanLiteral.TRUE)) { + return firstValue; + } else { + firstValue = (FirstValue) firstValue.withChildren(firstValue.child(0)); + windowExpression = windowExpression.withFunction(firstValue); + } } + Optional windowFrame = windowExpression.getWindowFrame(); if (windowFrame.isPresent()) { WindowFrame wf = windowFrame.get(); @@ -347,6 +353,10 @@ public FirstOrLastValue visitFirstValue(FirstValue firstValue, Void ctx) { @Override public FirstOrLastValue visitLastValue(LastValue lastValue, Void ctx) { FirstOrLastValue.checkSecondParameter(lastValue); + if (2 == lastValue.arity() && lastValue.child(1).equals(BooleanLiteral.FALSE)) { + lastValue = (LastValue) lastValue.withChildren(lastValue.child(0)); + windowExpression = windowExpression.withFunction(lastValue); + } return lastValue; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CheckAndStandardizeWindowFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CheckAndStandardizeWindowFunctionTest.java index ffe88af2fea0ba..2f3b133fdeb1ea 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CheckAndStandardizeWindowFunctionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CheckAndStandardizeWindowFunctionTest.java @@ -28,11 +28,14 @@ import org.apache.doris.nereids.trees.expressions.WindowFrame.FrameBoundary; import org.apache.doris.nereids.trees.expressions.WindowFrame.FrameUnitsType; import org.apache.doris.nereids.trees.expressions.functions.window.DenseRank; +import org.apache.doris.nereids.trees.expressions.functions.window.FirstValue; import org.apache.doris.nereids.trees.expressions.functions.window.Lag; +import org.apache.doris.nereids.trees.expressions.functions.window.LastValue; import org.apache.doris.nereids.trees.expressions.functions.window.Lead; import org.apache.doris.nereids.trees.expressions.functions.window.Rank; import org.apache.doris.nereids.trees.expressions.functions.window.RowNumber; import org.apache.doris.nereids.trees.expressions.functions.window.WindowFunction; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.plans.Plan; @@ -240,6 +243,29 @@ public void testCheckWindowFrameBeforeFunc5() { forCheckWindowFrameBeforeFunc(windowFrame2, errorMsg2); } + @Test + public void testFirstValueRewrite() { + age = rStudent.getOutput().get(3).toSlot(); + WindowExpression window = new WindowExpression(new FirstValue(age, BooleanLiteral.FALSE), partitionKeyList, orderKeyList); + Alias windowAlias = new Alias(window, window.toSql()); + WindowExpression windowLastValue = new WindowExpression(new LastValue(age, BooleanLiteral.FALSE), partitionKeyList, orderKeyList); + Alias windowLastValueAlias = new Alias(windowLastValue, windowLastValue.toSql()); + List outputExpressions = Lists.newArrayList(windowAlias, windowLastValueAlias); + Plan root = new LogicalWindow<>(outputExpressions, rStudent); + + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyTopDown(new ExtractAndNormalizeWindowExpression()) + .applyTopDown(new CheckAndStandardizeWindowFunctionAndFrame()) + .matches( + logicalWindow() + .when(logicalWindow -> { + WindowExpression newWindowFirstValue = (WindowExpression) logicalWindow.getWindowExpressions().get(0).child(0); + WindowExpression newWindowLastValue = (WindowExpression) logicalWindow.getWindowExpressions().get(0).child(0); + return newWindowFirstValue.getFunction().arity() == 1 && newWindowLastValue.getFunction().arity() == 1; + }) + ); + } + private void forCheckWindowFrameBeforeFunc(WindowFrame windowFrame, String errorMsg) { WindowExpression window = new WindowExpression(new Rank(), partitionKeyList, orderKeyList, windowFrame); forCheckWindowFrameBeforeFunc(window, errorMsg); diff --git a/regression-test/data/nereids_syntax_p0/window_function.out b/regression-test/data/nereids_syntax_p0/window_function.out index 38eba68274e18f..378524125c9474 100644 --- a/regression-test/data/nereids_syntax_p0/window_function.out +++ b/regression-test/data/nereids_syntax_p0/window_function.out @@ -567,3 +567,35 @@ -- !multi_winf2 -- 1 35 +-- !first_value_false -- +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +2 +2 +2 +2 + +-- !last_value_false -- +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 + diff --git a/regression-test/suites/nereids_syntax_p0/window_function.groovy b/regression-test/suites/nereids_syntax_p0/window_function.groovy index f2ce708f4c372a..bb19aba17c10fc 100644 --- a/regression-test/suites/nereids_syntax_p0/window_function.groovy +++ b/regression-test/suites/nereids_syntax_p0/window_function.groovy @@ -314,4 +314,7 @@ suite("window_function") { sql "select last_value(c1,false) over() from window_test" sql "select first_value(c1,1) over() from window_test" sql "select last_value(c1,0) over() from window_test" + + qt_first_value_false "select last_value(c1,false) over(partition by c2 order by c1) from window_test order by 1" + qt_last_value_false "select first_value(c1,false) over(partition by c2 order by c1) from window_test order by 1" }