From c7cde31a890ee6c94e85ae749e8718de2252e0a5 Mon Sep 17 00:00:00 2001 From: Zoltan Haindrich Date: Mon, 29 Jul 2024 10:11:36 +0200 Subject: [PATCH] HAVING clauses may not contain window functions (#16742) Rejects having clauses if they contain windowed expressions. Also added a check to produce a more descriptive error if an OVER expression reaches the filter translation layer. --------- Co-authored-by: Benedict Jin --- .../sql/calcite/expression/Expressions.java | 7 +++ .../calcite/planner/DruidSqlValidator.java | 50 +++++++++++++++++++ .../sql/calcite/CalciteSelectQueryTest.java | 18 +++++++ .../calcite/expression/ExpressionsTest.java | 36 +++++++++++++ 4 files changed, 111 insertions(+) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java index ec518ee4522f..689db77c0faa 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java @@ -28,6 +28,7 @@ import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexOver; import org.apache.calcite.rex.RexUtil; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperator; @@ -36,6 +37,7 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularity; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExpressionType; @@ -66,6 +68,7 @@ import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.ExpressionParser; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.CannotBuildQueryException; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.table.RowSignatures; import org.joda.time.Interval; @@ -238,6 +241,10 @@ public static DruidExpression toDruidExpressionWithPostAggOperands( final SqlKind kind = rexNode.getKind(); if (kind == SqlKind.INPUT_REF) { return inputRefToDruidExpression(rowSignature, rexNode); + } else if (rexNode instanceof RexOver) { + throw new CannotBuildQueryException( + StringUtils.format("Unexpected OVER expression during translation [%s]", rexNode) + ); } else if (rexNode instanceof RexCall) { return rexCallToDruidExpression(plannerContext, rowSignature, rexNode, postAggregatorVisitor); } else if (kind == SqlKind.LITERAL) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidSqlValidator.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidSqlValidator.java index 16d3541e96c6..857c8cb0d120 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidSqlValidator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidSqlValidator.java @@ -36,6 +36,7 @@ import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlNodeList; import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.SqlOverOperator; import org.apache.calcite.sql.SqlSelect; import org.apache.calcite.sql.SqlSelectKeyword; import org.apache.calcite.sql.SqlUpdate; @@ -46,6 +47,8 @@ import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.sql.util.SqlBasicVisitor; +import org.apache.calcite.sql.util.SqlVisitor; import org.apache.calcite.sql.validate.IdentifierNamespace; import org.apache.calcite.sql.validate.SelectNamespace; import org.apache.calcite.sql.validate.SqlNonNullableAccessors; @@ -83,6 +86,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.Predicate; import java.util.regex.Pattern; /** @@ -902,4 +906,50 @@ private boolean isSqlCallDistinct(@Nullable SqlCall call) && call.getFunctionQuantifier() != null && call.getFunctionQuantifier().getValue() == SqlSelectKeyword.DISTINCT; } + + @Override + protected void validateHavingClause(SqlSelect select) + { + super.validateHavingClause(select); + SqlNode having = select.getHaving(); + if (containsOver(having)) { + throw buildCalciteContextException("Window functions are not allowed in HAVING", having); + } + } + + private boolean containsOver(SqlNode having) + { + if (having == null) { + return false; + } + final Predicate callPredicate = call -> call.getOperator() instanceof SqlOverOperator; + return containsCall(having, callPredicate); + } + + // copy of SqlUtil#containsCall + /** Returns whether an AST tree contains a call that matches a given + * predicate. */ + private static boolean containsCall(SqlNode node, + Predicate callPredicate) + { + try { + SqlVisitor visitor = + new SqlBasicVisitor() { + @Override public Void visit(SqlCall call) + { + if (callPredicate.test(call)) { + throw new Util.FoundOne(call); + } + return super.visit(call); + } + }; + node.accept(visitor); + return false; + } + catch (Util.FoundOne e) { + Util.swallow(e, null); + return true; + } + } + } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java index 3c47eb20491e..6a0f1742155c 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java @@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.calcite.rel.RelNode; import org.apache.druid.common.config.NullHandling; +import org.apache.druid.error.DruidException; import org.apache.druid.error.DruidExceptionMatcher; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.Intervals; @@ -2139,4 +2140,21 @@ public void testSqlToRelInConversion() ) .run(); } + + @Test + public void testRejectHavingWithWindowExpression() + { + assertEquals( + "1.37.0", + RelNode.class.getPackage().getImplementationVersion(), + "Calcite version changed; check if CALCITE-6473 is fixed and remove:\n * this assertion\n * DruidSqlValidator#validateHavingClause" + ); + + testQueryThrows( + "SELECT cityName,sum(1) OVER () as w FROM wikipedia group by cityName HAVING w > 10", + ImmutableMap.of(PlannerContext.CTX_ENABLE_WINDOW_FNS, true), + DruidException.class, + invalidSqlContains("Window functions are not allowed in HAVING") + ); + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java index 86b5d38639be..c94eac39703d 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java @@ -24,6 +24,8 @@ import org.apache.calcite.avatica.util.TimeUnit; import org.apache.calcite.avatica.util.TimeUnitRange; import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexWindowBounds; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlIntervalQualifier; import org.apache.calcite.sql.SqlOperator; @@ -72,7 +74,9 @@ import org.apache.druid.sql.calcite.planner.DruidOperatorTable; import org.apache.druid.sql.calcite.planner.DruidTypeSystem; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.CannotBuildQueryException; import org.apache.druid.sql.calcite.util.CalciteTestBase; +import org.hamcrest.core.StringContains; import org.joda.time.DateTimeZone; import org.joda.time.Period; import org.junit.Assert; @@ -84,6 +88,8 @@ import java.util.Collections; import java.util.Map; +import static org.hamcrest.MatcherAssert.assertThat; + public class ExpressionsTest extends CalciteTestBase { private static final RowSignature ROW_SIGNATURE = RowSignature @@ -2827,6 +2833,36 @@ public void testHumanReadableDecimalByteFormat() ); } + @Test + public void testPresenceOfOverIsInvalid() + { + final RexBuilder rexBuilder = new RexBuilder(DruidTypeSystem.TYPE_FACTORY); + final PlannerContext plannerContext = Mockito.mock(PlannerContext.class); + Mockito.when(plannerContext.getTimeZone()).thenReturn(DateTimeZone.UTC); + + RexNode rexNode = rexBuilder.makeOver( + testHelper.createSqlType(SqlTypeName.BIGINT), + SqlStdOperatorTable.SUM, + Collections.emptyList(), + Collections.emptyList(), + ImmutableList.of(), + RexWindowBounds.CURRENT_ROW, + RexWindowBounds.CURRENT_ROW, + false, + true, + false, + false, + false + ); + + CannotBuildQueryException t = Assert.assertThrows( + CannotBuildQueryException.class, + () -> testHelper.testExpression(rexNode, null, plannerContext) + ); + + assertThat(t.getMessage(), StringContains.containsString("Unexpected OVER expression")); + } + @Test public void testCalciteLiteralToDruidLiteral() {