From 1a8065e393b7525f5c09af04ea29aef3ee8c97f0 Mon Sep 17 00:00:00 2001 From: John Ed Quinn Date: Wed, 6 Mar 2024 14:58:06 -0800 Subject: [PATCH] Fixes aggregations of attribute references to values of union types --- CHANGELOG.md | 4 +- .../main/kotlin/org/partiql/planner/Errors.kt | 8 ++ .../partiql/planner/internal/PartiQLHeader.kt | 13 +- .../internal/transforms/PlanTransform.kt | 58 ++++++++- .../planner/internal/typer/PlanTyper.kt | 2 +- .../partiql/planner/internal/typer/TypeEnv.kt | 69 +++++++--- .../planner/internal/typer/TypeUtils.kt | 2 +- .../internal/typer/PlanTyperTestsPorted.kt | 118 +++++++++++++++++- .../kotlin/org/partiql/types/StaticType.kt | 5 +- 9 files changed, 242 insertions(+), 37 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 28803d07d9..a02eef7ea9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,10 +32,12 @@ Thank you to all who have contributed! ### Changed - Function resolution logic: Now the function resolver would match all possible candidate(based on if the argument can be coerced to the Signature parameter type). If there are multiple match it will first attempt to pick the one requires the least cast, then pick the function with the highest precedence. +- **Behavioral change**: The COUNT aggregate function now returns INT64. ### Deprecated ### Fixed +- Fixes aggregations of attribute references to values of union types. This fix also allows for proper error handling by passing the UnknownAggregateFunction problem to the ProblemCallback. Please note that, with this change, the planner will no longer immediately throw an IllegalStateException for this exact scenario. ### Removed @@ -43,7 +45,7 @@ Thank you to all who have contributed! ### Contributors Thank you to all who have contributed! -- @ +- @johnedquinn ## [0.14.3] - 2024-02-14 diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt index 4866c350a5..5e21d9e45d 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt @@ -94,6 +94,14 @@ public sealed class PlanningProblemDetails( "Unknown function `$identifier($types)" }) + public data class UnknownAggregateFunction( + val identifier: String, + val args: List, + ) : PlanningProblemDetails(ProblemSeverity.ERROR, { + val types = args.joinToString { "<${it.toString().lowercase()}>" } + "Unknown aggregate function `$identifier($types)" + }) + public object ExpressionAlwaysReturnsNullOrMissing : PlanningProblemDetails( severity = ProblemSeverity.ERROR, messageFormatter = { "Expression always returns null or missing." } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PartiQLHeader.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PartiQLHeader.kt index 24e5e22bff..ab50c6ac09 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PartiQLHeader.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PartiQLHeader.kt @@ -702,13 +702,13 @@ internal object PartiQLHeader : Header() { private fun count() = listOf( FunctionSignature.Aggregation( name = "count", - returns = INT32, + returns = INT64, parameters = listOf(FunctionParameter("value", ANY)), isNullable = false, ), FunctionSignature.Aggregation( name = "count_star", - returns = INT32, + returns = INT64, parameters = listOf(), isNullable = false, ), @@ -741,6 +741,15 @@ internal object PartiQLHeader : Header() { ) } + /** + * According to SQL:1999 Section 6.16 Syntax Rule 14.c and Rule 14.d: + * > If AVG is specified and DT is exact numeric, then the declared type of the result is exact + * numeric with implementation-defined precision not less than the precision of DT and + * implementation-defined scale not less than the scale of DT. + * + * > If DT is approximate numeric, then the declared type of the result is approximate numeric + * with implementation-defined precision not less than the precision of DT. + */ private fun avg() = types.numeric.map { FunctionSignature.Aggregation( name = "avg", diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt index 89741aadb2..87952da46d 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt @@ -1,8 +1,11 @@ package org.partiql.planner.internal.transforms +import org.partiql.errors.Problem import org.partiql.errors.ProblemCallback +import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION import org.partiql.plan.PlanNode import org.partiql.plan.partiQLPlan +import org.partiql.planner.PlanningProblemDetails import org.partiql.planner.internal.ir.Agg import org.partiql.planner.internal.ir.Catalog import org.partiql.planner.internal.ir.Fn @@ -12,7 +15,9 @@ import org.partiql.planner.internal.ir.Rel import org.partiql.planner.internal.ir.Rex import org.partiql.planner.internal.ir.Statement import org.partiql.planner.internal.ir.visitor.PlanBaseVisitor +import org.partiql.types.function.FunctionSignature import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType /** * This is an internal utility to translate from the internal unresolved plan used for typing to the public plan IR. @@ -58,7 +63,7 @@ internal object PlanTransform : PlanBaseVisitor() { override fun visitAggResolved(node: Agg.Resolved, ctx: ProblemCallback) = org.partiql.plan.Agg(node.signature) override fun visitAggUnresolved(node: Agg.Unresolved, ctx: ProblemCallback): org.partiql.plan.Rex.Op { - error("Unresolved aggregation ${node.identifier}") + error("Internal error: This should have been handled somewhere else. Cause: Unresolved aggregation ${node.identifier}.") } override fun visitStatement(node: Statement, ctx: ProblemCallback) = @@ -331,11 +336,56 @@ internal object PlanTransform : PlanBaseVisitor() { groups = node.groups.map { visitRex(it, ctx) }, ) - override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: ProblemCallback) = - org.partiql.plan.Rel.Op.Aggregate.Call( - agg = visitAgg(node.agg, ctx), + @OptIn(PartiQLValueExperimental::class) + override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: ProblemCallback): org.partiql.plan.Rel.Op.Aggregate.Call { + val agg = when (val agg = node.agg) { + is Agg.Unresolved -> { + val name = agg.identifier.toNormalizedString() + ctx.invoke( + Problem( + UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.UnknownAggregateFunction( + agg.identifier.toString(), + node.args.map { it.type } + ) + ) + ) + org.partiql.plan.Agg( + FunctionSignature.Aggregation( + "UNKNOWN_AGG::$name", + returns = PartiQLValueType.MISSING, + parameters = emptyList() + ) + ) + } + is Agg.Resolved -> { + visitAggResolved(agg, ctx) + } + } + return org.partiql.plan.Rel.Op.Aggregate.Call( + agg = agg, args = node.args.map { visitRex(it, ctx) }, ) + } + + private fun Identifier.toNormalizedString(): String { + return when (this) { + is Identifier.Symbol -> this.toNormalizedString() + is Identifier.Qualified -> { + val toJoin = listOf(this.root) + this.steps + toJoin.joinToString(separator = ".") { ident -> + ident.toNormalizedString() + } + } + } + } + + private fun Identifier.Symbol.toNormalizedString(): String { + return when (this.caseSensitivity) { + Identifier.CaseSensitivity.SENSITIVE -> "\"${this.symbol}\"" + Identifier.CaseSensitivity.INSENSITIVE -> this.symbol + } + } override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: ProblemCallback) = org.partiql.plan.Rel.Op.Exclude( input = visitRel(node.input, ctx), diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index 42fd0b9485..04e7673832 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -1168,7 +1168,7 @@ internal class PlanTyper( fun resolveAgg(agg: Agg.Unresolved, arguments: List): Pair { var missingArg = false val args = arguments.map { - val arg = visitRex(it, null) + val arg = visitRex(it, it.type) if (arg.type.isMissable()) missingArg = true arg } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt index d413abde04..e2b26d44fd 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt @@ -10,6 +10,8 @@ import org.partiql.planner.internal.ir.rexOpVarResolved import org.partiql.spi.BindingCase import org.partiql.spi.BindingName import org.partiql.spi.BindingPath +import org.partiql.types.AnyOfType +import org.partiql.types.AnyType import org.partiql.types.StaticType import org.partiql.types.StructType import org.partiql.types.TupleConstraint @@ -85,30 +87,28 @@ internal class TypeEnv(public val schema: List) { for (i in schema.indices) { val local = schema[i] val type = local.type - if (type is StructType) { - when (type.containsKey(name)) { - true -> { - if (c != null && known) { - // TODO root was already definitively matched, emit ambiguous error. - return null - } - c = rex(type, rexOpVarResolved(i)) - known = true + when (type.containsKey(name)) { + true -> { + if (c != null && known) { + // TODO root was already definitively matched, emit ambiguous error. + return null } - null -> { - if (c != null) { - if (known) { - continue - } else { - // TODO we have more than one possible match, emit ambiguous error. - return null - } + c = rex(type, rexOpVarResolved(i)) + known = true + } + null -> { + if (c != null) { + if (known) { + continue + } else { + // TODO we have more than one possible match, emit ambiguous error. + return null } - c = rex(type, rexOpVarResolved(i)) - known = false } - false -> continue + c = rex(type, rexOpVarResolved(i)) + known = false } + false -> continue } } return c @@ -152,4 +152,33 @@ internal class TypeEnv(public val schema: List) { val closed = constraints.contains(TupleConstraint.Open(false)) return if (closed) false else null } + + /** + * Searches for the [BindingName] within the given [StaticType]. + * + * Returns + * - true iff known to contain key + * - false iff known to NOT contain key + * - null iff NOT known to contain key + * + * @param name + * @return + */ + private fun StaticType.containsKey(name: BindingName): Boolean? { + return when (val type = this.flatten()) { + is StructType -> type.containsKey(name) + is AnyOfType -> { + val anyKnownToContainKey = type.allTypes.any { it.containsKey(name) == true } + val anyKnownToNotContainKey = type.allTypes.any { it.containsKey(name) == false } + val anyNotKnownToContainKey = type.allTypes.any { it.containsKey(name) == null } + when { + anyKnownToNotContainKey.not() && anyNotKnownToContainKey.not() -> true + anyKnownToContainKey.not() && anyNotKnownToContainKey -> false + else -> null + } + } + is AnyType -> null + else -> false + } + } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt index bccd22c451..d83c45de5f 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt @@ -93,7 +93,7 @@ internal fun StaticType.toRuntimeType(): PartiQLValueType { // handle anyOf(null, T) cases val t = types.filter { it !is NullType && it !is MissingType } return if (t.size != 1) { - error("Cannot have a UNION runtime type: $this") + PartiQLValueType.ANY } else { t.first().asRuntimeType() } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt index ad8114067c..787287e34d 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt @@ -1,6 +1,8 @@ package org.partiql.planner.internal.typer import com.amazon.ionelement.api.loadSingleElement +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import org.junit.jupiter.api.extension.ExtensionContext import org.junit.jupiter.api.parallel.Execution @@ -2679,14 +2681,16 @@ class PlanTyperTestsPorted { fun aggregationCases() = listOf( SuccessTestCase( name = "AGGREGATE over INTS, without alias", - query = "SELECT a, COUNT(*), SUM(a), MIN(b) FROM << {'a': 1, 'b': 2} >> GROUP BY a", + query = "SELECT a, COUNT(*), COUNT(a), SUM(a), MIN(b), MAX(a) FROM << {'a': 1, 'b': 2} >> GROUP BY a", expected = BagType( StructType( fields = mapOf( "a" to StaticType.INT4, - "_1" to StaticType.INT4, - "_2" to StaticType.INT4.asNullable(), + "_1" to StaticType.INT8, + "_2" to StaticType.INT8, "_3" to StaticType.INT4.asNullable(), + "_4" to StaticType.INT4.asNullable(), + "_5" to StaticType.INT4.asNullable(), ), contentClosed = true, constraints = setOf( @@ -2699,12 +2703,13 @@ class PlanTyperTestsPorted { ), SuccessTestCase( name = "AGGREGATE over INTS, with alias", - query = "SELECT a, COUNT(*) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1, 'b': 2} >> GROUP BY a", + query = "SELECT a, COUNT(*) AS c_s, COUNT(a) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1, 'b': 2} >> GROUP BY a", expected = BagType( StructType( fields = mapOf( "a" to StaticType.INT4, - "c" to StaticType.INT4, + "c_s" to StaticType.INT8, + "c" to StaticType.INT8, "s" to StaticType.INT4.asNullable(), "m" to StaticType.INT4.asNullable(), ), @@ -2724,7 +2729,7 @@ class PlanTyperTestsPorted { StructType( fields = mapOf( "a" to StaticType.DECIMAL, - "c" to StaticType.INT4, + "c" to StaticType.INT8, "s" to StaticType.DECIMAL.asNullable(), "m" to StaticType.DECIMAL.asNullable(), ), @@ -2737,6 +2742,53 @@ class PlanTyperTestsPorted { ) ) ), + SuccessTestCase( + name = "AGGREGATE over nullable integers", + query = """ + SELECT + a AS a, + COUNT(*) AS count_star, + COUNT(a) AS count_a, + COUNT(b) AS count_b, + SUM(a) AS sum_a, + SUM(b) AS sum_b, + MIN(a) AS min_a, + MIN(b) AS min_b, + MAX(a) AS max_a, + MAX(b) AS max_b, + AVG(a) AS avg_a, + AVG(b) AS avg_b + FROM << + { 'a': 1, 'b': 2 }, + { 'a': 3, 'b': 4 }, + { 'a': 5, 'b': NULL } + >> GROUP BY a + """.trimIndent(), + expected = BagType( + StructType( + fields = mapOf( + "a" to StaticType.INT4, + "count_star" to StaticType.INT8, + "count_a" to StaticType.INT8, + "count_b" to StaticType.INT8, + "sum_a" to StaticType.INT4.asNullable(), + "sum_b" to StaticType.INT4.asNullable(), + "min_a" to StaticType.INT4.asNullable(), + "min_b" to StaticType.INT4.asNullable(), + "max_a" to StaticType.INT4.asNullable(), + "max_b" to StaticType.INT4.asNullable(), + "avg_a" to StaticType.INT4.asNullable(), + "avg_b" to StaticType.INT4.asNullable(), + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ) + ), ) @JvmStatic @@ -2987,6 +3039,60 @@ class PlanTyperTestsPorted { // // Parameterized Tests // + + @Test + @Disabled("The planner doesn't support heterogeneous input to aggregation functions (yet?).") + fun failingTest() { + val tc = SuccessTestCase( + name = "AGGREGATE over heterogeneous data", + query = """ + SELECT + a AS a, + COUNT(*) AS count_star, + COUNT(a) AS count_a, + COUNT(b) AS count_b, + SUM(a) AS sum_a, + SUM(b) AS sum_b, + MIN(a) AS min_a, + MIN(b) AS min_b, + MAX(a) AS max_a, + MAX(b) AS max_b, + AVG(a) AS avg_a, + AVG(b) AS avg_b + FROM << + { 'a': 1.0, 'b': 2.0 }, + { 'a': 3, 'b': 4 }, + { 'a': 5, 'b': NULL } + >> GROUP BY a + """.trimIndent(), + expected = BagType( + StructType( + fields = mapOf( + "a" to StaticType.DECIMAL, + "count_star" to StaticType.INT8, + "count_a" to StaticType.INT8, + "count_b" to StaticType.INT8, + "sum_a" to StaticType.DECIMAL.asNullable(), + "sum_b" to StaticType.DECIMAL.asNullable(), + "min_a" to StaticType.DECIMAL.asNullable(), + "min_b" to StaticType.DECIMAL.asNullable(), + "max_a" to StaticType.DECIMAL.asNullable(), + "max_b" to StaticType.DECIMAL.asNullable(), + "avg_a" to StaticType.DECIMAL.asNullable(), + "avg_b" to StaticType.DECIMAL.asNullable(), + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ) + ) + runTest(tc) + } + @ParameterizedTest @ArgumentsSource(TestProvider::class) fun test(tc: TestCase) = runTest(tc) diff --git a/partiql-types/src/main/kotlin/org/partiql/types/StaticType.kt b/partiql-types/src/main/kotlin/org/partiql/types/StaticType.kt index d2ef1756f8..11b4cff532 100644 --- a/partiql-types/src/main/kotlin/org/partiql/types/StaticType.kt +++ b/partiql-types/src/main/kotlin/org/partiql/types/StaticType.kt @@ -599,9 +599,10 @@ public data class StructType( get() = listOf(this) override fun toString(): String { - val firstSeveral = fields.take(3).joinToString { "${it.key}: ${it.value}" } + val firstFieldsSize = 15 + val firstSeveral = fields.take(firstFieldsSize).joinToString { "${it.key}: ${it.value}" } return when { - fields.size <= 3 -> "struct($firstSeveral, $constraints)" + fields.size <= firstFieldsSize -> "struct($firstSeveral, $constraints)" else -> "struct($firstSeveral, ... and ${fields.size - 3} other field(s), $constraints)" } }