From 94c02ef8f9e9472d7a4f3e72758169ae399eba8a Mon Sep 17 00:00:00 2001 From: "R. C. Howell" Date: Fri, 27 Oct 2023 14:52:54 -0700 Subject: [PATCH] Allows providing headers to PartiQLPlanner --- .../PartiQLSchemaInferencerTests.kt | 89 +- .../main/kotlin/org/partiql/planner/Env.kt | 85 +- .../main/kotlin/org/partiql/planner/Header.kt | 934 +----------------- .../org/partiql/planner/PartiQLHeader.kt | 670 +++++++++++++ .../partiql/planner/PartiQLPlannerBuilder.kt | 7 +- .../partiql/planner/PartiQLPlannerDefault.kt | 12 +- .../org/partiql/planner/typer/FnResolver.kt | 395 ++++++++ .../partiql/planner/typer/FunctionResolver.kt | 79 -- .../org/partiql/planner/typer/PlanTyper.kt | 1 - .../org/partiql/planner/typer/TypeLattice.kt | 44 + .../kotlin/org/partiql/planner/HeaderTest.kt | 2 +- .../planner/typer/FunctionResolverTest.kt | 38 +- 12 files changed, 1235 insertions(+), 1121 deletions(-) create mode 100644 partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt create mode 100644 partiql-planner/src/main/kotlin/org/partiql/planner/typer/FnResolver.kt delete mode 100644 partiql-planner/src/main/kotlin/org/partiql/planner/typer/FunctionResolver.kt diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt index 59aed25d0f..fadbb741ef 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt @@ -112,6 +112,11 @@ class PartiQLSchemaInferencerTests { @Execution(ExecutionMode.CONCURRENT) fun testTupleUnion(tc: TestCase) = runTest(tc) + @ParameterizedTest + @MethodSource("aggregationCases") + @Execution(ExecutionMode.CONCURRENT) + fun testAggregations(tc: TestCase) = runTest(tc) + companion object { private val root = this::class.java.getResource("/catalogs/default")!!.toURI().toPath().pathString @@ -2123,6 +2128,50 @@ class PartiQLSchemaInferencerTests { ), ), ) + + @JvmStatic + fun aggregationCases() = listOf( + SuccessTestCase( + name = "AGGREGATE over INTS", + query = "SELECT a, COUNT(*) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1, 'b': 2} >> GROUP BY a", + expected = BagType( + StructType( + fields = mapOf( + "a" to INT, + "c" to INT4, + "s" to INT.asNullable(), + "m" to INT.asNullable(), + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ) + ), + SuccessTestCase( + name = "AGGREGATE over DECIMALS", + query = "SELECT a, COUNT(*) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1.0, 'b': 2.0}, {'a': 1.0, 'b': 2.0} >> GROUP BY a", + expected = BagType( + StructType( + fields = mapOf( + "a" to StaticType.DECIMAL, + "c" to INT4, + "s" to StaticType.DECIMAL.asNullable(), + "m" to StaticType.DECIMAL.asNullable(), + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ) + ), + ) } sealed class TestCase { @@ -2790,46 +2839,6 @@ class PartiQLSchemaInferencerTests { ) ) ), - SuccessTestCase( - name = "AGGREGATE over INTS", - query = "SELECT a, COUNT(*) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1, 'b': 2} >> GROUP BY a", - expected = BagType( - StructType( - fields = mapOf( - "a" to INT, - "c" to INT, - "s" to INT, - "m" to INT, - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Ordered - ) - ) - ) - ), - SuccessTestCase( - name = "AGGREGATE over DECIMALS", - query = "SELECT a, COUNT(*) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1.0, 'b': 2.0}, {'a': 1.0, 'b': 2.0} >> GROUP BY a", - expected = BagType( - StructType( - fields = mapOf( - "a" to StaticType.DECIMAL, - "c" to INT, - "s" to StaticType.DECIMAL, - "m" to StaticType.DECIMAL, - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Ordered - ) - ) - ) - ), SuccessTestCase( name = "Current User", query = "CURRENT_USER", diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/Env.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/Env.kt index 8d633351b8..797969c58e 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/Env.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/Env.kt @@ -9,9 +9,7 @@ import org.partiql.plan.Rex import org.partiql.plan.global import org.partiql.plan.identifierQualified import org.partiql.plan.identifierSymbol -import org.partiql.planner.typer.FunctionResolver -import org.partiql.planner.typer.Mapping -import org.partiql.planner.typer.toRuntimeType +import org.partiql.planner.typer.FnResolver import org.partiql.spi.BindingCase import org.partiql.spi.BindingName import org.partiql.spi.BindingPath @@ -25,9 +23,6 @@ import org.partiql.spi.connector.Constants import org.partiql.types.StaticType import org.partiql.types.StructType import org.partiql.types.TupleConstraint -import org.partiql.types.function.FunctionParameter -import org.partiql.types.function.FunctionSignature -import org.partiql.value.PartiQLValueExperimental /** * Handle for associating a catalog with the metadata; pair of catalog to data. @@ -70,32 +65,6 @@ internal class TypeEnv( } } -/** - * Result of attempting to match an unresolved function. - */ -internal sealed class FnMatch { - - /** - * 7.1 Inputs with wrong types - * It follows that all functions return MISSING when one of their inputs is MISSING - * - * @property signature - * @property mapping - * @property isMissable TRUE when anyone of the arguments _could_ be MISSING. We *always* propagate MISSING. - */ - public data class Ok( - public val signature: T, - public val mapping: Mapping, - public val isMissable: Boolean, - ) : FnMatch() - - public data class Error( - public val identifier: Identifier, - public val args: List, - public val candidates: List, - ) : FnMatch() -} - /** * Metadata regarding a resolved variable. */ @@ -148,13 +117,12 @@ internal enum class ResolutionStrategy { /** * PartiQL Planner Global Environment of Catalogs backed by given plugins. * - * @property header List of namespaced definitions + * @property headers List of namespaced definitions * @property plugins List of plugins for global resolution * @property session Session details */ -@OptIn(PartiQLValueExperimental::class) internal class Env( - private val header: Header, + private val headers: List
, private val plugins: List, private val session: PartiQLPlanner.Session, ) { @@ -165,9 +133,9 @@ internal class Env( public val globals = mutableListOf() /** - * Encapsulate function matching logic in + * Encapsulate all function resolving logic within [FnResolver]. */ - public val functionResolver = FunctionResolver(header) + public val fnResolver = FnResolver(headers) private val connectorSession = object : ConnectorSession { override fun getQueryId(): String = session.queryId @@ -197,51 +165,12 @@ internal class Env( /** * Leverages a [FunctionResolver] to find a matching function defined in the [Header] scalar function catalog. */ - internal fun resolveFn(fn: Fn.Unresolved, args: List): FnMatch { - val candidates = header.lookup(fn) - var hadMissingArg = false - val parameters = args.mapIndexed { i, arg -> - if (!hadMissingArg && arg.type.isMissable()) { - hadMissingArg = true - } - FunctionParameter("arg-$i", arg.type.toRuntimeType()) - } - val match = functionResolver.match(candidates, parameters) - return when (match) { - null -> FnMatch.Error(fn.identifier, args, candidates) - else -> { - val isMissable = hadMissingArg || header.isUnsafeCast(match.signature.specific) - FnMatch.Ok(match.signature, match.mapping, isMissable) - } - } - } + internal fun resolveFn(fn: Fn.Unresolved, args: List) = fnResolver.resolveFn(fn, args) /** * Leverages a [FunctionResolver] to find a matching function defined in the [Header] aggregation function catalog. */ - internal fun resolveAgg(agg: Agg.Unresolved, args: List): FnMatch { - val candidates = header.lookup(agg) - var hadMissingArg = false - val parameters = args.mapIndexed { i, arg -> - if (!hadMissingArg && arg.type.isMissable()) { - hadMissingArg = true - } - FunctionParameter("arg-$i", arg.type.toRuntimeType()) - } - val match = functionResolver.match(candidates, parameters) - return when (match) { - null -> FnMatch.Error(agg.identifier, args, candidates) - else -> { - val isMissable = hadMissingArg || header.isUnsafeCast(match.signature.specific) - FnMatch.Ok(match.signature, match.mapping, isMissable) - } - } - } - - /** - * TODO - */ - internal fun getFnAggHandle(identifier: Identifier): Nothing = TODO() + internal fun resolveAgg(agg: Agg.Unresolved, args: List) = fnResolver.resolveAgg(agg, args) /** * Fetch global object metadata from the given [BindingPath]. diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt index 89f2b3321f..257be92466 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt @@ -1,115 +1,41 @@ package org.partiql.planner -import org.partiql.ast.DatetimeField -import org.partiql.plan.Agg -import org.partiql.plan.Fn -import org.partiql.plan.Identifier -import org.partiql.planner.typer.CastType import org.partiql.planner.typer.TypeLattice import org.partiql.types.function.FunctionParameter import org.partiql.types.function.FunctionSignature import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType -import org.partiql.value.PartiQLValueType.ANY -import org.partiql.value.PartiQLValueType.BAG -import org.partiql.value.PartiQLValueType.BINARY -import org.partiql.value.PartiQLValueType.BLOB -import org.partiql.value.PartiQLValueType.BOOL -import org.partiql.value.PartiQLValueType.BYTE -import org.partiql.value.PartiQLValueType.CHAR -import org.partiql.value.PartiQLValueType.CLOB -import org.partiql.value.PartiQLValueType.DATE -import org.partiql.value.PartiQLValueType.DECIMAL -import org.partiql.value.PartiQLValueType.FLOAT32 -import org.partiql.value.PartiQLValueType.FLOAT64 -import org.partiql.value.PartiQLValueType.INT -import org.partiql.value.PartiQLValueType.INT16 -import org.partiql.value.PartiQLValueType.INT32 -import org.partiql.value.PartiQLValueType.INT64 -import org.partiql.value.PartiQLValueType.INT8 -import org.partiql.value.PartiQLValueType.INTERVAL -import org.partiql.value.PartiQLValueType.LIST -import org.partiql.value.PartiQLValueType.MISSING -import org.partiql.value.PartiQLValueType.NULL -import org.partiql.value.PartiQLValueType.SEXP -import org.partiql.value.PartiQLValueType.STRING -import org.partiql.value.PartiQLValueType.STRUCT -import org.partiql.value.PartiQLValueType.SYMBOL -import org.partiql.value.PartiQLValueType.TIME -import org.partiql.value.PartiQLValueType.TIMESTAMP - -private typealias FnMap = Map> - -/** - * Unicode non-character to be used for name sanitization - */ -private val HIDDEN_FLAG: Char = Char(0xFDD0) /** - * A place for type and function definitions. Eventually these will be read from Ion files. - * - * @property namespace Definition namespace e.g. partiql, spark, redshift, ... - * @property types Type definitions - * @property functions Scalar function definitions - * @property aggregations Aggregation function definitions + * A (temporary) place for function definitions; there are whispers of loading this as information_schema. */ @OptIn(PartiQLValueExperimental::class) -internal class Header( - private val namespace: String, - private val types: TypeLattice, - private val functions: FnMap, - private val aggregations: FnMap, - private val unsafeCastSet: Set, -) { +public abstract class Header { /** - * Return a list of all scalar function signatures matching the given identifier. + * Definition namespace e.g. partiql, spark, redshift, ... */ - public fun lookup(ref: Fn.Unresolved): List { - val name = getFnName(ref.identifier) - return if (ref.isHidden) - functions.getOrDefault("${HIDDEN_FLAG}_$name", emptyList()) - else functions.getOrDefault(name, emptyList()) - } + abstract val namespace: String /** - * Return a list of all aggregation function signatures matching the given identifier. + * Scalar function signatures available via call syntax. */ - public fun lookup(ref: Agg.Unresolved): List { - val name = getFnName(ref.identifier) - return aggregations.getOrDefault(name, emptyList()) - } + open val functions: List = emptyList() /** - * Returns the CAST function if exists, else null. + * Hidden scalar function signatures available via operator or special form syntax. */ - public fun lookupCoercion(valueType: PartiQLValueType, targetType: PartiQLValueType): FunctionSignature.Scalar? { - if (!types.canCoerce(valueType, targetType)) { - return null - } - val name = castName(targetType) - val casts = functions.getOrDefault("${HIDDEN_FLAG}_$name", emptyList()) - for (cast in casts) { - if (cast.parameters.isEmpty()) { - break // should be unreachable - } - if (valueType == cast.parameters[0].type) return cast - } - return null - } + open val operators: List = emptyList() /** - * Easy lookup of whether this CAST can return MISSING. + * Aggregation function signatures. */ - public fun isUnsafeCast(specific: String): Boolean = unsafeCastSet.contains(specific) + open val aggregations: List = emptyList() /** - * Return a normalized function identifier for lookup in our list of function definitions. + * Type relationships; this is primarily a helper for defining operators. */ - private fun getFnName(identifier: Identifier): String = when (identifier) { - is Identifier.Qualified -> throw IllegalArgumentException("Qualified function identifiers not supported") - is Identifier.Symbol -> identifier.symbol.lowercase() - } + internal val types: TypeLattice = TypeLattice.partiql() /** * Dump the Header as SQL commands @@ -117,7 +43,7 @@ internal class Header( * For functions, output CREATE FUNCTION statements. */ override fun toString(): String = buildString { - functions.forEach { + functions.groupBy { it.name }.forEach { appendLine("-- [${it.key}] ---------") appendLine() it.value.forEach { fn -> appendLine(fn) } @@ -125,822 +51,14 @@ internal class Header( } } - companion object { - - /** - * TODO TEMPORARY — Hardcoded PartiQL Global Catalog - */ - public fun partiql(): Header { - val namespace = "partiql" - val types = TypeLattice.partiql() - val (casts, unsafeCastSet) = Functions.casts(types) - val functions = Functions.combine( - Functions.builtins() - ) - val internalFunctions = Functions.combineInternal( - casts, - Functions.operators(), - Functions.special(), - Functions.system(), - ) - val aggregations = Functions.combine( - Functions.aggregations(), - ) - return Header(namespace, types, functions + internalFunctions, aggregations, unsafeCastSet) - } - - /** - * Define CASTS with some mangled name; CAST(x AS T) -> cast_t(x) - * - * CAST(x AS INT8) -> cast_int64(x) - * - * But what about parameterized types? Are the parameters dropped in casts, or do parameters become arguments? - */ - private fun castName(type: PartiQLValueType) = "cast_${type.name.lowercase()}" - } - - /** - * Utilities for building function signatures for the header / symbol table. - */ - internal object Functions { - - /** - * Group list of [FunctionSignature.Scalar] by name. - */ - public fun combine(vararg functions: List): FnMap { - return functions.flatMap { it.sortedWith(functionPrecedence) }.groupBy { it.name } - } - public fun combineInternal(vararg functions: List): FnMap { - return functions.flatMap { it.sortedWith(functionPrecedence) }.groupBy { "${HIDDEN_FLAG}_${it.name}" } - } - - // ==================================== - // TYPES - // ==================================== - private val allTypes = PartiQLValueType.values() - - private val nullableTypes = listOf( - NULL, // null.null - MISSING, // missing - ) - - private val intTypes = listOf( - INT8, - INT16, - INT32, - INT64, - INT, - ) - - private val numericTypes = listOf( - INT8, - INT16, - INT32, - INT64, - INT, - DECIMAL, - FLOAT32, - FLOAT64, - ) - - private val textTypes = listOf( - STRING, - SYMBOL, - CLOB, - ) - - private val collectionTypes = listOf( - BAG, - LIST, - SEXP, - ) - - private val datetimeTypes = listOf( - DATE, - TIME, - TIMESTAMP, - ) - - // ==================================== - // SCALAR FUNCTIONS - // ==================================== - - /** - * Generate all CAST functions from the given lattice. - * - * @param lattice - * @return Pair(0) is the function list, Pair(1) represents the unsafe cast specifics - */ - public fun casts(lattice: TypeLattice): Pair, Set> { - val casts = mutableListOf() - val unsafeCastSet = mutableSetOf() - for (t1 in lattice.types) { - for (t2 in lattice.types) { - val r = lattice.graph[t1.ordinal][t2.ordinal] - if (r != null) { - val fn = cast(t1, t2) - casts.add(fn) - if (r.cast == CastType.UNSAFE) unsafeCastSet.add(fn.specific) - } - } - } - return casts to unsafeCastSet - } - - /** - * Generate all unary and binary operator signatures. - */ - public fun operators(): List = listOf( - not(), - pos(), - neg(), - eq(), - ne(), - and(), - or(), - lt(), - lte(), - gt(), - gte(), - plus(), - minus(), - times(), - div(), - mod(), - concat(), - bitwiseAnd(), - ).flatten() - - /** - * SQL Builtins (not special forms) - */ - public fun builtins(): List = listOf( - upper(), - lower(), - coalesce(), - nullIf(), - position(), - substring(), - trim(), - utcNow(), - ).flatten() - - /** - * SQL and PartiQL special forms - */ - public fun special(): List = listOf( - like(), - between(), - inCollection(), - isType(), - isTypeSingleArg(), - isTypeDoubleArgsInt(), - isTypeTime(), - position(), - substring(), - trimSpecial(), - overlay(), - extract(), - dateAdd(), - dateDiff(), - ).flatten() - - /** - * System functions (for now, CURRENT_USER and CURRENT_DATE) - * - * @return - */ - public fun system(): List = listOf( - currentUser(), - currentDate(), - ) - - // OPERATORS - - private fun not(): List = listOf(unary("not", BOOL, BOOL)) - - private fun pos(): List = numericTypes.map { t -> - unary("pos", t, t) - } - - private fun neg(): List = numericTypes.map { t -> - unary("neg", t, t) - } - - private fun eq(): List = allTypes.map { t -> - FunctionSignature.Scalar( - name = "eq", - returns = BOOL, - isNullCall = false, - isNullable = false, - parameters = listOf(FunctionParameter("lhs", t), FunctionParameter("rhs", t)), - ) - } - - private fun ne(): List = allTypes.map { t -> - binary("ne", BOOL, t, t) - } - - private fun and(): List = listOf(binary("and", BOOL, BOOL, BOOL)) - - private fun or(): List = listOf(binary("or", BOOL, BOOL, BOOL)) - - private fun lt(): List = numericTypes.map { t -> - binary("lt", BOOL, t, t) - } - - private fun lte(): List = numericTypes.map { t -> - binary("lte", BOOL, t, t) - } - - private fun gt(): List = numericTypes.map { t -> - binary("gt", BOOL, t, t) - } - - private fun gte(): List = numericTypes.map { t -> - binary("gte", BOOL, t, t) - } - - private fun plus(): List = numericTypes.map { t -> - binary("plus", t, t, t) - } - - private fun minus(): List = numericTypes.map { t -> - binary("minus", t, t, t) - } - - private fun times(): List = numericTypes.map { t -> - binary("times", t, t, t) - } - - private fun div(): List = numericTypes.map { t -> - binary("divide", t, t, t) - } - - private fun mod(): List = numericTypes.map { t -> - binary("modulo", t, t, t) - } - - private fun concat(): List = textTypes.map { t -> - binary("concat", t, t, t) - } - - private fun bitwiseAnd(): List = intTypes.map { t -> - binary("bitwise_and", t, t, t) - } - - // BUILT INTS - - private fun upper(): List = textTypes.map { t -> - FunctionSignature.Scalar( - name = "upper", - returns = t, - parameters = listOf(FunctionParameter("value", t)), - isNullCall = true, - isNullable = false, - ) - } - - private fun lower(): List = textTypes.map { t -> - FunctionSignature.Scalar( - name = "lower", - returns = t, - parameters = listOf(FunctionParameter("value", t)), - isNullCall = true, - isNullable = false, - ) - } - - // SPECIAL FORMS - - private fun like(): List = listOf( - FunctionSignature.Scalar( - name = "like", - returns = BOOL, - parameters = listOf( - FunctionParameter("value", STRING), - FunctionParameter("pattern", STRING), - ), - isNullCall = true, - isNullable = false, - ), - FunctionSignature.Scalar( - name = "like_escape", - returns = BOOL, - parameters = listOf( - FunctionParameter("value", STRING), - FunctionParameter("pattern", STRING), - FunctionParameter("escape", STRING), - ), - isNullCall = true, - isNullable = false, - ), - ) - - private fun between(): List = numericTypes.map { t -> - FunctionSignature.Scalar( - name = "between", - returns = BOOL, - parameters = listOf( - FunctionParameter("value", t), - FunctionParameter("lower", t), - FunctionParameter("upper", t), - ), - isNullCall = true, - isNullable = false, - ) - } - - private fun inCollection(): List = allTypes.map { element -> - collectionTypes.map { collection -> - FunctionSignature.Scalar( - name = "in_collection", - returns = BOOL, - parameters = listOf( - FunctionParameter("value", element), - FunctionParameter("collection", collection), - ), - isNullCall = true, - isNullable = false, - ) - } - }.flatten() - - // To model type assertion, generating a list of assertion function based on the type, - // and the parameter will be the value entered. - // i.e., 1 is INT2 => is_int16(1) - // TODO: We can remove the types with parameter in this function. - // but, leaving out the decision to have, for example: - // is_decimal(null, null, value) vs is_decimal(value) later.... - private fun isType(): List = allTypes.map { element -> - FunctionSignature.Scalar( - name = "is_${element.name.lowercase()}", - returns = BOOL, - parameters = listOf( - FunctionParameter("value", ANY) // TODO: Decide if we need to further segment this - ), - isNullCall = false, - isNullable = false - ) - } - - // In type assertion, it is possible for types to have args - // i.e., 'a' is CHAR(2) - // we put type parameter before value. - private fun isTypeSingleArg(): List = listOf(CHAR, STRING).map { element -> - FunctionSignature.Scalar( - name = "is_${element.name.lowercase()}", - returns = BOOL, - parameters = listOf( - FunctionParameter("type_parameter_1", INT32), - FunctionParameter("value", ANY) // TODO: Decide if we need to further segment this - ), - isNullCall = false, - isNullable = false - ) - } - - private fun isTypeDoubleArgsInt(): List = listOf(DECIMAL).map { element -> - FunctionSignature.Scalar( - name = "is_${element.name.lowercase()}", - returns = BOOL, - parameters = listOf( - FunctionParameter("type_parameter_1", INT32), - FunctionParameter("type_parameter_2", INT32), - FunctionParameter("value", ANY) // TODO: Decide if we need to further segment this - ), - isNullCall = false, - isNullable = false - ) - } - - private fun isTypeTime(): List = listOf(TIME, TIMESTAMP).map { element -> - FunctionSignature.Scalar( - name = "is_${element.name.lowercase()}", - returns = BOOL, - parameters = listOf( - FunctionParameter("type_parameter_1", BOOL), - FunctionParameter("type_parameter_2", INT32), - FunctionParameter("value", ANY) // TODO: Decide if we need to further segment this - ), - isNullCall = false, - isNullable = false - ) - } - - // TODO - private fun coalesce(): List = emptyList() + // ==================================== + // HELPERS + // ==================================== - // NULLIF(x, y) - private fun nullIf(): List = nullableTypes.map { t -> - FunctionSignature.Scalar( - name = "null_if", - returns = t, - parameters = listOf( - FunctionParameter("value", t), - FunctionParameter("nullifier", BOOL), // TODO: why is this BOOL? - ), - isNullCall = true, - isNullable = true, - ) - } - - // SUBSTRING (expression, start[, length]?) - // SUBSTRINGG(expression from start [FOR length]? ) - private fun substring(): List = textTypes.map { t -> - listOf( - FunctionSignature.Scalar( - name = "substring", - returns = t, - parameters = listOf( - FunctionParameter("value", t), - FunctionParameter("start", INT64), - ), - isNullCall = true, - isNullable = false, - ), - FunctionSignature.Scalar( - name = "substring", - returns = t, - parameters = listOf( - FunctionParameter("value", t), - FunctionParameter("start", INT64), - FunctionParameter("end", INT64), - ), - isNullCall = true, - isNullable = false, - ) - ) - }.flatten() - - // position (str1, str2) - // position (str1 in str2) - private fun position(): List = textTypes.map { t -> - FunctionSignature.Scalar( - name = "position", - returns = INT64, - parameters = listOf( - FunctionParameter("probe", t), - FunctionParameter("value", t), - ), - isNullCall = true, - isNullable = false, - ) - } - - // trim(str) - private fun trim(): List = textTypes.map { t -> - FunctionSignature.Scalar( - name = "trim", - returns = t, - parameters = listOf( - FunctionParameter("value", t), - ), - isNullCall = true, - isNullable = false, - ) - } - - // TODO: We need to add a special form function for TRIM(BOTH FROM value) - private fun trimSpecial(): List = textTypes.map { t -> - listOf( - // TRIM(chars FROM value) - // TRIM(both chars from value) - FunctionSignature.Scalar( - name = "trim_chars", - returns = t, - parameters = listOf( - FunctionParameter("value", t), - FunctionParameter("chars", t), - ), - isNullCall = true, - isNullable = false, - ), - // TRIM(LEADING FROM value) - FunctionSignature.Scalar( - name = "trim_leading", - returns = t, - parameters = listOf( - FunctionParameter("value", t), - ), - isNullCall = true, - isNullable = false, - ), - // TRIM(LEADING chars FROM value) - FunctionSignature.Scalar( - name = "trim_leading_chars", - returns = t, - parameters = listOf( - FunctionParameter("value", t), - FunctionParameter("chars", t), - ), - isNullCall = true, - isNullable = false, - ), - // TRIM(TRAILING FROM value) - FunctionSignature.Scalar( - name = "trim_trailing", - returns = t, - parameters = listOf( - FunctionParameter("value", t), - ), - isNullCall = true, - isNullable = false, - ), - // TRIM(TRAILING chars FROM value) - FunctionSignature.Scalar( - name = "trim_trailing_chars", - returns = t, - parameters = listOf( - FunctionParameter("value", t), - FunctionParameter("chars", t), - ), - isNullCall = true, - isNullable = false, - ), - ) - }.flatten() - - // TODO - private fun overlay(): List = emptyList() - - // TODO - private fun extract(): List = emptyList() - - private fun dateAdd(): List { - val operators = mutableListOf() - for (type in datetimeTypes) { - for (field in DatetimeField.values()) { - if (field == DatetimeField.TIMEZONE_HOUR || field == DatetimeField.TIMEZONE_MINUTE) { - continue - } - val signature = FunctionSignature.Scalar( - name = "date_add_${field.name.lowercase()}", - returns = type, - parameters = listOf( - FunctionParameter("interval", INT), - FunctionParameter("datetime", type), - ), - isNullCall = true, - isNullable = false, - ) - operators.add(signature) - } - } - return operators - } - - private fun dateDiff(): List { - val operators = mutableListOf() - for (type in datetimeTypes) { - for (field in DatetimeField.values()) { - if (field == DatetimeField.TIMEZONE_HOUR || field == DatetimeField.TIMEZONE_MINUTE) { - continue - } - val signature = FunctionSignature.Scalar( - name = "date_diff_${field.name.lowercase()}", - returns = INT64, - parameters = listOf( - FunctionParameter("datetime1", type), - FunctionParameter("datetime2", type), - ), - isNullCall = true, - isNullable = false, - ) - operators.add(signature) - } - } - return operators - } - - private fun utcNow(): List = listOf( - FunctionSignature.Scalar( - name = "utcnow", - returns = TIMESTAMP, - parameters = emptyList(), - isNullable = false, - ) - ) - - private fun currentUser() = FunctionSignature.Scalar( - name = "current_user", - returns = STRING, - parameters = emptyList(), - isNullable = true, - ) - - private fun currentDate() = FunctionSignature.Scalar( - name = "current_date", - returns = DATE, - parameters = emptyList(), - isNullable = false, - ) - - // ==================================== - // AGGREGATIONS - // ==================================== - - /** - * SQL and PartiQL Aggregation Builtins - */ - public fun aggregations(): List = listOf( - every(), - any(), - some(), - count(), - min(), - max(), - sum(), - avg(), - ).flatten() - - private fun every() = listOf( - FunctionSignature.Aggregation( - name = "every", - returns = BOOL, - parameters = listOf(FunctionParameter("value", BOOL)), - isNullable = true, - ), - FunctionSignature.Aggregation( - name = "every", - returns = BOOL, - parameters = listOf(FunctionParameter("value", ANY)), - isNullable = true, - ), - ) - - private fun any() = listOf( - FunctionSignature.Aggregation( - name = "any", - returns = BOOL, - parameters = listOf(FunctionParameter("value", BOOL)), - isNullable = true, - ), - FunctionSignature.Aggregation( - name = "any", - returns = BOOL, - parameters = listOf(FunctionParameter("value", ANY)), - isNullable = true, - ), - ) - - private fun some() = listOf( - FunctionSignature.Aggregation( - name = "some", - returns = BOOL, - parameters = listOf(FunctionParameter("value", BOOL)), - isNullable = true, - ), - FunctionSignature.Aggregation( - name = "some", - returns = BOOL, - parameters = listOf(FunctionParameter("value", ANY)), - isNullable = true, - ), - ) - - private fun count() = listOf( - FunctionSignature.Aggregation( - name = "count", - returns = INT32, - parameters = listOf(FunctionParameter("value", ANY)), - isNullable = false, - ), - FunctionSignature.Aggregation( - name = "count_star", - returns = INT32, - parameters = listOf(), - isNullable = false, - ), - ) - - private fun min() = numericTypes.map { - FunctionSignature.Aggregation( - name = "min", - returns = it, - parameters = listOf(FunctionParameter("value", it)), - isNullable = true, - ) - } + FunctionSignature.Aggregation( - name = "min", - returns = ANY, - parameters = listOf(FunctionParameter("value", ANY)), - isNullable = true, - ) - - private fun max() = numericTypes.map { - FunctionSignature.Aggregation( - name = "max", - returns = it, - parameters = listOf(FunctionParameter("value", it)), - isNullable = true, - ) - } + FunctionSignature.Aggregation( - name = "max", - returns = ANY, - parameters = listOf(FunctionParameter("value", ANY)), - isNullable = true, - ) - - private fun sum() = numericTypes.map { - FunctionSignature.Aggregation( - name = "sum", - returns = it, - parameters = listOf(FunctionParameter("value", it)), - isNullable = true, - ) - } + FunctionSignature.Aggregation( - name = "sum", - returns = ANY, - parameters = listOf(FunctionParameter("value", ANY)), - isNullable = true, - ) - - private fun avg() = numericTypes.map { - FunctionSignature.Aggregation( - name = "avg", - returns = it, - parameters = listOf(FunctionParameter("value", it)), - isNullable = true, - ) - } + FunctionSignature.Aggregation( - name = "avg", - returns = ANY, - parameters = listOf(FunctionParameter("value", ANY)), - isNullable = true, - ) - - // ==================================== - // SORTING - // ==================================== - - // Function precedence comparator - // 1. Fewest args first - // 2. Parameters are compared left-to-right - private val functionPrecedence = Comparator { fn1, fn2 -> - // Compare number of arguments - if (fn1.parameters.size != fn2.parameters.size) { - return@Comparator fn1.parameters.size - fn2.parameters.size - } - // Compare operand type precedence - for (i in fn1.parameters.indices) { - val p1 = fn1.parameters[i] - val p2 = fn2.parameters[i] - val comparison = p1.compareTo(p2) - if (comparison != 0) return@Comparator comparison - } - // unreachable? - 0 - } - - private fun FunctionParameter.compareTo(other: FunctionParameter): Int = - comparePrecedence(this.type, other.type) - - private fun comparePrecedence(t1: PartiQLValueType, t2: PartiQLValueType): Int { - if (t1 == t2) return 0 - val p1 = typePrecedence[t1]!! - val p2 = typePrecedence[t2]!! - return p1 - p2 - } - - // This simply describes some precedence for ordering functions. - // This is not explicitly defined in the PartiQL Specification - // This does not imply the ability to CAST. - private val typePrecedence: Map = listOf( - NULL, - MISSING, - BOOL, - INT8, - INT16, - INT32, - INT64, - INT, - DECIMAL, - FLOAT32, - FLOAT64, - CHAR, - STRING, - CLOB, - SYMBOL, - BINARY, - BYTE, - BLOB, - DATE, - TIME, - TIMESTAMP, - INTERVAL, - LIST, - SEXP, - BAG, - STRUCT, - ANY, - ).mapIndexed { precedence, type -> type to precedence }.toMap() - - // ==================================== - // HELPERS - // ==================================== + companion object { - public fun unary(name: String, returns: PartiQLValueType, value: PartiQLValueType) = + @JvmStatic + internal fun unary(name: String, returns: PartiQLValueType, value: PartiQLValueType) = FunctionSignature.Scalar( name = name, returns = returns, @@ -949,7 +67,8 @@ internal class Header( isNullable = false, ) - public fun binary(name: String, returns: PartiQLValueType, lhs: PartiQLValueType, rhs: PartiQLValueType) = + @JvmStatic + internal fun binary(name: String, returns: PartiQLValueType, lhs: PartiQLValueType, rhs: PartiQLValueType) = FunctionSignature.Scalar( name = name, returns = returns, @@ -957,16 +76,5 @@ internal class Header( isNullCall = true, isNullable = false, ) - - public fun cast(operand: PartiQLValueType, target: PartiQLValueType) = - FunctionSignature.Scalar( - name = castName(target), - returns = target, - isNullCall = true, - isNullable = false, - parameters = listOf( - FunctionParameter("value", operand), - ) - ) } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt new file mode 100644 index 0000000000..02ab2411c9 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt @@ -0,0 +1,670 @@ +package org.partiql.planner + +import org.partiql.ast.DatetimeField +import org.partiql.types.function.FunctionParameter +import org.partiql.types.function.FunctionSignature +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType.ANY +import org.partiql.value.PartiQLValueType.BOOL +import org.partiql.value.PartiQLValueType.CHAR +import org.partiql.value.PartiQLValueType.DATE +import org.partiql.value.PartiQLValueType.DECIMAL +import org.partiql.value.PartiQLValueType.INT +import org.partiql.value.PartiQLValueType.INT32 +import org.partiql.value.PartiQLValueType.INT64 +import org.partiql.value.PartiQLValueType.STRING +import org.partiql.value.PartiQLValueType.TIME +import org.partiql.value.PartiQLValueType.TIMESTAMP + +/** + * A header which uses the PartiQL Lang Kotlin default standard library. All functions exist in a global namespace. + * Once we have catalogs with information_schema, the PartiQL Header will be fixed on a specification version and + * user defined functions will be defined within their own schema. + */ +@OptIn(PartiQLValueExperimental::class) +object PartiQLHeader : Header() { + + override val namespace: String = "partiql" + + /** + * PartiQL Scalar Functions accessible via call syntax. + */ + override val functions = scalarBuiltins() + + /** + * PartiQL Scalar Functions accessible via special form syntax (unary, binary, infix keywords, etc). + */ + override val operators = listOf( + operators(), + special(), + system(), + ).flatten() + + /** + * PartiQL Aggregation Functions accessible via + */ + override val aggregations = aggBuiltins() + + /** + * Generate all unary and binary operator signatures. + */ + private fun operators(): List = listOf( + not(), + pos(), + neg(), + eq(), + ne(), + and(), + or(), + lt(), + lte(), + gt(), + gte(), + plus(), + minus(), + times(), + div(), + mod(), + concat(), + bitwiseAnd(), + ).flatten() + + /** + * SQL Builtins (not special forms) + */ + private fun scalarBuiltins(): List = listOf( + upper(), + lower(), + coalesce(), + nullIf(), + position(), + substring(), + trim(), + utcNow(), + ).flatten() + + /** + * SQL and PartiQL special forms + */ + private fun special(): List = listOf( + like(), + between(), + inCollection(), + isType(), + isTypeSingleArg(), + isTypeDoubleArgsInt(), + isTypeTime(), + position(), + substring(), + trimSpecial(), + overlay(), + extract(), + dateAdd(), + dateDiff(), + ).flatten() + + /** + * System functions (for now, CURRENT_USER and CURRENT_DATE) + * + * @return + */ + private fun system(): List = listOf( + currentUser(), + currentDate(), + ) + + // OPERATORS + + private fun not(): List = listOf(unary("not", BOOL, BOOL)) + + private fun pos(): List = types.numeric.map { t -> + unary("pos", t, t) + } + + private fun neg(): List = types.numeric.map { t -> + unary("neg", t, t) + } + + private fun eq(): List = types.all.map { t -> + FunctionSignature.Scalar( + name = "eq", + returns = BOOL, + isNullCall = false, + isNullable = false, + parameters = listOf(FunctionParameter("lhs", t), FunctionParameter("rhs", t)), + ) + } + + private fun ne(): List = types.all.map { t -> + binary("ne", BOOL, t, t) + } + + private fun and(): List = listOf(binary("and", BOOL, BOOL, BOOL)) + + private fun or(): List = listOf(binary("or", BOOL, BOOL, BOOL)) + + private fun lt(): List = types.numeric.map { t -> + binary("lt", BOOL, t, t) + } + + private fun lte(): List = types.numeric.map { t -> + binary("lte", BOOL, t, t) + } + + private fun gt(): List = types.numeric.map { t -> + binary("gt", BOOL, t, t) + } + + private fun gte(): List = types.numeric.map { t -> + binary("gte", BOOL, t, t) + } + + private fun plus(): List = types.numeric.map { t -> + binary("plus", t, t, t) + } + + private fun minus(): List = types.numeric.map { t -> + binary("minus", t, t, t) + } + + private fun times(): List = types.numeric.map { t -> + binary("times", t, t, t) + } + + private fun div(): List = types.numeric.map { t -> + binary("divide", t, t, t) + } + + private fun mod(): List = types.numeric.map { t -> + binary("modulo", t, t, t) + } + + private fun concat(): List = types.text.map { t -> + binary("concat", t, t, t) + } + + private fun bitwiseAnd(): List = types.integer.map { t -> + binary("bitwise_and", t, t, t) + } + + // BUILT INTS + + private fun upper(): List = types.text.map { t -> + FunctionSignature.Scalar( + name = "upper", + returns = t, + parameters = listOf(FunctionParameter("value", t)), + isNullCall = true, + isNullable = false, + ) + } + + private fun lower(): List = types.text.map { t -> + FunctionSignature.Scalar( + name = "lower", + returns = t, + parameters = listOf(FunctionParameter("value", t)), + isNullCall = true, + isNullable = false, + ) + } + + // SPECIAL FORMS + + private fun like(): List = listOf( + FunctionSignature.Scalar( + name = "like", + returns = BOOL, + parameters = listOf( + FunctionParameter("value", STRING), + FunctionParameter("pattern", STRING), + ), + isNullCall = true, + isNullable = false, + ), + FunctionSignature.Scalar( + name = "like_escape", + returns = BOOL, + parameters = listOf( + FunctionParameter("value", STRING), + FunctionParameter("pattern", STRING), + FunctionParameter("escape", STRING), + ), + isNullCall = true, + isNullable = false, + ), + ) + + private fun between(): List = types.numeric.map { t -> + FunctionSignature.Scalar( + name = "between", + returns = BOOL, + parameters = listOf( + FunctionParameter("value", t), + FunctionParameter("lower", t), + FunctionParameter("upper", t), + ), + isNullCall = true, + isNullable = false, + ) + } + + private fun inCollection(): List = types.all.map { element -> + types.collections.map { collection -> + FunctionSignature.Scalar( + name = "in_collection", + returns = BOOL, + parameters = listOf( + FunctionParameter("value", element), + FunctionParameter("collection", collection), + ), + isNullCall = true, + isNullable = false, + ) + } + }.flatten() + + // To model type assertion, generating a list of assertion function based on the type, + // and the parameter will be the value entered. + // i.e., 1 is INT2 => is_int16(1) + // TODO: We can remove the types with parameter in this function. + // but, leaving out the decision to have, for example: + // is_decimal(null, null, value) vs is_decimal(value) later.... + private fun isType(): List = types.all.map { element -> + FunctionSignature.Scalar( + name = "is_${element.name.lowercase()}", + returns = BOOL, + parameters = listOf( + FunctionParameter("value", ANY) // TODO: Decide if we need to further segment this + ), + isNullCall = false, + isNullable = false + ) + } + + // In type assertion, it is possible for types to have args + // i.e., 'a' is CHAR(2) + // we put type parameter before value. + private fun isTypeSingleArg(): List = listOf(CHAR, STRING).map { element -> + FunctionSignature.Scalar( + name = "is_${element.name.lowercase()}", + returns = BOOL, + parameters = listOf( + FunctionParameter("type_parameter_1", INT32), + FunctionParameter("value", ANY) // TODO: Decide if we need to further segment this + ), + isNullCall = false, + isNullable = false + ) + } + + private fun isTypeDoubleArgsInt(): List = listOf(DECIMAL).map { element -> + FunctionSignature.Scalar( + name = "is_${element.name.lowercase()}", + returns = BOOL, + parameters = listOf( + FunctionParameter("type_parameter_1", INT32), + FunctionParameter("type_parameter_2", INT32), + FunctionParameter("value", ANY) // TODO: Decide if we need to further segment this + ), + isNullCall = false, + isNullable = false + ) + } + + private fun isTypeTime(): List = listOf(TIME, TIMESTAMP).map { element -> + FunctionSignature.Scalar( + name = "is_${element.name.lowercase()}", + returns = BOOL, + parameters = listOf( + FunctionParameter("type_parameter_1", BOOL), + FunctionParameter("type_parameter_2", INT32), + FunctionParameter("value", ANY) // TODO: Decide if we need to further segment this + ), + isNullCall = false, + isNullable = false + ) + } + + // TODO + private fun coalesce(): List = emptyList() + + // NULLIF(x, y) + private fun nullIf(): List = types.nullable.map { t -> + FunctionSignature.Scalar( + name = "null_if", + returns = t, + parameters = listOf( + FunctionParameter("value", t), + FunctionParameter("nullifier", BOOL), // TODO: why is this BOOL? + ), + isNullCall = true, + isNullable = true, + ) + } + + // SUBSTRING (expression, start[, length]?) + // SUBSTRINGG(expression from start [FOR length]? ) + private fun substring(): List = types.text.map { t -> + listOf( + FunctionSignature.Scalar( + name = "substring", + returns = t, + parameters = listOf( + FunctionParameter("value", t), + FunctionParameter("start", INT64), + ), + isNullCall = true, + isNullable = false, + ), + FunctionSignature.Scalar( + name = "substring", + returns = t, + parameters = listOf( + FunctionParameter("value", t), + FunctionParameter("start", INT64), + FunctionParameter("end", INT64), + ), + isNullCall = true, + isNullable = false, + ) + ) + }.flatten() + + // position (str1, str2) + // position (str1 in str2) + private fun position(): List = types.text.map { t -> + FunctionSignature.Scalar( + name = "position", + returns = INT64, + parameters = listOf( + FunctionParameter("probe", t), + FunctionParameter("value", t), + ), + isNullCall = true, + isNullable = false, + ) + } + + // trim(str) + private fun trim(): List = types.text.map { t -> + FunctionSignature.Scalar( + name = "trim", + returns = t, + parameters = listOf( + FunctionParameter("value", t), + ), + isNullCall = true, + isNullable = false, + ) + } + + // TODO: We need to add a special form function for TRIM(BOTH FROM value) + private fun trimSpecial(): List = types.text.map { t -> + listOf( + // TRIM(chars FROM value) + // TRIM(both chars from value) + FunctionSignature.Scalar( + name = "trim_chars", + returns = t, + parameters = listOf( + FunctionParameter("value", t), + FunctionParameter("chars", t), + ), + isNullCall = true, + isNullable = false, + ), + // TRIM(LEADING FROM value) + FunctionSignature.Scalar( + name = "trim_leading", + returns = t, + parameters = listOf( + FunctionParameter("value", t), + ), + isNullCall = true, + isNullable = false, + ), + // TRIM(LEADING chars FROM value) + FunctionSignature.Scalar( + name = "trim_leading_chars", + returns = t, + parameters = listOf( + FunctionParameter("value", t), + FunctionParameter("chars", t), + ), + isNullCall = true, + isNullable = false, + ), + // TRIM(TRAILING FROM value) + FunctionSignature.Scalar( + name = "trim_trailing", + returns = t, + parameters = listOf( + FunctionParameter("value", t), + ), + isNullCall = true, + isNullable = false, + ), + // TRIM(TRAILING chars FROM value) + FunctionSignature.Scalar( + name = "trim_trailing_chars", + returns = t, + parameters = listOf( + FunctionParameter("value", t), + FunctionParameter("chars", t), + ), + isNullCall = true, + isNullable = false, + ), + ) + }.flatten() + + // TODO + private fun overlay(): List = emptyList() + + // TODO + private fun extract(): List = emptyList() + + private fun dateAdd(): List { + val operators = mutableListOf() + for (type in types.datetime) { + for (field in DatetimeField.values()) { + if (field == DatetimeField.TIMEZONE_HOUR || field == DatetimeField.TIMEZONE_MINUTE) { + continue + } + val signature = FunctionSignature.Scalar( + name = "date_add_${field.name.lowercase()}", + returns = type, + parameters = listOf( + FunctionParameter("interval", INT), + FunctionParameter("datetime", type), + ), + isNullCall = true, + isNullable = false, + ) + operators.add(signature) + } + } + return operators + } + + private fun dateDiff(): List { + val operators = mutableListOf() + for (type in types.datetime) { + for (field in DatetimeField.values()) { + if (field == DatetimeField.TIMEZONE_HOUR || field == DatetimeField.TIMEZONE_MINUTE) { + continue + } + val signature = FunctionSignature.Scalar( + name = "date_diff_${field.name.lowercase()}", + returns = INT64, + parameters = listOf( + FunctionParameter("datetime1", type), + FunctionParameter("datetime2", type), + ), + isNullCall = true, + isNullable = false, + ) + operators.add(signature) + } + } + return operators + } + + private fun utcNow(): List = listOf( + FunctionSignature.Scalar( + name = "utcnow", + returns = TIMESTAMP, + parameters = emptyList(), + isNullable = false, + ) + ) + + private fun currentUser() = FunctionSignature.Scalar( + name = "current_user", + returns = STRING, + parameters = emptyList(), + isNullable = true, + ) + + private fun currentDate() = FunctionSignature.Scalar( + name = "current_date", + returns = DATE, + parameters = emptyList(), + isNullable = false, + ) + + // ==================================== + // AGGREGATIONS + // ==================================== + + /** + * SQL and PartiQL Aggregation Builtins + */ + private fun aggBuiltins(): List = listOf( + every(), + any(), + some(), + count(), + min(), + max(), + sum(), + avg(), + ).flatten() + + private fun every() = listOf( + FunctionSignature.Aggregation( + name = "every", + returns = BOOL, + parameters = listOf(FunctionParameter("value", BOOL)), + isNullable = true, + ), + FunctionSignature.Aggregation( + name = "every", + returns = BOOL, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ), + ) + + private fun any() = listOf( + FunctionSignature.Aggregation( + name = "any", + returns = BOOL, + parameters = listOf(FunctionParameter("value", BOOL)), + isNullable = true, + ), + FunctionSignature.Aggregation( + name = "any", + returns = BOOL, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ), + ) + + private fun some() = listOf( + FunctionSignature.Aggregation( + name = "some", + returns = BOOL, + parameters = listOf(FunctionParameter("value", BOOL)), + isNullable = true, + ), + FunctionSignature.Aggregation( + name = "some", + returns = BOOL, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ), + ) + + private fun count() = listOf( + FunctionSignature.Aggregation( + name = "count", + returns = INT32, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = false, + ), + FunctionSignature.Aggregation( + name = "count_star", + returns = INT32, + parameters = listOf(), + isNullable = false, + ), + ) + + private fun min() = types.numeric.map { + FunctionSignature.Aggregation( + name = "min", + returns = it, + parameters = listOf(FunctionParameter("value", it)), + isNullable = true, + ) + } + FunctionSignature.Aggregation( + name = "min", + returns = ANY, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ) + + private fun max() = types.numeric.map { + FunctionSignature.Aggregation( + name = "max", + returns = it, + parameters = listOf(FunctionParameter("value", it)), + isNullable = true, + ) + } + FunctionSignature.Aggregation( + name = "max", + returns = ANY, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ) + + private fun sum() = types.numeric.map { + FunctionSignature.Aggregation( + name = "sum", + returns = it, + parameters = listOf(FunctionParameter("value", it)), + isNullable = true, + ) + } + FunctionSignature.Aggregation( + name = "sum", + returns = ANY, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ) + + private fun avg() = types.numeric.map { + FunctionSignature.Aggregation( + name = "avg", + returns = it, + parameters = listOf(FunctionParameter("value", it)), + isNullable = true, + ) + } + FunctionSignature.Aggregation( + name = "avg", + returns = ANY, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ) +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerBuilder.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerBuilder.kt index 0b4cf16b44..cec415092c 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerBuilder.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerBuilder.kt @@ -7,10 +7,11 @@ import org.partiql.spi.Plugin */ class PartiQLPlannerBuilder { + private var headers: MutableList
= mutableListOf(PartiQLHeader) private var plugins: List = emptyList() private var passes: List = emptyList() - fun build(): PartiQLPlanner = PartiQLPlannerDefault(plugins, passes) + fun build(): PartiQLPlanner = PartiQLPlannerDefault(headers, plugins, passes) public fun plugins(plugins: List): PartiQLPlannerBuilder = this.apply { this.plugins = plugins @@ -19,4 +20,8 @@ class PartiQLPlannerBuilder { public fun passes(passes: List): PartiQLPlannerBuilder = this.apply { this.passes = passes } + + public fun headers(headers: List
): PartiQLPlannerBuilder = this.apply { + this.headers += headers + } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt index c7354bce65..ee94048183 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt @@ -13,18 +13,20 @@ import org.partiql.spi.Plugin * Default PartiQL logical query planner. */ internal class PartiQLPlannerDefault( + private val headers: List
, private val plugins: List, private val passes: List, ) : PartiQLPlanner { private val version = PartiQLVersion.VERSION_0_1 - // For now, only have the default header - private val header = Header.partiql() - - override fun plan(statement: Statement, session: PartiQLPlanner.Session, onProblem: ProblemCallback): PartiQLPlanner.Result { + override fun plan( + statement: Statement, + session: PartiQLPlanner.Session, + onProblem: ProblemCallback, + ): PartiQLPlanner.Result { // 0. Initialize the planning environment - val env = Env(header, plugins, session) + val env = Env(headers, plugins, session) // 1. Normalize val ast = statement.normalize() diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/FnResolver.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/FnResolver.kt new file mode 100644 index 0000000000..ad1808f1d3 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/FnResolver.kt @@ -0,0 +1,395 @@ +package org.partiql.planner.typer + +import org.partiql.plan.Agg +import org.partiql.plan.Fn +import org.partiql.plan.Identifier +import org.partiql.plan.Rex +import org.partiql.planner.Header +import org.partiql.types.function.FunctionParameter +import org.partiql.types.function.FunctionSignature +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType +import org.partiql.value.PartiQLValueType.ANY +import org.partiql.value.PartiQLValueType.BAG +import org.partiql.value.PartiQLValueType.BINARY +import org.partiql.value.PartiQLValueType.BLOB +import org.partiql.value.PartiQLValueType.BOOL +import org.partiql.value.PartiQLValueType.BYTE +import org.partiql.value.PartiQLValueType.CHAR +import org.partiql.value.PartiQLValueType.CLOB +import org.partiql.value.PartiQLValueType.DATE +import org.partiql.value.PartiQLValueType.DECIMAL +import org.partiql.value.PartiQLValueType.FLOAT32 +import org.partiql.value.PartiQLValueType.FLOAT64 +import org.partiql.value.PartiQLValueType.INT +import org.partiql.value.PartiQLValueType.INT16 +import org.partiql.value.PartiQLValueType.INT32 +import org.partiql.value.PartiQLValueType.INT64 +import org.partiql.value.PartiQLValueType.INT8 +import org.partiql.value.PartiQLValueType.INTERVAL +import org.partiql.value.PartiQLValueType.LIST +import org.partiql.value.PartiQLValueType.MISSING +import org.partiql.value.PartiQLValueType.NULL +import org.partiql.value.PartiQLValueType.SEXP +import org.partiql.value.PartiQLValueType.STRING +import org.partiql.value.PartiQLValueType.STRUCT +import org.partiql.value.PartiQLValueType.SYMBOL +import org.partiql.value.PartiQLValueType.TIME +import org.partiql.value.PartiQLValueType.TIMESTAMP + +/** + * Function signature lookup by name. + */ +internal typealias FnMap = Map> + +/** + * Function arguments list. The planner is responsible for mapping arguments to parameters. + */ +internal typealias Args = List + +/** + * Parameter mapping list tells the planner where to insert implicit casts. Null is the identity. + */ +internal typealias Mapping = List + +/** + * Tells us which function matched, and how the arguments are mapped. + */ +internal class Match( + public val signature: T, + public val mapping: Mapping, +) + +/** + * Result of attempting to match an unresolved function. + */ +internal sealed class FnMatch { + + /** + * 7.1 Inputs with wrong types + * It follows that all functions return MISSING when one of their inputs is MISSING + * + * @property signature + * @property mapping + * @property isMissable TRUE when anyone of the arguments _could_ be MISSING. We *always* propagate MISSING. + */ + public data class Ok( + public val signature: T, + public val mapping: Mapping, + public val isMissable: Boolean, + ) : FnMatch() + + public data class Error( + public val identifier: Identifier, + public val args: List, + public val candidates: List, + ) : FnMatch() +} + +/** + * Logic for matching signatures to arguments — this class contains all cast/coercion logic. In my opinion, casts + * and coercions should come along with the type lattice. Considering we don't really have this, it is simple enough + * at the moment to keep that information (derived from the current TypeLattice) with the [FnResolver]. + */ +@OptIn(PartiQLValueExperimental::class) +internal class FnResolver(private val headers: List
) { + + /** + * All headers use the same type lattice (we don't have a design for plugging type systems at the moment). + */ + private val types = TypeLattice.partiql() + + /** + * Calculate a queryable map of scalar function signatures. + */ + private val functions: FnMap + + /** + * Calculate a queryable map of scalar function signatures from special forms. + */ + private val operators: FnMap + + /** + * Calculate a queryable map of aggregation function signatures + */ + private val aggregations: FnMap + + /** + * A place to quickly lookup a cast can return missing; lookup by "SPECIFIC" + */ + private val unsafeCastSet: Set + + init { + val (casts, unsafeCasts) = casts() + unsafeCastSet = unsafeCasts + // combine all header definitions + val fns = headers.flatMap { it.functions } + functions = fns.toFnMap() + operators = (headers.flatMap { it.operators } + casts).toFnMap() + aggregations = headers.flatMap { it.aggregations }.toFnMap() + } + + /** + * Group list of [FunctionSignature] by name. + */ + private fun List.toFnMap(): FnMap = this + .distinctBy { it.specific } + .sortedWith(fnPrecedence) + .groupBy { it.name } + + /** + * Leverages a [FnResolver] to find a matching function defined in the [Header] scalar function catalog. + */ + public fun resolveFn(fn: Fn.Unresolved, args: List): FnMatch { + val candidates = lookup(fn) + var hadMissingArg = false + val parameters = args.mapIndexed { i, arg -> + if (!hadMissingArg && arg.type.isMissable()) { + hadMissingArg = true + } + FunctionParameter("arg-$i", arg.type.toRuntimeType()) + } + val match = match(candidates, parameters) + return when (match) { + null -> FnMatch.Error(fn.identifier, args, candidates) + else -> { + val isMissable = hadMissingArg || isUnsafeCast(match.signature.specific) + FnMatch.Ok(match.signature, match.mapping, isMissable) + } + } + } + + /** + * Leverages a [FnResolver] to find a matching function defined in the [Header] aggregation function catalog. + */ + public fun resolveAgg(agg: Agg.Unresolved, args: List): FnMatch { + val candidates = lookup(agg) + var hadMissingArg = false + val parameters = args.mapIndexed { i, arg -> + if (!hadMissingArg && arg.type.isMissable()) { + hadMissingArg = true + } + FunctionParameter("arg-$i", arg.type.toRuntimeType()) + } + val match = match(candidates, parameters) + return when (match) { + null -> FnMatch.Error(agg.identifier, args, candidates) + else -> { + val isMissable = hadMissingArg || isUnsafeCast(match.signature.specific) + FnMatch.Ok(match.signature, match.mapping, isMissable) + } + } + } + + /** + * Functions are sorted by precedence (which is not rigorously defined/specified at the moment). + */ + private fun match(signatures: List, args: Args): Match? { + for (signature in signatures) { + val mapping = match(signature, args) + if (mapping != null) { + return Match(signature, mapping) + } + } + return null + } + + /** + * Attempt to match arguments to the parameters; return the implicit casts if necessary. + * + * TODO we need to constrain the allowable runtime types for an ANY typed parameter. + */ + fun match(signature: FunctionSignature, args: Args): Mapping? { + if (signature.parameters.size != args.size) { + return null + } + val mapping = ArrayList(args.size) + for (i in args.indices) { + val a = args[i] + val p = signature.parameters[i] + when { + // 1. Exact match + a.type == p.type -> mapping.add(null) + // 2. Match ANY, no coercion needed + p.type == ANY -> mapping.add(null) + // 3. Match NULL argument + a.type == NULL -> mapping.add(null) + // 4. Check for a coercion + else -> { + val coercion = lookupCoercion(a.type, p.type) + when (coercion) { + null -> return null // short-circuit + else -> mapping.add(coercion) + } + } + } + } + // we made a match + return mapping + } + + /** + * Return a list of all scalar function signatures matching the given identifier. + */ + private fun lookup(ref: Fn.Unresolved): List { + val name = getFnName(ref.identifier) + return when (ref.isHidden) { + true -> operators.getOrDefault(name, emptyList()) + else -> functions.getOrDefault(name, emptyList()) + } + } + + /** + * Return a list of all aggregation function signatures matching the given identifier. + */ + private fun lookup(ref: Agg.Unresolved): List { + val name = getFnName(ref.identifier) + return aggregations.getOrDefault(name, emptyList()) + } + + /** + * Return a normalized function identifier for lookup in our list of function definitions. + */ + private fun getFnName(identifier: Identifier): String = when (identifier) { + is Identifier.Qualified -> throw IllegalArgumentException("Qualified function identifiers not supported") + is Identifier.Symbol -> identifier.symbol.lowercase() + } + + // ==================================== + // CASTS and COERCIONS + // ==================================== + + /** + * Returns the CAST function if exists, else null. + */ + private fun lookupCoercion(valueType: PartiQLValueType, targetType: PartiQLValueType): FunctionSignature.Scalar? { + if (!types.canCoerce(valueType, targetType)) { + return null + } + val name = castName(targetType) + val casts = operators.getOrDefault(name, emptyList()) + for (cast in casts) { + if (cast.parameters.isEmpty()) { + break // should be unreachable + } + if (valueType == cast.parameters[0].type) return cast + } + return null + } + + /** + * Easy lookup of whether this CAST can return MISSING. + */ + private fun isUnsafeCast(specific: String): Boolean = unsafeCastSet.contains(specific) + + /** + * Generate all CAST functions from the given lattice. + * + * @return Pair(0) is the function list, Pair(1) represents the unsafe cast specifics + */ + private fun casts(): Pair, Set> { + val casts = mutableListOf() + val unsafeCastSet = mutableSetOf() + for (t1 in types.types) { + for (t2 in types.types) { + val r = types.graph[t1.ordinal][t2.ordinal] + if (r != null) { + val fn = cast(t1, t2) + casts.add(fn) + if (r.cast == CastType.UNSAFE) unsafeCastSet.add(fn.specific) + } + } + } + return casts to unsafeCastSet + } + + /** + * Define CASTS with some mangled name; CAST(x AS T) -> cast_t(x) + * + * CAST(x AS INT8) -> cast_int64(x) + * + * But what about parameterized types? Are the parameters dropped in casts, or do parameters become arguments? + */ + private fun castName(type: PartiQLValueType) = "cast_${type.name.lowercase()}" + + internal fun cast(operand: PartiQLValueType, target: PartiQLValueType) = + FunctionSignature.Scalar( + name = castName(target), + returns = target, + isNullCall = true, + isNullable = false, + parameters = listOf( + FunctionParameter("value", operand), + ) + ) + + companion object { + + // ==================================== + // SORTING + // ==================================== + + // Function precedence comparator + // 1. Fewest args first + // 2. Parameters are compared left-to-right + @JvmStatic + private val fnPrecedence = Comparator { fn1, fn2 -> + // Compare number of arguments + if (fn1.parameters.size != fn2.parameters.size) { + return@Comparator fn1.parameters.size - fn2.parameters.size + } + // Compare operand type precedence + for (i in fn1.parameters.indices) { + val p1 = fn1.parameters[i] + val p2 = fn2.parameters[i] + val comparison = p1.compareTo(p2) + if (comparison != 0) return@Comparator comparison + } + // unreachable? + 0 + } + + private fun FunctionParameter.compareTo(other: FunctionParameter): Int = + comparePrecedence(this.type, other.type) + + private fun comparePrecedence(t1: PartiQLValueType, t2: PartiQLValueType): Int { + if (t1 == t2) return 0 + val p1 = precedence[t1]!! + val p2 = precedence[t2]!! + return p1 - p2 + } + + // This simply describes some precedence for ordering functions. + // This is not explicitly defined in the PartiQL Specification!! + // This does not imply the ability to CAST; this defines function resolution behavior. + private val precedence: Map = listOf( + NULL, + MISSING, + BOOL, + INT8, + INT16, + INT32, + INT64, + INT, + DECIMAL, + FLOAT32, + FLOAT64, + CHAR, + STRING, + CLOB, + SYMBOL, + BINARY, + BYTE, + BLOB, + DATE, + TIME, + TIMESTAMP, + INTERVAL, + LIST, + SEXP, + BAG, + STRUCT, + ANY, + ).mapIndexed { precedence, type -> type to precedence }.toMap() + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/FunctionResolver.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/FunctionResolver.kt deleted file mode 100644 index dbd03fb8b7..0000000000 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/FunctionResolver.kt +++ /dev/null @@ -1,79 +0,0 @@ -package org.partiql.planner.typer - -import org.partiql.planner.Header -import org.partiql.types.function.FunctionParameter -import org.partiql.types.function.FunctionSignature -import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType - -/** - * Function arguments list. The planner is responsible for mapping arguments to parameters. - */ -internal typealias Args = List - -/** - * Parameter mapping list tells the planner where to insert implicit casts. Null is the identity. - */ -internal typealias Mapping = List - -/** - * Tells us which function matched, and how the arguments are mapped. - */ -internal class Match( - public val signature: T, - public val mapping: Mapping, -) - -/** - * Logic for matching signatures to arguments. - */ -@OptIn(PartiQLValueExperimental::class) -internal class FunctionResolver(private val header: Header) { - - /** - * Functions are sorted by precedence (which is not rigorously defined/specified at the moment). - */ - public fun match(signatures: List, args: Args): Match? { - for (signature in signatures) { - val mapping = match(signature, args) - if (mapping != null) { - return Match(signature, mapping) - } - } - return null - } - - /** - * Attempt to match arguments to the parameters; return the implicit casts if necessary. - * - * TODO we need to constrain the allowable runtime types for an ANY typed parameter. - */ - public fun match(signature: FunctionSignature, args: Args): Mapping? { - if (signature.parameters.size != args.size) { - return null - } - val mapping = ArrayList(args.size) - for (i in args.indices) { - val a = args[i] - val p = signature.parameters[i] - when { - // 1. Exact match - a.type == p.type -> mapping.add(null) - // 2. Match ANY, no coercion needed - p.type == PartiQLValueType.ANY -> mapping.add(null) - // 3. Match NULL argument - a.type == PartiQLValueType.NULL -> mapping.add(null) - // 4. Check for a coercion - else -> { - val coercion = header.lookupCoercion(a.type, p.type) - when (coercion) { - null -> return null // short-circuit - else -> mapping.add(coercion) - } - } - } - } - // we made a match - return mapping - } -} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt index e055bfbe2d..e0a61160c7 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt @@ -59,7 +59,6 @@ import org.partiql.plan.rexOpVarResolved import org.partiql.plan.statementQuery import org.partiql.plan.util.PlanRewriter import org.partiql.planner.Env -import org.partiql.planner.FnMatch import org.partiql.planner.PlanningProblemDetails import org.partiql.planner.ResolutionStrategy import org.partiql.planner.ResolvedVar diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/TypeLattice.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/TypeLattice.kt index 10f0457af6..63af484732 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/TypeLattice.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/TypeLattice.kt @@ -64,6 +64,50 @@ internal class TypeLattice private constructor( return graph[operand][target]?.cast == CastType.COERCION } + internal val all = PartiQLValueType.values() + + internal val nullable = listOf( + NULL, // null.null + MISSING, // missing + ) + + internal val integer = listOf( + INT8, + INT16, + INT32, + INT64, + INT, + ) + + internal val numeric = listOf( + INT8, + INT16, + INT32, + INT64, + INT, + DECIMAL, + FLOAT32, + FLOAT64, + ) + + internal val text = listOf( + STRING, + SYMBOL, + CLOB, + ) + + internal val collections = listOf( + BAG, + LIST, + SEXP, + ) + + internal val datetime = listOf( + DATE, + TIME, + TIMESTAMP, + ) + /** * Dump the graph as an Asciidoc table. */ diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/HeaderTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/HeaderTest.kt index 474ee8d033..3dad436cae 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/HeaderTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/HeaderTest.kt @@ -8,6 +8,6 @@ class HeaderTest { @Test @Disabled fun print() { - println(Header.partiql()) + println(PartiQLHeader) } } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/typer/FunctionResolverTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/typer/FunctionResolverTest.kt index b269826973..eac7669401 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/typer/FunctionResolverTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/typer/FunctionResolverTest.kt @@ -3,10 +3,12 @@ package org.partiql.planner.typer import org.junit.jupiter.api.Test import org.junit.jupiter.api.fail import org.partiql.planner.Header +import org.partiql.planner.PartiQLHeader import org.partiql.types.function.FunctionParameter import org.partiql.types.function.FunctionSignature import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType +import javax.print.DocFlavor.STRING /** * As far as testing is concerned, we can stub out all value related things. @@ -19,7 +21,7 @@ class FunctionResolverTest { @Test fun sanity() { // 1 + 1.0 -> 2.0 - val fn = Header.Functions.binary( + val fn = Header.binary( name = "plus", returns = PartiQLValueType.FLOAT64, lhs = PartiQLValueType.FLOAT64, @@ -34,9 +36,39 @@ class FunctionResolverTest { case.assert() } + @Test + fun split() { + val args = listOf( + FunctionParameter("arg-0", PartiQLValueType.STRING), + FunctionParameter("arg-1", PartiQLValueType.STRING), + ) + val expectedImplicitCasts = listOf(false, false) + val case = Case.Success(split, args, expectedImplicitCasts) + case.assert() + } + companion object { - private val header = Header.partiql() - private val resolver = FunctionResolver(header) + + val split = FunctionSignature.Scalar( + name = "split", + returns = PartiQLValueType.LIST, + parameters = listOf( + FunctionParameter("value", PartiQLValueType.STRING), + FunctionParameter("delimiter", PartiQLValueType.STRING), + ), + isNullable = false, + ) + + private val myHeader = object : Header() { + + override val namespace: String = "my_header" + + override val functions: List = listOf( + split + ) + } + + private val resolver = FnResolver(listOf(PartiQLHeader, myHeader)) } private sealed class Case {