diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/AstPass.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/AstPass.kt index e8ad093e70..722e0309eb 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/AstPass.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/AstPass.kt @@ -1,3 +1,17 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + package org.partiql.ast.normalize import org.partiql.ast.Statement diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/Normalize.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/Normalize.kt index 6276c6b1c2..c4aadcf42a 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/Normalize.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/Normalize.kt @@ -1,3 +1,17 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + package org.partiql.ast.normalize import org.partiql.ast.Statement @@ -10,5 +24,6 @@ public fun Statement.normalize(): Statement { var ast = this ast = NormalizeFromSource.apply(ast) ast = NormalizeSelect.apply(ast) + ast = NormalizeGroupBy.apply(ast) return ast } diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeFromSource.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeFromSource.kt index ef6e9dde3b..d0c071da84 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeFromSource.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeFromSource.kt @@ -1,3 +1,17 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + package org.partiql.ast.normalize import org.partiql.ast.AstNode diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeGroupBy.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeGroupBy.kt new file mode 100644 index 0000000000..4ef1f701c6 --- /dev/null +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeGroupBy.kt @@ -0,0 +1,46 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.ast.normalize + +import org.partiql.ast.Expr +import org.partiql.ast.GroupBy +import org.partiql.ast.Statement +import org.partiql.ast.groupByKey +import org.partiql.ast.helpers.toBinder +import org.partiql.ast.util.AstRewriter + +/** + * Adds a unique binder to each group key. + */ +object NormalizeGroupBy : AstPass { + + override fun apply(statement: Statement) = Visitor.visitStatement(statement, 0) as Statement + + private object Visitor : AstRewriter() { + + override fun visitGroupByKey(node: GroupBy.Key, ctx: Int): GroupBy.Key { + val expr = visitExpr(node.expr, 0) as Expr + val alias = when (node.asAlias) { + null -> expr.toBinder(ctx) + else -> node.asAlias + } + return if (expr !== node.expr || alias !== node.asAlias) { + groupByKey(expr, alias) + } else { + node + } + } + } +} diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectList.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectList.kt new file mode 100644 index 0000000000..e0fe892cdc --- /dev/null +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectList.kt @@ -0,0 +1,67 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.ast.normalize + +import org.partiql.ast.Expr +import org.partiql.ast.Select +import org.partiql.ast.Statement +import org.partiql.ast.builder.ast +import org.partiql.ast.helpers.toBinder +import org.partiql.ast.util.AstRewriter + +/** + * Adds an `as` alias to every select-list item. + * + * - [org.partiql.ast.helpers.toBinder] + * - https://partiql.org/assets/PartiQL-Specification.pdf#page=28 + * - https://web.cecs.pdx.edu/~len/sql1999.pdf#page=287 + */ +internal object NormalizeSelectList : AstPass { + + override fun apply(statement: Statement) = Visitor.visitStatement(statement, 0) as Statement + + private object Visitor : AstRewriter() { + + override fun visitSelectProject(node: Select.Project, ctx: Int) = ast { + if (node.items.isEmpty()) { + return@ast node + } + var diff = false + val transformed = ArrayList(node.items.size) + node.items.forEachIndexed { i, n -> + val item = visitSelectProjectItem(n, i) as Select.Project.Item + if (item !== n) diff = true + transformed.add(item) + } + // We don't want to create a new list unless we have to, as to not trigger further rewrites up the tree. + if (diff) selectProject(transformed) else node + } + + override fun visitSelectProjectItemAll(node: Select.Project.Item.All, ctx: Int) = node.copy() + + override fun visitSelectProjectItemExpression(node: Select.Project.Item.Expression, ctx: Int) = ast { + val expr = visitExpr(node.expr, 0) as Expr + val alias = when (node.asAlias) { + null -> expr.toBinder(ctx) + else -> node.asAlias + } + if (expr != node.expr || alias != node.asAlias) { + selectProjectItemExpression(expr, alias) + } else { + node + } + } + } +} diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectStar.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectStar.kt index 7ad039559f..97f4272d4c 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectStar.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectStar.kt @@ -1,3 +1,17 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + package org.partiql.ast.normalize import org.partiql.ast.Expr diff --git a/partiql-plan/src/main/resources/partiql_plan_0_1.ion b/partiql-plan/src/main/resources/partiql_plan_0_1.ion index 0cdab738e6..7e1a0efa1f 100644 --- a/partiql-plan/src/main/resources/partiql_plan_0_1.ion +++ b/partiql-plan/src/main/resources/partiql_plan_0_1.ion @@ -2,7 +2,8 @@ imports::{ kotlin: [ partiql_value::'org.partiql.value.PartiQLValue', static_type::'org.partiql.types.StaticType', - function_signature::'org.partiql.types.function.FunctionSignature', + scalar_signature::'org.partiql.types.function.FunctionSignature$Scalar', + aggregation_signature::'org.partiql.types.function.FunctionSignature$Aggregation', ], } @@ -28,7 +29,16 @@ global::{ fn::[ resolved::{ - signature: function_signature, + signature: scalar_signature, + }, + unresolved::{ + identifier: identifier, + }, +] + +agg::[ + resolved::{ + signature: aggregation_signature, }, unresolved::{ identifier: identifier, @@ -262,11 +272,11 @@ rel::{ aggregate::{ input: rel, strategy: [ FULL, PARTIAL ], - aggs: list::[agg], + calls: list::[call], groups: list::[rex], _: [ - agg::{ - fn: fn, + call::{ + agg: agg, args: list::[rex], }, ], diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/Env.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/Env.kt index 5fcc5b4914..225157599b 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/Env.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/Env.kt @@ -1,5 +1,6 @@ package org.partiql.planner +import org.partiql.plan.Agg import org.partiql.plan.Fn import org.partiql.plan.Global import org.partiql.plan.Identifier @@ -10,7 +11,6 @@ import org.partiql.plan.identifierQualified import org.partiql.plan.identifierSymbol import org.partiql.planner.typer.FunctionResolver import org.partiql.planner.typer.Mapping -import org.partiql.planner.typer.isNullOrMissing import org.partiql.planner.typer.toRuntimeType import org.partiql.spi.BindingCase import org.partiql.spi.BindingName @@ -73,7 +73,7 @@ internal class TypeEnv( /** * Result of attempting to match an unresolved function. */ -internal sealed class FnMatch { +internal sealed class FnMatch { /** * 7.1 Inputs with wrong types @@ -83,17 +83,17 @@ internal sealed class FnMatch { * @property mapping * @property isMissable TRUE when anyone of the arguments _could_ be MISSING. We *always* propagate MISSING. */ - public data class Ok( - public val signature: FunctionSignature, + public data class Ok( + public val signature: T, public val mapping: Mapping, public val isMissable: Boolean, - ) : FnMatch() + ) : FnMatch() - public data class Error( - public val fn: Fn.Unresolved, + public data class Error( + public val identifier: Identifier, public val args: List, public val candidates: List, - ) : FnMatch() + ) : FnMatch() } /** @@ -195,21 +195,42 @@ internal class Env( } /** - * Leverages a [FunctionResolver] to find a matching function defined in the [Header]. + * Leverages a [FunctionResolver] to find a matching function defined in the [Header] scalar function catalog. */ - internal fun resolveFn(fn: Fn.Unresolved, args: List): FnMatch { + internal fun resolveFn(fn: Fn.Unresolved, args: List): FnMatch { val candidates = header.lookup(fn) var hadMissingArg = false val parameters = args.mapIndexed { i, arg -> if (!hadMissingArg && arg.type.isMissable()) { hadMissingArg = true } - arg.type.isNullOrMissing() FunctionParameter("arg-$i", arg.type.toRuntimeType()) } val match = functionResolver.match(candidates, parameters) return when (match) { - null -> FnMatch.Error(fn, args, candidates) + null -> FnMatch.Error(fn.identifier, args, candidates) + else -> { + val isMissable = hadMissingArg || header.isUnsafeCast(match.signature.specific) + FnMatch.Ok(match.signature, match.mapping, isMissable) + } + } + } + + /** + * Leverages a [FunctionResolver] to find a matching function defined in the [Header] aggregation function catalog. + */ + internal fun resolveAgg(agg: Agg.Unresolved, args: List): FnMatch { + val candidates = header.lookup(agg) + var hadMissingArg = false + val parameters = args.mapIndexed { i, arg -> + if (!hadMissingArg && arg.type.isMissable()) { + hadMissingArg = true + } + FunctionParameter("arg-$i", arg.type.toRuntimeType()) + } + val match = functionResolver.match(candidates, parameters) + return when (match) { + null -> FnMatch.Error(agg.identifier, args, candidates) else -> { val isMissable = hadMissingArg || header.isUnsafeCast(match.signature.specific) FnMatch.Ok(match.signature, match.mapping, isMissable) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt index 91557b621b..33d114b063 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt @@ -1,6 +1,7 @@ package org.partiql.planner import org.partiql.ast.DatetimeField +import org.partiql.plan.Agg import org.partiql.plan.Fn import org.partiql.plan.Identifier import org.partiql.planner.typer.CastType @@ -38,14 +39,9 @@ import org.partiql.value.PartiQLValueType.TIME import org.partiql.value.PartiQLValueType.TIMESTAMP /** - * A structure for function lookup. + * A structure for scalar function lookup. */ -private typealias FunctionMap = Map> - -/** - * Unicode non-character to be used for name sanitization - */ -// private val OP: Char = Char(0xFDEF) +private typealias FnMap = Map> /** * Map session attributes to underlying function name. @@ -60,28 +56,38 @@ internal val ATTRIBUTES: Map = mapOf( * * @property namespace Definition namespace e.g. partiql, spark, redshift, ... * @property types Type definitions - * @property functions Function definitions + * @property functions Scalar function definitions + * @property aggregations Aggregation function definitions */ @OptIn(PartiQLValueExperimental::class) internal class Header( private val namespace: String, private val types: TypeLattice, - private val functions: FunctionMap, + private val functions: FnMap, + private val aggregations: FnMap, private val unsafeCastSet: Set, ) { /** - * Return a list of all function signatures matching the given identifier. + * Return a list of all scalar function signatures matching the given identifier. */ - public fun lookup(ref: Fn.Unresolved): List { + public fun lookup(ref: Fn.Unresolved): List { val name = getFnName(ref.identifier) return functions.getOrDefault(name, emptyList()) } + /** + * Return a list of all aggregation function signatures matching the given identifier. + */ + public fun lookup(ref: Agg.Unresolved): List { + val name = getFnName(ref.identifier) + return aggregations.getOrDefault(name, emptyList()) + } + /** * Returns the CAST function if exists, else null. */ - public fun lookupCoercion(valueType: PartiQLValueType, targetType: PartiQLValueType): FunctionSignature? { + public fun lookupCoercion(valueType: PartiQLValueType, targetType: PartiQLValueType): FunctionSignature.Scalar? { if (!types.canCoerce(valueType, targetType)) { return null } @@ -127,7 +133,6 @@ internal class Header( /** * TODO TEMPORARY — Hardcoded PartiQL Global Catalog - * TODO BUG — We don't validate function overloads */ public fun partiql(): Header { val namespace = "partiql" @@ -137,10 +142,12 @@ internal class Header( casts, Functions.operators(), Functions.builtins(), - Functions.special(), Functions.system(), ) - return Header(namespace, types, functions, unsafeCastSet) + val aggregations = Functions.combine( + Functions.aggregations(), + ) + return Header(namespace, types, functions, aggregations, unsafeCastSet) } /** @@ -159,20 +166,72 @@ internal class Header( internal object Functions { /** - * Produce a function map (grouping by name) from a list of signatures. + * Group list of [FunctionSignature.Scalar] by name. */ - public fun combine(vararg functions: List): FunctionMap { + public fun combine(vararg functions: List): FnMap { return functions.flatMap { it.sortedWith(functionPrecedence) }.groupBy { it.name } } + // ==================================== + // TYPES + // ==================================== + + private val allTypes = PartiQLValueType.values() + + private val nullableTypes = listOf( + NULL, // null.null + MISSING, // missing + ) + + private val intTypes = listOf( + INT8, + INT16, + INT32, + INT64, + INT, + ) + + private val numericTypes = listOf( + INT8, + INT16, + INT32, + INT64, + INT, + DECIMAL, + FLOAT32, + FLOAT64, + ) + + private val textTypes = listOf( + STRING, + SYMBOL, + CLOB, + ) + + private val collectionTypes = listOf( + BAG, + LIST, + SEXP, + ) + + private val datetimeTypes = listOf( + DATE, + TIME, + TIMESTAMP, + ) + + // ==================================== + // SCALAR FUNCTIONS + // ==================================== + /** * Generate all CAST functions from the given lattice. * * @param lattice - * @return Pair(0) is the function list, Pair(1) represens the unsafe cast specifics + * @return Pair(0) is the function list, Pair(1) represents the unsafe cast specifics */ - public fun casts(lattice: TypeLattice): Pair, Set> { - val casts = mutableListOf() + public fun casts(lattice: TypeLattice): Pair, Set> { + val casts = mutableListOf() val unsafeCastSet = mutableSetOf() for (t1 in lattice.types) { for (t2 in lattice.types) { @@ -190,7 +249,7 @@ internal class Header( /** * Generate all unary and binary operator signatures. */ - public fun operators(): List = listOf( + public fun operators(): List = listOf( not(), pos(), neg(), @@ -212,17 +271,11 @@ internal class Header( ).flatten() /** - * SQL Builtins (not special forms) + * SQL and PartiQL Scalar Builtins */ - public fun builtins(): List = listOf( + public fun builtins(): List = listOf( upper(), lower(), - ).flatten() - - /** - * SQL and PartiQL special forms - */ - public fun special(): List = listOf( like(), between(), inCollection(), @@ -242,98 +295,30 @@ internal class Header( utcNow(), ).flatten() - public fun system(): List = listOf( + /** + * System functions (for now, CURRENT_USER and CURRENT_DATE) + * + * @return + */ + public fun system(): List = listOf( currentUser(), currentDate(), ) - private val allTypes = PartiQLValueType.values() - - private val nullableTypes = listOf( - NULL, // null.null - MISSING, // missing - ) - - private val intTypes = listOf( - INT8, - INT16, - INT32, - INT64, - INT, - ) - - private val numericTypes = listOf( - INT8, - INT16, - INT32, - INT64, - INT, - DECIMAL, - FLOAT32, - FLOAT64, - ) - - private val textTypes = listOf( - STRING, - SYMBOL, - CLOB, - ) - - private val collectionTypes = listOf( - BAG, - LIST, - SEXP, - ) - - private val datetimeTypes = listOf( - DATE, - TIME, - TIMESTAMP, - ) - - public fun unary(name: String, returns: PartiQLValueType, value: PartiQLValueType) = - FunctionSignature( - name = name, - returns = returns, - parameters = listOf(FunctionParameter("value", value)), - isNullCall = true, - isNullable = false, - ) - - public fun binary(name: String, returns: PartiQLValueType, lhs: PartiQLValueType, rhs: PartiQLValueType) = - FunctionSignature( - name = name, - returns = returns, - parameters = listOf(FunctionParameter("lhs", lhs), FunctionParameter("rhs", rhs)), - isNullCall = true, - isNullable = false, - ) - - public fun cast(operand: PartiQLValueType, target: PartiQLValueType) = - FunctionSignature( - name = castName(target), - returns = target, - isNullCall = true, - isNullable = false, - parameters = listOf( - FunctionParameter("value", operand), - ) - ) - // OPERATORS - private fun not(): List = listOf(unary("not", BOOL, BOOL)) + private fun not(): List = listOf(unary("not", BOOL, BOOL)) - private fun pos(): List = numericTypes.map { t -> + private fun pos(): List = numericTypes.map { t -> unary("pos", t, t) } - private fun neg(): List = numericTypes.map { t -> + private fun neg(): List = numericTypes.map { t -> unary("neg", t, t) } - private fun eq(): List = allTypes.map { t -> - FunctionSignature( + private fun eq(): List = allTypes.map { t -> + FunctionSignature.Scalar( name = "eq", returns = BOOL, isNullCall = false, @@ -342,62 +327,62 @@ internal class Header( ) } - private fun ne(): List = allTypes.map { t -> + private fun ne(): List = allTypes.map { t -> binary("ne", BOOL, t, t) } - private fun and(): List = listOf(binary("and", BOOL, BOOL, BOOL)) + private fun and(): List = listOf(binary("and", BOOL, BOOL, BOOL)) - private fun or(): List = listOf(binary("or", BOOL, BOOL, BOOL)) + private fun or(): List = listOf(binary("or", BOOL, BOOL, BOOL)) - private fun lt(): List = numericTypes.map { t -> + private fun lt(): List = numericTypes.map { t -> binary("lt", BOOL, t, t) } - private fun lte(): List = numericTypes.map { t -> + private fun lte(): List = numericTypes.map { t -> binary("lte", BOOL, t, t) } - private fun gt(): List = numericTypes.map { t -> + private fun gt(): List = numericTypes.map { t -> binary("gt", BOOL, t, t) } - private fun gte(): List = numericTypes.map { t -> + private fun gte(): List = numericTypes.map { t -> binary("gte", BOOL, t, t) } - private fun plus(): List = numericTypes.map { t -> + private fun plus(): List = numericTypes.map { t -> binary("plus", t, t, t) } - private fun minus(): List = numericTypes.map { t -> + private fun minus(): List = numericTypes.map { t -> binary("minus", t, t, t) } - private fun times(): List = numericTypes.map { t -> + private fun times(): List = numericTypes.map { t -> binary("times", t, t, t) } - private fun div(): List = numericTypes.map { t -> + private fun div(): List = numericTypes.map { t -> binary("divide", t, t, t) } - private fun mod(): List = numericTypes.map { t -> + private fun mod(): List = numericTypes.map { t -> binary("modulo", t, t, t) } - private fun concat(): List = textTypes.map { t -> + private fun concat(): List = textTypes.map { t -> binary("concat", t, t, t) } - private fun bitwiseAnd(): List = intTypes.map { t -> + private fun bitwiseAnd(): List = intTypes.map { t -> binary("bitwise_and", t, t, t) } // BUILT INTS - private fun upper(): List = textTypes.map { t -> - FunctionSignature( + private fun upper(): List = textTypes.map { t -> + FunctionSignature.Scalar( name = "upper", returns = t, parameters = listOf(FunctionParameter("value", t)), @@ -406,8 +391,8 @@ internal class Header( ) } - private fun lower(): List = textTypes.map { t -> - FunctionSignature( + private fun lower(): List = textTypes.map { t -> + FunctionSignature.Scalar( name = "lower", returns = t, parameters = listOf(FunctionParameter("value", t)), @@ -418,8 +403,8 @@ internal class Header( // SPECIAL FORMS - private fun like(): List = listOf( - FunctionSignature( + private fun like(): List = listOf( + FunctionSignature.Scalar( name = "like", returns = BOOL, parameters = listOf( @@ -429,7 +414,7 @@ internal class Header( isNullCall = true, isNullable = false, ), - FunctionSignature( + FunctionSignature.Scalar( name = "like_escape", returns = BOOL, parameters = listOf( @@ -442,8 +427,8 @@ internal class Header( ), ) - private fun between(): List = numericTypes.map { t -> - FunctionSignature( + private fun between(): List = numericTypes.map { t -> + FunctionSignature.Scalar( name = "between", returns = BOOL, parameters = listOf( @@ -456,9 +441,9 @@ internal class Header( ) } - private fun inCollection(): List = allTypes.map { element -> + private fun inCollection(): List = allTypes.map { element -> collectionTypes.map { collection -> - FunctionSignature( + FunctionSignature.Scalar( name = "in_collection", returns = BOOL, parameters = listOf( @@ -477,8 +462,8 @@ internal class Header( // TODO: We can remove the types with parameter in this function. // but, leaving out the decision to have, for example: // is_decimal(null, null, value) vs is_decimal(value) later.... - private fun isType(): List = allTypes.map { element -> - FunctionSignature( + private fun isType(): List = allTypes.map { element -> + FunctionSignature.Scalar( name = "is_${element.name.lowercase()}", returns = BOOL, parameters = listOf( @@ -492,8 +477,8 @@ internal class Header( // In type assertion, it is possible for types to have args // i.e., 'a' is CHAR(2) // we put type parameter before value. - private fun isTypeSingleArg(): List = listOf(CHAR, STRING).map { element -> - FunctionSignature( + private fun isTypeSingleArg(): List = listOf(CHAR, STRING).map { element -> + FunctionSignature.Scalar( name = "is_${element.name.lowercase()}", returns = BOOL, parameters = listOf( @@ -505,8 +490,8 @@ internal class Header( ) } - private fun isTypeDoubleArgsInt(): List = listOf(DECIMAL).map { element -> - FunctionSignature( + private fun isTypeDoubleArgsInt(): List = listOf(DECIMAL).map { element -> + FunctionSignature.Scalar( name = "is_${element.name.lowercase()}", returns = BOOL, parameters = listOf( @@ -519,8 +504,8 @@ internal class Header( ) } - private fun isTypeTime(): List = listOf(TIME, TIMESTAMP).map { element -> - FunctionSignature( + private fun isTypeTime(): List = listOf(TIME, TIMESTAMP).map { element -> + FunctionSignature.Scalar( name = "is_${element.name.lowercase()}", returns = BOOL, parameters = listOf( @@ -534,10 +519,10 @@ internal class Header( } // TODO - private fun coalesce(): List = emptyList() + private fun coalesce(): List = emptyList() - private fun nullIf(): List = nullableTypes.map { t -> - FunctionSignature( + private fun nullIf(): List = nullableTypes.map { t -> + FunctionSignature.Scalar( name = "null_if", returns = t, parameters = listOf( @@ -549,9 +534,9 @@ internal class Header( ) } - private fun substring(): List = textTypes.map { t -> + private fun substring(): List = textTypes.map { t -> listOf( - FunctionSignature( + FunctionSignature.Scalar( name = "substring", returns = t, parameters = listOf( @@ -561,7 +546,7 @@ internal class Header( isNullCall = true, isNullable = false, ), - FunctionSignature( + FunctionSignature.Scalar( name = "substring_length", returns = t, parameters = listOf( @@ -575,8 +560,8 @@ internal class Header( ) }.flatten() - private fun position(): List = textTypes.map { t -> - FunctionSignature( + private fun position(): List = textTypes.map { t -> + FunctionSignature.Scalar( name = "position", returns = INT64, parameters = listOf( @@ -588,9 +573,9 @@ internal class Header( ) } - private fun trim(): List = textTypes.map { t -> + private fun trim(): List = textTypes.map { t -> listOf( - FunctionSignature( + FunctionSignature.Scalar( name = "trim", returns = t, parameters = listOf( @@ -599,7 +584,7 @@ internal class Header( isNullCall = true, isNullable = false, ), - FunctionSignature( + FunctionSignature.Scalar( name = "trim_chars", returns = t, parameters = listOf( @@ -609,7 +594,7 @@ internal class Header( isNullCall = true, isNullable = false, ), - FunctionSignature( + FunctionSignature.Scalar( name = "trim_leading", returns = t, parameters = listOf( @@ -618,7 +603,7 @@ internal class Header( isNullCall = true, isNullable = false, ), - FunctionSignature( + FunctionSignature.Scalar( name = "trim_leading_chars", returns = t, parameters = listOf( @@ -628,7 +613,7 @@ internal class Header( isNullCall = true, isNullable = false, ), - FunctionSignature( + FunctionSignature.Scalar( name = "trim_trailing", returns = t, parameters = listOf( @@ -637,7 +622,7 @@ internal class Header( isNullCall = true, isNullable = false, ), - FunctionSignature( + FunctionSignature.Scalar( name = "trim_trailing_chars", returns = t, parameters = listOf( @@ -651,19 +636,19 @@ internal class Header( }.flatten() // TODO - private fun overlay(): List = emptyList() + private fun overlay(): List = emptyList() // TODO - private fun extract(): List = emptyList() + private fun extract(): List = emptyList() - private fun dateArithmetic(prefix: String): List { - val operators = mutableListOf() + private fun dateArithmetic(prefix: String): List { + val operators = mutableListOf() for (type in datetimeTypes) { for (field in DatetimeField.values()) { if (field == DatetimeField.TIMEZONE_HOUR || field == DatetimeField.TIMEZONE_MINUTE) { continue } - val signature = FunctionSignature( + val signature = FunctionSignature.Scalar( name = "${prefix}_${field.name.lowercase()}", returns = type, parameters = listOf( @@ -679,12 +664,12 @@ internal class Header( return operators } - private fun dateAdd(): List = dateArithmetic("date_add") + private fun dateAdd(): List = dateArithmetic("date_add") - private fun dateDiff(): List = dateArithmetic("date_diff") + private fun dateDiff(): List = dateArithmetic("date_diff") - private fun utcNow(): List = listOf( - FunctionSignature( + private fun utcNow(): List = listOf( + FunctionSignature.Scalar( name = "utcnow", returns = TIMESTAMP, parameters = emptyList(), @@ -692,20 +677,158 @@ internal class Header( ) ) - private fun currentUser() = FunctionSignature( + private fun currentUser() = FunctionSignature.Scalar( name = "\$__current_user", returns = STRING, parameters = emptyList(), isNullable = true, ) - private fun currentDate() = FunctionSignature( + private fun currentDate() = FunctionSignature.Scalar( name = "\$__current_date", returns = DATE, parameters = emptyList(), isNullable = false, ) + // ==================================== + // AGGREGATIONS + // ==================================== + + /** + * SQL and PartiQL Aggregation Builtins + */ + public fun aggregations(): List = listOf( + every(), + any(), + some(), + count(), + min(), + max(), + sum(), + avg(), + ).flatten() + + private fun every() = listOf( + FunctionSignature.Aggregation( + name = "every", + returns = BOOL, + parameters = listOf(FunctionParameter("value", BOOL)), + isNullable = true, + ), + FunctionSignature.Aggregation( + name = "every", + returns = BOOL, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ), + ) + + private fun any() = listOf( + FunctionSignature.Aggregation( + name = "any", + returns = BOOL, + parameters = listOf(FunctionParameter("value", BOOL)), + isNullable = true, + ), + FunctionSignature.Aggregation( + name = "any", + returns = BOOL, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ), + ) + + private fun some() = listOf( + FunctionSignature.Aggregation( + name = "some", + returns = BOOL, + parameters = listOf(FunctionParameter("value", BOOL)), + isNullable = true, + ), + FunctionSignature.Aggregation( + name = "some", + returns = BOOL, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ), + ) + + private fun count() = listOf( + FunctionSignature.Aggregation( + name = "count", + returns = INT32, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = false, + ), + FunctionSignature.Aggregation( + name = "count_star", + returns = INT32, + parameters = listOf(), + isNullable = false, + ), + ) + + private fun min() = numericTypes.map { + FunctionSignature.Aggregation( + name = "min", + returns = it, + parameters = listOf(FunctionParameter("value", it)), + isNullable = true, + ) + } + FunctionSignature.Aggregation( + name = "min", + returns = ANY, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ) + + private fun max() = numericTypes.map { + FunctionSignature.Aggregation( + name = "max", + returns = it, + parameters = listOf(FunctionParameter("value", it)), + isNullable = true, + ) + } + FunctionSignature.Aggregation( + name = "max", + returns = ANY, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ) + + private fun sum() = numericTypes.map { + FunctionSignature.Aggregation( + name = "sum", + returns = it, + parameters = listOf(FunctionParameter("value", it)), + isNullable = true, + ) + } + FunctionSignature.Aggregation( + name = "sum", + returns = ANY, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ) + + private fun avg() = numericTypes.map { + FunctionSignature.Aggregation( + name = "avg", + returns = it, + parameters = listOf(FunctionParameter("value", it)), + isNullable = true, + ) + } + FunctionSignature.Aggregation( + name = "avg", + returns = ANY, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ) + + // ==================================== + // SORTING + // ==================================== + // Function precedence comparator // 1. Fewest args first // 2. Parameters are compared left-to-right @@ -767,5 +890,38 @@ internal class Header( STRUCT, ANY, ).mapIndexed { precedence, type -> type to precedence }.toMap() + + // ==================================== + // HELPERS + // ==================================== + + public fun unary(name: String, returns: PartiQLValueType, value: PartiQLValueType) = + FunctionSignature.Scalar( + name = name, + returns = returns, + parameters = listOf(FunctionParameter("value", value)), + isNullCall = true, + isNullable = false, + ) + + public fun binary(name: String, returns: PartiQLValueType, lhs: PartiQLValueType, rhs: PartiQLValueType) = + FunctionSignature.Scalar( + name = name, + returns = returns, + parameters = listOf(FunctionParameter("lhs", lhs), FunctionParameter("rhs", rhs)), + isNullCall = true, + isNullable = false, + ) + + public fun cast(operand: PartiQLValueType, target: PartiQLValueType) = + FunctionSignature.Scalar( + name = castName(target), + returns = target, + isNullCall = true, + isNullable = false, + parameters = listOf( + FunctionParameter("value", operand), + ) + ) } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RelConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RelConverter.kt index 1cf02d2461..8b72e34347 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RelConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RelConverter.kt @@ -31,11 +31,11 @@ import org.partiql.ast.util.AstRewriter import org.partiql.ast.visitor.AstBaseVisitor import org.partiql.plan.Rel import org.partiql.plan.Rex -import org.partiql.plan.fnUnresolved +import org.partiql.plan.aggUnresolved import org.partiql.plan.rel import org.partiql.plan.relBinding import org.partiql.plan.relOpAggregate -import org.partiql.plan.relOpAggregateAgg +import org.partiql.plan.relOpAggregateCall import org.partiql.plan.relOpErr import org.partiql.plan.relOpExcept import org.partiql.plan.relOpExclude @@ -328,22 +328,22 @@ internal object RelConverter { return Pair(select, input) } - // Build the schema -> (aggs... groups...) + // Build the schema -> (calls... groups...) val schema = mutableListOf() val props = emptySet() // Build the rel operator var strategy = Rel.Op.Aggregate.Strategy.FULL - val aggs = aggregations.mapIndexed { i, agg -> + val calls = aggregations.mapIndexed { i, expr -> val binding = relBinding( name = syntheticAgg(i), type = (StaticType.ANY), ) schema.add(binding) - val args = agg.args.map { arg -> arg.toRex(env) } - val id = AstToPlan.convert(agg.function) - val fn = fnUnresolved(id) - relOpAggregateAgg(fn, args) + val args = expr.args.map { arg -> arg.toRex(env) } + val id = AstToPlan.convert(expr.function) + val agg = aggUnresolved(id) + relOpAggregateCall(agg, args) } var groups = emptyList() if (groupBy != null) { @@ -364,7 +364,7 @@ internal object RelConverter { } } val type = relType(schema, props) - val op = relOpAggregate(input, strategy, aggs, groups) + val op = relOpAggregate(input, strategy, calls, groups) val rel = rel(type, op) return Pair(sel, rel) } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/FunctionResolver.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/FunctionResolver.kt index ccd792eb08..dbd03fb8b7 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/FunctionResolver.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/FunctionResolver.kt @@ -14,13 +14,13 @@ internal typealias Args = List /** * Parameter mapping list tells the planner where to insert implicit casts. Null is the identity. */ -internal typealias Mapping = List +internal typealias Mapping = List /** * Tells us which function matched, and how the arguments are mapped. */ -internal class Match( - public val signature: FunctionSignature, +internal class Match( + public val signature: T, public val mapping: Mapping, ) @@ -33,7 +33,7 @@ internal class FunctionResolver(private val header: Header) { /** * Functions are sorted by precedence (which is not rigorously defined/specified at the moment). */ - public fun match(signatures: List, args: Args): Match? { + public fun match(signatures: List, args: Args): Match? { for (signature in signatures) { val mapping = match(signature, args) if (mapping != null) { @@ -52,7 +52,7 @@ internal class FunctionResolver(private val header: Header) { if (signature.parameters.size != args.size) { return null } - val mapping = ArrayList(args.size) + val mapping = ArrayList(args.size) for (i in args.indices) { val a = args[i] val p = signature.parameters[i] diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt index 897d8b8e80..e055bfbe2d 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt @@ -19,16 +19,19 @@ package org.partiql.planner.typer import org.partiql.errors.Problem import org.partiql.errors.ProblemCallback import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION +import org.partiql.plan.Agg import org.partiql.plan.Fn import org.partiql.plan.Identifier import org.partiql.plan.PlanNode import org.partiql.plan.Rel import org.partiql.plan.Rex import org.partiql.plan.Statement +import org.partiql.plan.aggResolved import org.partiql.plan.fnResolved import org.partiql.plan.identifierSymbol import org.partiql.plan.rel import org.partiql.plan.relBinding +import org.partiql.plan.relOpAggregateCall import org.partiql.plan.relOpErr import org.partiql.plan.relOpFilter import org.partiql.plan.relOpJoin @@ -218,10 +221,6 @@ internal class PlanTyper( return rel(type, op) } - override fun visitRelOpSortSpec(node: Rel.Op.Sort.Spec, ctx: Rel.Type?): Rel { - TODO("Type RelOp SortSpec") - } - override fun visitRelOpUnion(node: Rel.Op.Union, ctx: Rel.Type?): Rel { TODO("Type RelOp Union") } @@ -269,7 +268,9 @@ internal class PlanTyper( val input = visitRel(node.input, ctx) // type sub-nodes val typeEnv = TypeEnv(input.type.schema, ResolutionStrategy.LOCAL) - val projections = node.projections.map { it.type(typeEnv) } + val projections = node.projections.map { + it.type(typeEnv) + } // compute output schema val schema = projections.map { it.type } val type = ctx!!.copyWithSchema(schema) @@ -303,19 +304,20 @@ internal class PlanTyper( /** * Initial implementation of `EXCLUDE` schema inference. Until an RFC is finalized for `EXCLUDE` - * (https://github.com/partiql/partiql-spec/issues/39), this behavior is considered experimental and subject to - * change. + * (https://github.com/partiql/partiql-spec/issues/39), + * + * This behavior is considered experimental and subject to change. * - * So far this implementation includes + * This implementation includes * - Excluding tuple bindings (e.g. t.a.b.c) * - Excluding tuple wildcards (e.g. t.a.*.b) * - Excluding collection indexes (e.g. t.a[0].b -- behavior subject to change; see below discussion) * - Excluding collection wildcards (e.g. t.a[*].b) * * There are still discussion points regarding the following edge cases: - * - EXCLUDE on a tuple bindingibute that doesn't exist -- give an error/warning? + * - EXCLUDE on a tuple attribute that doesn't exist -- give an error/warning? * - currently no error - * - EXCLUDE on a tuple bindingibute that has duplicates -- give an error/warning? exclude one? exclude both? + * - EXCLUDE on a tuple attribute that has duplicates -- give an error/warning? exclude one? exclude both? * - currently excludes both w/ no error * - EXCLUDE on a collection index as the last step -- mark element type as optional? * - currently element type as-is @@ -325,7 +327,7 @@ internal class PlanTyper( * - currently a parser error * - EXCLUDE on a union type -- give an error/warning? no-op? exclude on each type in union? * - currently exclude on each union type - * - If SELECT list includes an bindingibute that is excluded, we could consider giving an error in PlanTyper or + * - If SELECT list includes an attribute that is excluded, we could consider giving an error in PlanTyper or * some other semantic pass * - currently does not give an error */ @@ -343,15 +345,33 @@ internal class PlanTyper( } override fun visitRelOpAggregate(node: Rel.Op.Aggregate, ctx: Rel.Type?): Rel { - TODO("Type RelOp Aggregate") - } + // compute input schema + val input = visitRel(node.input, ctx) - override fun visitRelOpAggregateAgg(node: Rel.Op.Aggregate.Agg, ctx: Rel.Type?): Rel { - TODO("Type RelOp Agg") - } + // type the calls and groups + val typer = RexTyper(locals = TypeEnv(input.type.schema, ResolutionStrategy.LOCAL)) + + // typing of aggregate calls is slightly more complicated because they are not expressions. + val calls = node.calls.mapIndexed { i, call -> + when (val agg = call.agg) { + is Agg.Resolved -> call to ctx!!.schema[i].type + is Agg.Unresolved -> typer.resolveAgg(agg, call.args) + } + } + val groups = node.groups.map { typer.visitRex(it, null) } + + // Compute schema using order (calls...groups...) + val schema = mutableListOf() + schema += calls.map { it.second } + schema += groups.map { it.type } - override fun visitRelBinding(node: Rel.Binding, ctx: Rel.Type?): Rel { - TODO("Type RelOp Binding") + // rewrite with typed calls and groups + val type = ctx!!.copyWithSchema(schema) + val op = node.copy( + calls = calls.map { it.first }, + groups = groups, + ) + return rel(type, op) } } @@ -455,7 +475,7 @@ internal class PlanTyper( // 4. Invalid path reference; always MISSING if (type == StaticType.MISSING) { handleAlwaysMissing() - return rex(type, rexOpErr("Unknown identifier $node")) + return rexErr("Unknown identifier $node") } // 5. Non-missing, root is resolved @@ -463,10 +483,7 @@ internal class PlanTyper( } /** - * Typing of functions is - * - * 1. If any argument is MISSING, the function return type is MISSING - * 2. If all arguments are NULL + * Resolve and type scalar function calls. * * @param node * @param ctx @@ -536,7 +553,7 @@ internal class PlanTyper( } is FnMatch.Error -> { handleUnknownFunction(match) - rex(StaticType.MISSING, rexOpErr("Unknown function $fn")) + rexErr("Unknown scalar function $fn") } } } @@ -927,6 +944,63 @@ internal class PlanTyper( false -> StaticType.ANY } } + + /** + * Resolution and typing of aggregation function calls. + * + * I've chosen to place this in RexTyper because all arguments will be typed using the same locals. + * There's no need to create new RexTyper instances for each argument. There is no reason to limit aggregations + * to a single argument (covar, corr, pct, etc.) but in practice we typically only have single . + * + * This method is _very_ similar to scalar function resolution, so it is temping to DRY these two out; but the + * separation is cleaner as the typing of NULLS is subtly different. + * + * SQL-99 6.16 General Rules on + * Let TX be the single-column table that is the result of applying the + * to each row of T and eliminating null values <--- all NULL values are eliminated as inputs + */ + public fun resolveAgg(agg: Agg.Unresolved, arguments: List): Pair { + var missingArg = false + val args = arguments.map { + val arg = visitRex(it, null) + if (arg.type.isMissable()) missingArg = true + arg + } + + // + if (missingArg) { + handleAlwaysMissing() + return relOpAggregateCall(agg, listOf(rexErr("MISSING"))) to MissingType + } + + // Try to match the arguments to functions defined in the catalog + return when (val match = env.resolveAgg(agg, args)) { + is FnMatch.Ok -> { + // Found a match! + val newAgg = aggResolved(match.signature) + val newArgs = rewriteFnArgs(match.mapping, args) + val returns = newAgg.signature.returns + + // Return type with calculated nullability + var type = when { + newAgg.signature.isNullable -> returns.toStaticType() + else -> returns.toNonNullStaticType() + } + + // Some operators can return MISSING during runtime + if (match.isMissable) { + type = StaticType.unionOf(type, StaticType.MISSING).flatten() + } + + // Finally, rewrite this node + relOpAggregateCall(newAgg, newArgs) to type + } + is FnMatch.Error -> { + handleUnknownFunction(match) + return relOpAggregateCall(agg, listOf(rexErr("MISSING"))) to MissingType + } + } + } } // HELPERS @@ -935,6 +1009,8 @@ internal class PlanTyper( private fun Rex.type(typeEnv: TypeEnv) = RexTyper(typeEnv).visitRex(this, this.type) + private fun rexErr(message: String) = rex(StaticType.MISSING, rexOpErr(message)) + /** * I found decorating the tree with the binding names (for resolution) was easier than associating introduced * bindings with a node via an id->list map. ONLY because right now I don't think we have a good way @@ -1026,7 +1102,7 @@ internal class PlanTyper( /** * Rewrites function arguments, wrapping in the given function if exists. */ - private fun rewriteFnArgs(mapping: List, args: List): List { + private fun rewriteFnArgs(mapping: List, args: List): List { if (mapping.size != args.size) { error("Fatal, malformed function mapping") // should be unreachable given how a mapping is generated. } @@ -1089,12 +1165,12 @@ internal class PlanTyper( ) } - private fun handleUnknownFunction(match: FnMatch.Error) { + private fun handleUnknownFunction(match: FnMatch.Error<*>) { onProblem( Problem( sourceLocation = UNKNOWN_PROBLEM_LOCATION, details = PlanningProblemDetails.UnknownFunction( - match.fn.identifier.normalize(), + match.identifier.normalize(), match.args.map { a -> a.type }, ) ) diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerTestJunit.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerTestJunit.kt index 187bc1d8bc..09fd3c0aea 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerTestJunit.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerTestJunit.kt @@ -3,7 +3,6 @@ package org.partiql.planner import com.amazon.ionelement.api.field import com.amazon.ionelement.api.ionString import com.amazon.ionelement.api.ionStructOf -import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.DynamicContainer import org.junit.jupiter.api.DynamicContainer.dynamicContainer import org.junit.jupiter.api.DynamicNode @@ -14,29 +13,49 @@ import org.junit.jupiter.api.fail import org.partiql.errors.ProblemSeverity import org.partiql.parser.PartiQLParserBuilder import org.partiql.plan.Statement -import org.partiql.planner.test.PartiQLTest -import org.partiql.planner.test.PartiQLTestProvider +import org.partiql.plan.debug.PlanPrinter +import org.partiql.planner.test.PlannerTest +import org.partiql.planner.test.PlannerTestProvider +import org.partiql.planner.test.PlannerTestSuite +import org.partiql.planner.test.toIon import org.partiql.plugins.local.LocalPlugin -import org.partiql.plugins.local.toIon import java.util.stream.Stream +import kotlin.io.path.pathString +import kotlin.io.path.toPath + +data class PlannerTestFilter( + val suite: String?, + val test: String?, +) -/** - * PlannerTestJunit is responsible for constructing JUnit test suites from all input queries in the testFixtures. - * - * I believe this can be more generic and added to testFixtures; but that is outside the scope of current work. - */ class PlannerTestJunit { @TestFactory fun mapSuitesToJunitTests(): Stream { - val inputs = PartiQLTestProvider().inputs() - val cases = PlannerTestProvider().groups() - return cases.map { groupNode(it, inputs) } + val filter = PlannerTestFilter( + suite = null, + test = null, + ) + val provider = PlannerTestProvider() + // filter suites + val suites = provider.suites().apply { + if (filter.suite != null) { + this.filter { it.name.contains(filter.suite) } + } + } + return provider.suites().map { + // filter tests + val suite = when (val f = filter.test) { + null -> it + else -> PlannerTestSuite(it.name, it.session, it.tests.filter { t -> t.key.contains(f) }) + } + suiteNode(suite) + } } companion object { - private val root = PartiQLTestProvider::class.java.getResource("/catalogs")!!.toURI().path + private val root = PlannerTest::class.java.getResource("/catalogs")!!.toURI().toPath().pathString private val parser = PartiQLParserBuilder.standard().build() @@ -45,52 +64,35 @@ class PlannerTestJunit { field("connector_name", ionString("local")), field("root", ionString("$root/default")), ), - "tpc_ds" to ionStructOf( - field("connector_name", ionString("local")), - field("root", ionString("$root/tpc_ds")), - ), ) - private fun groupNode(group: PlannerTestGroup, inputs: Map): DynamicContainer { + private fun suiteNode(suite: PlannerTestSuite): DynamicContainer { val plugin = LocalPlugin() val planner = PartiQLPlannerBuilder() .plugins(listOf(plugin)) .build() - // Map all cases to an input - val tests = group.cases.map { case -> - val key = "${group.name}__${case.input}" - val input = inputs[key] - // Report bad input mapping - if (input == null) { - return@map failTestNode(key, "Missing input for `$key`") - } + val tests = suite.tests.map { (name, test) -> + val testName = "${suite.name}__$name" val session = PartiQLPlanner.Session( - queryId = key, + queryId = "q__$testName", userId = "Planner_test_runner", - currentCatalog = case.catalog, - currentDirectory = case.catalogPath, + currentCatalog = suite.session.catalog, + currentDirectory = suite.session.path, catalogConfig = catalogConfig, ) - testNode(key, planner, session, input.statement, case) - } - return dynamicContainer(group.name, tests.stream()) - } - - private fun failTestNode(id: String, message: String): DynamicTest { - return dynamicTest(id) { - fail { message } + testNode(testName, planner, session, test) } + return dynamicContainer(suite.name, tests.stream()) } private fun testNode( displayName: String, planner: PartiQLPlanner, session: PartiQLPlanner.Session, - statement: String, - case: PlannerTestCase, + test: PlannerTest, ): DynamicTest { return dynamicTest(displayName) { - val ast = parser.parse(statement).root + val ast = parser.parse(test.statement).root val result = planner.plan(ast, session) for (problem in result.problems) { if (problem.details.severity == ProblemSeverity.ERROR) { @@ -101,9 +103,17 @@ class PlannerTestJunit { if (statement !is Statement.Query) { fail { "Expected plan statement to be a Statement.Query" } } - val expected = case.schema.toIon() + val expected = test.schema.toIon() val actual = statement.root.type.toIon() - assertEquals(expected, actual) + assert(expected == actual) { + buildString { + appendLine() + appendLine("Expect: $expected") + appendLine("Actual: $actual") + appendLine() + PlanPrinter.append(this, statement) + } + } } } } diff --git a/partiql-planner/src/testFixtures/resources/catalogs/default/pql/numbers.ion b/partiql-planner/src/testFixtures/resources/catalogs/default/pql/numbers.ion new file mode 100644 index 0000000000..14c03f7135 --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/catalogs/default/pql/numbers.ion @@ -0,0 +1,114 @@ +{ + type: "struct", + fields: [ + { + name: "nullable_int16s", + type: { + type: "list", + items: [ + "int16", + "null" + ] + } + }, + { + name: "nullable_int32s", + type: { + type: "list", + items: [ + "int32", + "null" + ] + } + }, + { + name: "nullable_int64s", + type: { + type: "list", + items: [ + "int64", + "null" + ] + } + }, + { + name: "nullable_ints", + type: { + type: "list", + items: [ + "int", + "null" + ] + } + }, + { + name: "int16s", + type: { + type: "list", + items: "int16", + }, + }, + { + name: "int32s", + type: { + type: "list", + items: "int32", + }, + }, + { + name: "int64s", + type: { + type: "list", + items: "int64", + }, + }, + { + name: "ints", + type: { + type: "list", + items: "int", + }, + }, + { + name: "decimals", + type: { + type: "list", + items: "decimal", + }, + }, + { + name: "nullable_float32s", + type: { + type: "list", + items: [ + "float32", + "null" + ] + } + }, + { + name: "nullable_float64s", + type: { + type: "list", + items: [ + "float64", + "null" + ] + } + }, + { + name: "float32s", + type: { + type: "list", + items: "float32", + }, + }, + { + name: "float64s", + type: { + type: "list", + items: "float64", + }, + } + ], +} diff --git a/partiql-planner/src/testFixtures/resources/catalogs/default/pql/points.ion b/partiql-planner/src/testFixtures/resources/catalogs/default/pql/points.ion new file mode 100644 index 0000000000..0d18fe569c --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/catalogs/default/pql/points.ion @@ -0,0 +1,21 @@ +{ + type: "bag", + items: { + type: "struct", + constraints: [closed], + fields: [ + { + name: "x", + type: "float32", + }, + { + name: "y", + type: "float32", + }, + { + name: "z", + type: "float32", + }, + ] + } +} diff --git a/partiql-planner/src/testFixtures/resources/tests/aggregations.ion b/partiql-planner/src/testFixtures/resources/tests/aggregations.ion new file mode 100644 index 0000000000..2badeb8dd6 --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/tests/aggregations.ion @@ -0,0 +1,286 @@ +suite::{ + name: "aggregations", + session: { + catalog: "default", + path: [ + "pql" + ], + vars: {}, + }, + tests: { + 'avg(int32|null)': { + statement: ''' + SELECT AVG(n) as "avg" FROM numbers.nullable_int32s AS n + ''', + schema: { + type: "bag", + items: { + type: "struct", + fields: [ + { + name: "avg", + type: [ + "int32", + "null", + ], + }, + ], + }, + }, + }, + 'count(int32|null)': { + statement: ''' + SELECT COUNT(n) as "count" FROM numbers.nullable_int32s AS n + ''', + schema: { + type: "bag", + items: { + type: "struct", + fields: [ + { + name: "count", + type: "int32", + }, + ], + }, + }, + }, + 'min(int32|null)': { + statement: ''' + SELECT MIN(n) as "min" FROM numbers.nullable_int32s AS n + ''', + schema: { + type: "bag", + items: { + type: "struct", + fields: [ + { + name: "min", + type: [ + "int32", + "null", + ], + }, + ], + }, + }, + }, + 'max(int32|null)': { + statement: ''' + SELECT MAX(n) as "max" FROM numbers.nullable_int32s AS n + ''', + schema: { + type: "bag", + items: { + type: "struct", + fields: [ + { + name: "max", + type: [ + "int32", + "null", + ], + }, + ], + }, + }, + }, + 'sum(int32|null)': { + statement: ''' + SELECT SUM(n) as "sum" FROM numbers.nullable_int32s AS n + ''', + schema: { + type: "bag", + items: { + type: "struct", + fields: [ + { + name: "sum", + type: [ + "int32", + "null", + ], + }, + ], + }, + }, + }, + 'avg(int32)': { + statement: ''' + SELECT AVG(n) as "avg" FROM numbers.int32s AS n + ''', + schema: { + type: "bag", + items: { + type: "struct", + fields: [ + { + name: "avg", + type: [ + "int32", + "null" + ], + }, + ], + }, + }, + }, + 'count(int32)': { + statement: ''' + SELECT COUNT(n) as "count" FROM numbers.int32s AS n + ''', + schema: { + type: "bag", + items: { + type: "struct", + fields: [ + { + name: "count", + type: "int32", + }, + ], + }, + }, + }, + 'min(int32)': { + statement: ''' + SELECT MIN(n) as "min" FROM numbers.int32s AS n + ''', + schema: { + type: "bag", + items: { + type: "struct", + fields: [ + { + name: "min", + type: [ + "int32", + "null" + ], + }, + ], + }, + }, + }, + 'max(int32)': { + statement: ''' + SELECT MAX(n) as "max" FROM numbers.int32s AS n + ''', + schema: { + type: "bag", + items: { + type: "struct", + fields: [ + { + name: "max", + type: [ + "int32", + "null" + ], + }, + ], + }, + }, + }, + 'sum(int32)': { + statement: ''' + SELECT SUM(n) as "sum" FROM numbers.int32s AS n + ''', + schema: { + type: "bag", + items: { + type: "struct", + fields: [ + { + name: "sum", + type: [ + "int32", + "null" + ], + }, + ], + }, + }, + }, + 'group_by_key': { + statement: ''' + SELECT COUNT(*) as "count", isOdd FROM numbers.int32s AS n + GROUP BY n % 2 = 0 AS isOdd + ''', + schema: { + type: "bag", + items: { + type: "struct", + fields: [ + { + name: "count", + type: "int32" + }, + { + name: "isOdd", + type: "bool" + }, + ], + }, + }, + }, + 'group_by_keys_noalias': { + statement: ''' + SELECT AVG(x), y, z FROM points + GROUP BY y, z + ''', + schema: { + type: "bag", + items: { + type: "struct", + fields: [ + { + name: "_1", + type: [ + "float32", + "null" + ], + }, + { + name: "y", + type: "float32" + }, + { + name: "z", + type: "float32" + }, + ], + }, + }, + }, + 'group_by_keys_alias': { + statement: ''' + SELECT AVG(x), a, b FROM points + GROUP BY y as a, z as b + ''', + schema: { + type: "bag", + items: { + type: "struct", + fields: [ + { + name: "_1", + type: [ + "float32", + "null" + ], + }, + { + name: "a", + type: "float32" + }, + { + name: "b", + type: "float32" + }, + ], + }, + }, + }, + }, +} diff --git a/partiql-types/src/main/kotlin/org/partiql/types/function/FunctionSignature.kt b/partiql-types/src/main/kotlin/org/partiql/types/function/FunctionSignature.kt index e9666d1ca1..89839c5edd 100644 --- a/partiql-types/src/main/kotlin/org/partiql/types/function/FunctionSignature.kt +++ b/partiql-types/src/main/kotlin/org/partiql/types/function/FunctionSignature.kt @@ -4,7 +4,6 @@ import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType /** - * Represents the signature of a PartiQL function. * * The signature includes the names of the function (which allows for function overloading), * the return type, a list of parameters, a flag indicating whether the function is deterministic @@ -13,26 +12,20 @@ import org.partiql.value.PartiQLValueType * @property name Function name * @property returns Operator return type * @property parameters Operator parameters - * @property isDeterministic Flag indicating this function always produces the same output given the same input. - * @property isNullCall Flag indicating if any of the call arguments is NULL, then return NULL. - * @property isNullable Flag indicating this function's operator may return a NULL value. * @property description Optional operator description + * @property isNullable Flag indicating this function's operator may return a NULL value. */ @OptIn(PartiQLValueExperimental::class) -public class FunctionSignature( - public val name: String, - public val returns: PartiQLValueType, - public val parameters: List, - public val isDeterministic: Boolean = true, - public val isNullCall: Boolean = false, - public val isNullable: Boolean = true, - public val description: String? = null, +public sealed class FunctionSignature( + @JvmField public val name: String, + @JvmField public val returns: PartiQLValueType, + @JvmField public val parameters: List, + @JvmField public val description: String? = null, + @JvmField public val isNullable: Boolean = true, ) { /** - * String mangling of a function signature to generate a specific identifier. - * - * Format NAME__INPUTS__RETURNS + * Symbolic name of this operator of the form NAME__INPUTS__RETURNS */ public val specific: String = buildString { append(name.uppercase()) @@ -43,74 +36,145 @@ public class FunctionSignature( } /** - * SQL-99 p.542 + * Use the symbolic name for easy debugging + * + * @return */ - private val deterministicCharacteristic = when (isDeterministic) { - true -> "DETERMINISTIC" - else -> "NOT DETERMINISTIC" - } + override fun toString(): String = specific /** - * SQL-99 p.543 + * Represents the signature of a PartiQL scalar function. + * + * @property isDeterministic Flag indicating this function always produces the same output given the same input. + * @property isNullCall Flag indicating if any of the call arguments is NULL, then return NULL. + * @constructor */ - private val nullCallClause = when (isNullCall) { - true -> "RETURNS NULL ON NULL INPUT" - else -> "CALLED ON NULL INPUT" - } + public class Scalar( + name: String, + returns: PartiQLValueType, + parameters: List, + description: String? = null, + isNullable: Boolean = true, + @JvmField public val isDeterministic: Boolean = true, + @JvmField public val isNullCall: Boolean = false, + ) : FunctionSignature(name, returns, parameters, description, isNullable) { - override fun toString(): String = specific - - internal fun sql(): String = buildString { - val fn = name.uppercase() - val indent = " " - append("CREATE FUNCTION \"$fn\" (") - if (parameters.isNotEmpty()) { - val extent = parameters.maxOf { it.name.length } + override fun equals(other: Any?): Boolean { + if (other !is Scalar) return false + if ( + other.name != name || + other.returns != returns || + other.parameters.size != parameters.size || + other.isDeterministic != isDeterministic || + other.isNullCall != isNullCall || + other.isNullable != isNullable + ) { + return false + } + // all other parts equal, compare parameters (ignore names) for (i in parameters.indices) { - val p = parameters[i] - val ws = (extent - p.name.length) + 1 - appendLine() - append(indent).append(p.name.uppercase()).append(" ".repeat(ws)).append(p.type.name) - if (i != parameters.size - 1) append(",") + val p1 = parameters[i] + val p2 = other.parameters[i] + if (p1.type != p2.type) return false } + return true + } + + override fun hashCode(): Int { + var result = name.hashCode() + result = 31 * result + returns.hashCode() + result = 31 * result + parameters.hashCode() + result = 31 * result + isDeterministic.hashCode() + result = 31 * result + isNullCall.hashCode() + result = 31 * result + isNullable.hashCode() + result = 31 * result + (description?.hashCode() ?: 0) + return result } - appendLine(" )") - append(indent).appendLine("RETURNS $returns") - append(indent).appendLine("SPECIFIC $specific") - append(indent).appendLine(deterministicCharacteristic) - append(indent).appendLine(nullCallClause) - append(indent).appendLine("RETURN $fn ( ${parameters.joinToString { it.name.uppercase() }} ) ;") } - override fun equals(other: Any?): Boolean { - if (other !is FunctionSignature) return false - if ( - other.name != name || - other.returns != returns || - other.isDeterministic != isDeterministic || - other.isNullCall != isNullCall || - other.isNullable != isNullable || - other.parameters.size != parameters.size - ) { - return false + /** + * Represents the signature of a PartiQL aggregation function. + * + * @property isDecomposable Flag indicating this aggregation can be decomposed + * @constructor + */ + public class Aggregation( + name: String, + returns: PartiQLValueType, + parameters: List, + description: String? = null, + isNullable: Boolean = true, + @JvmField public val isDecomposable: Boolean = true, + ) : FunctionSignature(name, returns, parameters, description, isNullable) { + + override fun equals(other: Any?): Boolean { + if (other !is Aggregation) return false + if ( + other.name != name || + other.returns != returns || + other.parameters.size != parameters.size || + other.isDecomposable != isDecomposable || + other.isNullable != isNullable + ) { + return false + } + // all other parts equal, compare parameters (ignore names) + for (i in parameters.indices) { + val p1 = parameters[i] + val p2 = other.parameters[i] + if (p1.type != p2.type) return false + } + return true } - // all other parts equal, compare parameters (ignore names) - for (i in parameters.indices) { - val p1 = parameters[i] - val p2 = other.parameters[i] - if (p1.type != p2.type) return false + + override fun hashCode(): Int { + var result = name.hashCode() + result = 31 * result + returns.hashCode() + result = 31 * result + parameters.hashCode() + result = 31 * result + isDecomposable.hashCode() + result = 31 * result + isNullable.hashCode() + result = 31 * result + (description?.hashCode() ?: 0) + return result } - return true } - override fun hashCode(): Int { - var result = name.hashCode() - result = 31 * result + returns.hashCode() - result = 31 * result + parameters.hashCode() - result = 31 * result + isDeterministic.hashCode() - result = 31 * result + isNullCall.hashCode() - result = 31 * result + isNullable.hashCode() - result = 31 * result + (description?.hashCode() ?: 0) - return result - } + // // Logic for writing a [FunctionSignature] using SQL `CREATE FUNCTION` syntax. + // + // /** + // * SQL-99 p.542 + // */ + // private val deterministicCharacteristic = when (isDeterministic) { + // true -> "DETERMINISTIC" + // else -> "NOT DETERMINISTIC" + // } + // + // /** + // * SQL-99 p.543 + // */ + // private val nullCallClause = when (isNullCall) { + // true -> "RETURNS NULL ON NULL INPUT" + // else -> "CALLED ON NULL INPUT" + // } + // + // private fun sql(): String = buildString { + // val fn = name.uppercase() + // val indent = " " + // append("CREATE FUNCTION \"$fn\" (") + // if (parameters.isNotEmpty()) { + // val extent = parameters.maxOf { it.name.length } + // for (i in parameters.indices) { + // val p = parameters[i] + // val ws = (extent - p.name.length) + 1 + // appendLine() + // append(indent).append(p.name.uppercase()).append(" ".repeat(ws)).append(p.type.name) + // if (i != parameters.size - 1) append(",") + // } + // } + // appendLine(" )") + // append(indent).appendLine("RETURNS $returns") + // append(indent).appendLine("SPECIFIC $specific") + // append(indent).appendLine(deterministicCharacteristic) + // append(indent).appendLine(nullCallClause) + // append(indent).appendLine("RETURN $fn ( ${parameters.joinToString { it.name.uppercase() }} ) ;") + // } } diff --git a/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/functions/Pow.kt b/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/functions/Pow.kt index 3ef0482ae3..9ee2b7d578 100644 --- a/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/functions/Pow.kt +++ b/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/functions/Pow.kt @@ -15,7 +15,7 @@ import org.partiql.value.float64Value object Pow : PartiQLFunction { @OptIn(PartiQLValueExperimental::class) - override val signature = FunctionSignature( + override val signature = FunctionSignature.Scalar( name = "test_power", returns = PartiQLValueType.FLOAT64, parameters = listOf( diff --git a/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/functions/TrimLead.kt b/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/functions/TrimLead.kt index cf5a3912d3..9baa6b9365 100644 --- a/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/functions/TrimLead.kt +++ b/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/functions/TrimLead.kt @@ -15,7 +15,7 @@ import org.partiql.value.stringValue object TrimLead : PartiQLFunction { @OptIn(PartiQLValueExperimental::class) - override val signature = FunctionSignature( + override val signature = FunctionSignature.Scalar( name = "trim_lead", returns = PartiQLValueType.STRING, parameters = listOf(