diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeGroupBy.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeGroupBy.kt index 4ef1f701c6..ccc2c9d1f4 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeGroupBy.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeGroupBy.kt @@ -14,6 +14,7 @@ package org.partiql.ast.normalize +import org.partiql.ast.AstNode import org.partiql.ast.Expr import org.partiql.ast.GroupBy import org.partiql.ast.Statement @@ -30,6 +31,13 @@ object NormalizeGroupBy : AstPass { private object Visitor : AstRewriter() { + override fun visitGroupBy(node: GroupBy, ctx: Int): AstNode { + val keys = node.keys.mapIndexed { index, key -> + visitGroupByKey(key, index + 1) + } + return node.copy(keys = keys) + } + override fun visitGroupByKey(node: GroupBy.Key, ctx: Int): GroupBy.Key { val expr = visitExpr(node.expr, 0) as Expr val alias = when (node.asAlias) { diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt index ba2cedf685..24522ab165 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt @@ -2,6 +2,7 @@ package org.partiql.eval.internal import org.partiql.eval.PartiQLEngine import org.partiql.eval.internal.operator.Operator +import org.partiql.eval.internal.operator.rel.RelAggregate import org.partiql.eval.internal.operator.rel.RelDistinct import org.partiql.eval.internal.operator.rel.RelExclude import org.partiql.eval.internal.operator.rel.RelFilter @@ -45,6 +46,7 @@ import org.partiql.plan.Rex import org.partiql.plan.Statement import org.partiql.plan.debug.PlanPrinter import org.partiql.plan.visitor.PlanBaseVisitor +import org.partiql.spi.fn.Agg import org.partiql.spi.fn.FnExperimental import org.partiql.types.StaticType import org.partiql.value.PartiQLValueExperimental @@ -170,6 +172,30 @@ internal class Compiler( override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: StaticType?): Operator = symbols.getGlobal(node.ref) + override fun visitRelOpAggregate(node: Rel.Op.Aggregate, ctx: StaticType?): Operator.Relation { + val input = visitRel(node.input, ctx) + val calls = node.calls.map { + visitRelOpAggregateCall(it, ctx) + } + val groups = node.groups.map { visitRex(it, ctx).modeHandled() } + return RelAggregate(input, groups, calls) + } + + @OptIn(FnExperimental::class) + override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: StaticType?): Operator.Aggregation { + val args = node.args.map { visitRex(it, it.type).modeHandled() } + val setQuantifier: Operator.Aggregation.SetQuantifier = when (node.setQuantifier) { + Rel.Op.Aggregate.Call.SetQuantifier.ALL -> Operator.Aggregation.SetQuantifier.ALL + Rel.Op.Aggregate.Call.SetQuantifier.DISTINCT -> Operator.Aggregation.SetQuantifier.DISTINCT + } + val agg = symbols.getAgg(node.agg) + return object : Operator.Aggregation { + override val delegate: Agg = agg + override val args: List = args + override val setQuantifier: Operator.Aggregation.SetQuantifier = setQuantifier + } + } + override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: StaticType?): Operator { val root = visitRex(node.root, ctx) val key = visitRex(node.key, ctx) @@ -206,7 +232,7 @@ internal class Compiler( val args = node.args.map { visitRex(it, ctx).modeHandled() }.toTypedArray() val candidates = node.candidates.map { candidate -> val fn = symbols.getFn(candidate.fn) - val types = fn.signature.parameters.map { it.type }.toTypedArray() + val types = candidate.parameters.toTypedArray() val coercions = candidate.coercions.toTypedArray() ExprCallDynamic.Candidate(fn, types, coercions) } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Symbols.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Symbols.kt index 2d12b1fc7d..d770c79484 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Symbols.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Symbols.kt @@ -7,9 +7,11 @@ import org.partiql.eval.internal.operator.rex.ExprVarGlobal import org.partiql.plan.Catalog import org.partiql.plan.PartiQLPlan import org.partiql.plan.Ref +import org.partiql.spi.connector.ConnectorAggProvider import org.partiql.spi.connector.ConnectorBindings import org.partiql.spi.connector.ConnectorFnProvider import org.partiql.spi.connector.ConnectorPath +import org.partiql.spi.fn.Agg import org.partiql.spi.fn.Fn import org.partiql.spi.fn.FnExperimental @@ -25,6 +27,7 @@ internal class Symbols private constructor(private val catalogs: Array) { val name: String, val bindings: ConnectorBindings, val functions: ConnectorFnProvider, + val aggregations: ConnectorAggProvider, val items: Array, ) { @@ -53,6 +56,18 @@ internal class Symbols private constructor(private val catalogs: Array) { ?: error("Catalog `$catalog` has no entry for function $item") } + fun getAgg(ref: Ref): Agg { + val catalog = catalogs[ref.catalog] + val item = catalog.items.getOrNull(ref.symbol) + if (item == null || item !is Catalog.Item.Agg) { + error("Invalid reference $ref; missing aggregation entry for catalog `$catalog`.") + } + // Lookup in connector + val path = ConnectorPath(item.path) + return catalog.aggregations.getAgg(path, item.specific) + ?: error("Catalog `$catalog` has no entry for aggregation function $item") + } + companion object { /** @@ -71,6 +86,7 @@ internal class Symbols private constructor(private val catalogs: Array) { name = it.name, bindings = connector.getBindings(), functions = connector.getFunctions(), + aggregations = connector.getAggregations(), items = it.items.toTypedArray() ) }.toTypedArray() diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/Operator.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/Operator.kt index ade5920f07..7b28fcbed1 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/Operator.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/Operator.kt @@ -1,6 +1,8 @@ package org.partiql.eval.internal.operator import org.partiql.eval.internal.Record +import org.partiql.spi.fn.Agg +import org.partiql.spi.fn.FnExperimental import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental @@ -26,4 +28,19 @@ internal sealed interface Operator { override fun close() } + + interface Aggregation : Operator { + + @OptIn(FnExperimental::class) + val delegate: Agg + + val args: List + + val setQuantifier: SetQuantifier + + enum class SetQuantifier { + ALL, + DISTINCT + } + } } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelAggregate.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelAggregate.kt new file mode 100644 index 0000000000..0e19bcf610 --- /dev/null +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelAggregate.kt @@ -0,0 +1,132 @@ +package org.partiql.eval.internal.operator.rel + +import org.partiql.eval.internal.Record +import org.partiql.eval.internal.operator.Operator +import org.partiql.spi.fn.Agg +import org.partiql.spi.fn.FnExperimental +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType +import org.partiql.value.nullValue +import java.util.TreeMap +import java.util.TreeSet + +internal class RelAggregate( + val input: Operator.Relation, + val keys: List, + val functions: List +) : Operator.Relation { + + lateinit var records: Iterator + + @OptIn(PartiQLValueExperimental::class) + val aggregationMap = TreeMap, List>(PartiQLValueListComparator) + + @OptIn(PartiQLValueExperimental::class) + object PartiQLValueListComparator : Comparator> { + private val delegate = PartiQLValue.comparator(nullsFirst = false) + override fun compare(o1: List, o2: List): Int { + if (o1.size < o2.size) { + return -1 + } + if (o1.size > o2.size) { + return 1 + } + for (index in 0..o2.lastIndex) { + val element1 = o1[index] + val element2 = o2[index] + val compared = delegate.compare(element1, element2) + if (compared != 0) { + return compared + } + } + return 0 + } + } + + /** + * Wraps an [Agg.Accumulator] to help with filtering distinct values. + * + * @property seen maintains which values have already been seen. If null, we accumulate all values coming through. + */ + class AccumulatorWrapper @OptIn(PartiQLValueExperimental::class, FnExperimental::class) constructor( + val delegate: Agg.Accumulator, + val args: List, + val seen: TreeSet>? + ) + + @OptIn(PartiQLValueExperimental::class, FnExperimental::class) + override fun open() { + input.open() + var inputRecord = input.next() + while (inputRecord != null) { + // Initialize the AggregationMap + val evaluatedGroupByKeys = keys.map { + val key = it.eval(inputRecord!!) + when (key.type == PartiQLValueType.MISSING) { + true -> nullValue() + false -> key + } + } + val accumulators = aggregationMap.getOrPut(evaluatedGroupByKeys) { + functions.map { + AccumulatorWrapper( + delegate = it.delegate.accumulator(), + args = it.args, + seen = when (it.setQuantifier) { + Operator.Aggregation.SetQuantifier.DISTINCT -> TreeSet(PartiQLValueListComparator) + Operator.Aggregation.SetQuantifier.ALL -> null + } + ) + } + } + + // Aggregate Values in Aggregation State + accumulators.forEachIndexed { index, function -> + val valueToAggregate = function.args.map { it.eval(inputRecord!!) } + // Skip over aggregation if NULL/MISSING + if (valueToAggregate.any { it.type == PartiQLValueType.MISSING || it.isNull }) { + return@forEachIndexed + } + // Skip over aggregation if DISTINCT and SEEN + if (function.seen != null && (function.seen.add(valueToAggregate).not())) { + return@forEachIndexed + } + accumulators[index].delegate.next(valueToAggregate.toTypedArray()) + } + inputRecord = input.next() + } + + // No Aggregations Created + if (keys.isEmpty() && aggregationMap.isEmpty()) { + val record = mutableListOf() + functions.forEach { function -> + val accumulator = function.delegate.accumulator() + record.add(accumulator.value()) + } + records = iterator { yield(Record.of(*record.toTypedArray())) } + return + } + + records = iterator { + aggregationMap.forEach { (keysEvaluated, accumulators) -> + val recordValues = accumulators.map { acc -> acc.delegate.value() } + keysEvaluated.map { value -> value } + yield(Record.of(*recordValues.toTypedArray())) + } + } + } + + override fun next(): Record? { + return if (records.hasNext()) { + records.next() + } else { + null + } + } + + @OptIn(PartiQLValueExperimental::class) + override fun close() { + aggregationMap.clear() + input.close() + } +} diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt index 52a6a45a9c..23ab8c0e81 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt @@ -31,7 +31,11 @@ internal class ExprCallDynamic( return candidate.eval(actualArgs) } } - throw TypeCheckException() + val errorString = buildString { + val argString = actualArgs.joinToString(", ") + append("Could not dynamically find function for arguments $argString in $candidates.") + } + throw TypeCheckException(errorString) } /** diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprPathSymbol.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprPathSymbol.kt index bc94c229b1..1de788dfc4 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprPathSymbol.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprPathSymbol.kt @@ -25,6 +25,6 @@ internal class ExprPathSymbol( return v } } - throw TypeCheckException() + throw TypeCheckException("Couldn't find symbol '$symbol' in $struct.") } } diff --git a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEngineDefaultTest.kt b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEngineDefaultTest.kt index 9105bed264..1f73f76b5b 100644 --- a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEngineDefaultTest.kt +++ b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEngineDefaultTest.kt @@ -9,7 +9,7 @@ import org.junit.jupiter.params.provider.MethodSource import org.partiql.eval.PartiQLEngine import org.partiql.eval.PartiQLResult import org.partiql.parser.PartiQLParser -import org.partiql.plan.PlanNode +import org.partiql.plan.PartiQLPlan import org.partiql.plan.debug.PlanPrinter import org.partiql.planner.PartiQLPlanner import org.partiql.planner.PartiQLPlannerBuilder @@ -55,6 +55,11 @@ class PartiQLEngineDefaultTest { @Execution(ExecutionMode.CONCURRENT) fun subqueryTests(tc: SuccessTestCase) = tc.assert() + @ParameterizedTest + @MethodSource("aggregationTestCases") + @Execution(ExecutionMode.CONCURRENT) + fun aggregationTests(tc: SuccessTestCase) = tc.assert() + companion object { @JvmStatic @@ -314,6 +319,132 @@ class PartiQLEngineDefaultTest { ) ) + @JvmStatic + fun aggregationTestCases() = kotlin.collections.listOf( + SuccessTestCase( + input = """ + SELECT + gk_0, SUM(t.c) AS t_c_sum + FROM << + { 'b': NULL, 'c': 1 }, + { 'b': MISSING, 'c': 2 }, + { 'b': 1, 'c': 1 }, + { 'b': 1, 'c': 2 }, + { 'b': 2, 'c': NULL }, + { 'b': 2, 'c': 2 }, + { 'b': 3, 'c': MISSING }, + { 'b': 3, 'c': 2 }, + { 'b': 4, 'c': MISSING }, + { 'b': 4, 'c': NULL } + >> AS t GROUP BY t.b AS gk_0; + """.trimIndent(), + expected = org.partiql.value.bagValue( + org.partiql.value.structValue( + "gk_0" to org.partiql.value.int32Value(1), + "t_c_sum" to org.partiql.value.int32Value(3) + ), + org.partiql.value.structValue( + "gk_0" to org.partiql.value.int32Value(2), + "t_c_sum" to org.partiql.value.int32Value(2) + ), + org.partiql.value.structValue( + "gk_0" to org.partiql.value.int32Value(3), + "t_c_sum" to org.partiql.value.int32Value(2) + ), + org.partiql.value.structValue( + "gk_0" to org.partiql.value.int32Value(4), + "t_c_sum" to org.partiql.value.int32Value(null) + ), + org.partiql.value.structValue( + "gk_0" to org.partiql.value.nullValue(), + "t_c_sum" to org.partiql.value.int32Value(3) + ), + ), + mode = org.partiql.eval.PartiQLEngine.Mode.PERMISSIVE + ), + SuccessTestCase( + input = """ + SELECT VALUE { 'sensor': sensor, + 'readings': (SELECT VALUE v.l.co FROM g AS v) + } + FROM [{'sensor':1, 'co':0.4}, {'sensor':1, 'co':0.2}, {'sensor':2, 'co':0.3}] AS l + GROUP BY l.sensor AS sensor GROUP AS g + """.trimIndent(), + expected = org.partiql.value.bagValue( + org.partiql.value.structValue( + "sensor" to org.partiql.value.int32Value(1), + "readings" to org.partiql.value.bagValue( + org.partiql.value.decimalValue(0.4.toBigDecimal()), + org.partiql.value.decimalValue(0.2.toBigDecimal()) + ) + ), + org.partiql.value.structValue( + "sensor" to org.partiql.value.int32Value(2), + "readings" to org.partiql.value.bagValue( + org.partiql.value.decimalValue(0.3.toBigDecimal()) + ) + ), + ) + ), + SuccessTestCase( + input = """ + SELECT col1, g + FROM [{ 'col1':1 }, { 'col1':1 }] simple_1_col_1_group + GROUP BY col1 GROUP AS g + """.trimIndent(), + expected = bagValue( + structValue( + "col1" to int32Value(1), + "g" to bagValue( + structValue( + "simple_1_col_1_group" to structValue("col1" to int32Value(1)) + ), + structValue( + "simple_1_col_1_group" to structValue("col1" to int32Value(1)) + ), + ) + ), + ) + ), + SuccessTestCase( + input = """ + SELECT p.supplierId_mixed + FROM [ + { 'productId': 5, 'categoryId': 21, 'regionId': 100, 'supplierId_nulls': null, 'price_nulls': null }, + { 'productId': 4, 'categoryId': 20, 'regionId': 100, 'supplierId_nulls': null, 'supplierId_mixed': null, 'price_nulls': null, 'price_mixed': null } + ] AS p + GROUP BY p.supplierId_mixed + """.trimIndent(), + expected = bagValue( + structValue( + "supplierId_mixed" to nullValue(), + ), + ) + ), + SuccessTestCase( + input = """ + SELECT * + FROM << { 'a': 1, 'b': 2 } >> AS t + GROUP BY a, b, a + b GROUP AS g + """.trimIndent(), + expected = bagValue( + structValue( + "a" to int32Value(1), + "b" to int32Value(2), + "_3" to int32Value(3), + "g" to bagValue( + structValue( + "t" to structValue( + "a" to int32Value(1), + "b" to int32Value(2), + ) + ) + ), + ), + ) + ), + ) + @JvmStatic fun sanityTestsCases() = listOf( SuccessTestCase( @@ -938,11 +1069,11 @@ class PartiQLEngineDefaultTest { } } val output = result.value - assertEquals(expected, output, comparisonString(plan.plan, expected, output)) + assertEquals(expected, output, comparisonString(expected, output, plan.plan)) } @OptIn(PartiQLValueExperimental::class) - private fun comparisonString(plan: PlanNode, expected: PartiQLValue, actual: PartiQLValue): String { + private fun comparisonString(expected: PartiQLValue, actual: PartiQLValue, plan: PartiQLPlan): String { val expectedBuffer = ByteArrayOutputStream() val expectedWriter = PartiQLValueIonWriterBuilder.standardIonTextBuilder().build(expectedBuffer) expectedWriter.append(expected) diff --git a/partiql-plan/src/main/resources/partiql_plan.ion b/partiql-plan/src/main/resources/partiql_plan.ion index 6839065fdb..9d3ec6baac 100644 --- a/partiql-plan/src/main/resources/partiql_plan.ion +++ b/partiql-plan/src/main/resources/partiql_plan.ion @@ -24,6 +24,10 @@ catalog::{ path: list::[string], specific: string, }, + agg::{ + path: list::[string], + specific: string, + }, ] ] } @@ -128,6 +132,7 @@ rex::{ // is necessary. candidate::{ fn: ref, + parameters: list::[partiql_value_type], coercions: list::[optional::'.ref.cast'], } ] @@ -291,7 +296,8 @@ rel::{ groups: list::[rex], _: [ call::{ - agg: string, + agg: ref, + set_quantifier: [ ALL, DISTINCT ], args: list::[rex], }, ], diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt index d124259852..16f723f86a 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt @@ -2,6 +2,7 @@ package org.partiql.planner.internal import org.partiql.planner.PartiQLPlanner import org.partiql.planner.internal.casts.CastTable +import org.partiql.planner.internal.ir.Ref import org.partiql.planner.internal.ir.Rel import org.partiql.planner.internal.ir.Rex import org.partiql.planner.internal.ir.refAgg @@ -17,8 +18,11 @@ import org.partiql.planner.internal.ir.rexOpVarGlobal import org.partiql.planner.internal.typer.TypeEnv.Companion.toPath import org.partiql.planner.internal.typer.toRuntimeType import org.partiql.planner.internal.typer.toStaticType +import org.partiql.spi.BindingCase +import org.partiql.spi.BindingName import org.partiql.spi.BindingPath import org.partiql.spi.connector.ConnectorMetadata +import org.partiql.spi.fn.AggSignature import org.partiql.spi.fn.FnExperimental import org.partiql.types.StaticType import org.partiql.value.PartiQLValueExperimental @@ -58,7 +62,7 @@ internal class Env(private val session: PartiQLPlanner.Session) { /** * A [PathResolver] for aggregation function lookup. */ - private val aggs: PathResolverAgg = PathResolverAgg + private val aggs: PathResolverAgg = PathResolverAgg(catalog, session) /** * This function looks up a global [BindingPath], returning a global reference expression. @@ -101,9 +105,10 @@ internal class Env(private val session: PartiQLPlanner.Session) { fn = refFn( catalog = item.catalog, path = item.handle.path.steps, - signature = it.signature, + signature = it.fn.signature, ), - coercions = it.mapping.toList(), + parameters = it.parameters, + coercions = it.fn.mapping.toList(), ) } // Rewrite as a dynamic call to be typed by PlanTyper @@ -130,12 +135,23 @@ internal class Env(private val session: PartiQLPlanner.Session) { } @OptIn(FnExperimental::class, PartiQLValueExperimental::class) - fun resolveAgg(name: String, args: List): Rel.Op.Aggregate.Call.Resolved? { - val match = aggs.resolve(name, args) ?: return null + fun resolveAgg(name: String, setQuantifier: Rel.Op.Aggregate.SetQuantifier, args: List): Rel.Op.Aggregate.Call.Resolved? { + // TODO: Eventually, do we want to support sensitive lookup? With a path? + val path = BindingPath(listOf(BindingName(name, BindingCase.INSENSITIVE))) + val item = aggs.lookup(path) ?: return null + val candidates = item.handle.entity.getVariants() + var hadMissingArg = false + val parameters = args.mapIndexed { i, arg -> + if (!hadMissingArg && arg.type.isMissable()) { + hadMissingArg = true + } + arg.type.toRuntimeType() + } + val match = match(candidates, parameters) ?: return null val agg = match.first val mapping = match.second // Create an internal typed reference - val ref = refAgg(name, agg) + val ref = refAgg(item.catalog, item.handle.path.steps, agg) // Apply the coercions as explicit casts val coercions: List = args.mapIndexed { i, arg -> when (val cast = mapping[i]) { @@ -143,7 +159,7 @@ internal class Env(private val session: PartiQLPlanner.Session) { else -> rex(cast.target.toStaticType(), rexOpCastResolved(cast, arg)) } } - return relOpAggregateCallResolved(ref, coercions) + return relOpAggregateCallResolved(ref, setQuantifier, coercions) } @OptIn(PartiQLValueExperimental::class) @@ -159,16 +175,107 @@ internal class Env(private val session: PartiQLPlanner.Session) { /** * Logic for determining how many BindingNames were “matched” by the ConnectorMetadata - * 1. Matched = RelativePath - Not Found - * 2. Not Found = Input CatalogPath - Output CatalogPath - * 3. Matched = RelativePath - (Input CatalogPath - Output CatalogPath) - * 4. Matched = RelativePath + Output CatalogPath - Input CatalogPath + * + * Assume: + * - steps_matched = user_input_path_size - path_steps_not_found_size + * - path_steps_not_found_size = catalog_path_sent_to_spi_size - actual_catalog_absolute_path_size + * + * Therefore, we present the equation to [calculateMatched]: + * - steps_matched = user_input_path_size - (catalog_path_sent_to_spi_size - actual_catalog_absolute_path_size) + * = user_input_path_size + actual_catalog_absolute_path_size - catalog_path_sent_to_spi_size + * + * For example: + * + * Assume we are in some catalog, C, in some schema, S. There is a tuple, T, with attribute, A1. Assume A1 is of type + * tuple with an attribute A2. + * If our query references `T.A1.A2`, we will eventually ask SPI (connector C) for `S.T.A1.A2`. In this scenario: + * - The original user input was `T.A1.A2` (length 3) + * - The absolute path returned from SPI will be `S.T` (length 2) + * - The path we eventually sent to SPI to resolve was `S.T.A1.A2` (length 4) + * + * So, we can now use [calculateMatched] to determine how many were actually matched from the user input. Using the + * equation from above: + * + * - steps_matched = len(user input) + len(absolute catalog path) - len(path sent to SPI) + * = len([userInputPath]) + len([actualAbsolutePath]) - len([pathSentToConnector]) + * = 3 + 2 - 4 + * = 5 - 4 + * = 1 + * + * + * Therefore, in this example we have determined that from the original input (`T.A1.A2`) `T` is the value matched in the + * database environment. */ private fun calculateMatched( - originalPath: BindingPath, - inputCatalogPath: BindingPath, - outputCatalogPath: List, + userInputPath: BindingPath, + pathSentToConnector: BindingPath, + actualAbsolutePath: List, ): Int { - return originalPath.steps.size + outputCatalogPath.size - inputCatalogPath.steps.size + return userInputPath.steps.size + actualAbsolutePath.size - pathSentToConnector.steps.size + } + + @OptIn(FnExperimental::class, PartiQLValueExperimental::class) + private fun match(candidates: List, args: List): Pair>? { + // 1. Check for an exact match + for (candidate in candidates) { + if (candidate.matches(args)) { + return candidate to arrayOfNulls(args.size) + } + } + // 2. Look for best match. + var match: Pair>? = null + for (candidate in candidates) { + val m = candidate.match(args) ?: continue + // TODO AggMatch comparison + // if (match != null && m.exact < match.exact) { + // // already had a better match. + // continue + // } + match = m + } + // 3. Return best match or null + return match + } + + /** + * Check if this function accepts the exact input argument types. Assume same arity. + */ + @OptIn(FnExperimental::class, PartiQLValueExperimental::class) + private fun AggSignature.matches(args: List): Boolean { + for (i in args.indices) { + val a = args[i] + val p = parameters[i] + if (p.type != PartiQLValueType.ANY && a != p.type) return false + } + return true + } + + /** + * Attempt to match arguments to the parameters; return the implicit casts if necessary. + * + * @param args + * @return + */ + @OptIn(FnExperimental::class, PartiQLValueExperimental::class) + private fun AggSignature.match(args: List): Pair>? { + val mapping = arrayOfNulls(args.size) + for (i in args.indices) { + val arg = args[i] + val p = parameters[i] + when { + // 1. Exact match + arg == p.type -> continue + // 2. Match ANY, no coercion needed + p.type == PartiQLValueType.ANY -> continue + // 3. Match NULL argument + arg == PartiQLValueType.NULL -> continue + // 4. Check for a coercion + else -> when (val coercion = PathResolverAgg.casts.lookupCoercion(arg, p.type)) { + null -> return null // short-circuit + else -> mapping[i] = coercion + } + } + } + return this to mapping } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnMatch.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnMatch.kt index cd497a7604..acfeee59e5 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnMatch.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnMatch.kt @@ -3,6 +3,8 @@ package org.partiql.planner.internal import org.partiql.planner.internal.ir.Ref import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnSignature +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType /** * Result of matching an unresolved function. @@ -41,7 +43,19 @@ internal sealed class FnMatch { * @property exhaustive True if all argument permutations (branches) are matched. */ data class Dynamic( - val candidates: List, + val candidates: List, val exhaustive: Boolean, - ) : FnMatch() + ) : FnMatch() { + + /** + * Represents a candidate of dynamic dispatch. + * + * @property fn Function to invoke. + * @property parameters Represents the input type(s) to match. (ex: INT32) + */ + data class Candidate @OptIn(PartiQLValueExperimental::class) constructor( + val fn: Static, + val parameters: List + ) + } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnResolver.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnResolver.kt index edbd8dcb85..aedc2a8a14 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnResolver.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnResolver.kt @@ -71,7 +71,7 @@ internal object FnResolver { // Static call iff only one match for every branch return when { n == 0 -> null - n == 1 && exhaustive -> orderedUniqueFunctions.first() + n == 1 && exhaustive -> orderedUniqueFunctions.first().fn else -> FnMatch.Dynamic(orderedUniqueFunctions, exhaustive) } } @@ -83,11 +83,11 @@ internal object FnResolver { * @param args * @return */ - private fun match(candidates: List, args: List): FnMatch.Static? { + private fun match(candidates: List, args: List): FnMatch.Dynamic.Candidate? { // 1. Check for an exact match for (candidate in candidates) { if (candidate.matches(args)) { - return FnMatch.Static(candidate, arrayOfNulls(args.size)) + return FnMatch.Dynamic.Candidate(fn = FnMatch.Static(candidate, arrayOfNulls(args.size)), args) } } // 2. Look for best match (for now, first match). @@ -124,7 +124,7 @@ internal object FnResolver { * @param args * @return */ - private fun FnSignature.match(args: List): FnMatch.Static? { + private fun FnSignature.match(args: List): FnMatch.Dynamic.Candidate? { val mapping = arrayOfNulls(args.size) for (i in args.indices) { val arg = args[i] @@ -143,7 +143,7 @@ internal object FnResolver { } } } - return FnMatch.Static(this, mapping) + return FnMatch.Dynamic.Candidate(fn = FnMatch.Static(this, mapping), args) } private fun buildArgumentPermutations(args: List): List> { diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathItem.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathItem.kt index 0a05038f9b..08a7ebca5a 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathItem.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathItem.kt @@ -8,7 +8,7 @@ import org.partiql.spi.connector.ConnectorHandle * * @param T * @property catalog The resolved entity's catalog name. - * @property input The input binding path that resulted in this item match. + * @property input The input binding path (sent to SPI) that resulted in this item match. * @property handle The resolved entity's catalog path and type information. */ internal data class PathItem( diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathResolverAgg.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathResolverAgg.kt index 4a711306ae..7a19e15ede 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathResolverAgg.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathResolverAgg.kt @@ -1,52 +1,12 @@ package org.partiql.planner.internal +import org.partiql.planner.PartiQLPlanner import org.partiql.planner.internal.casts.CastTable -import org.partiql.planner.internal.ir.Ref -import org.partiql.planner.internal.ir.Rex -import org.partiql.planner.internal.typer.toRuntimeType -import org.partiql.spi.connector.sql.builtins.Agg_ANY__BOOL__BOOL -import org.partiql.spi.connector.sql.builtins.Agg_AVG__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY -import org.partiql.spi.connector.sql.builtins.Agg_AVG__FLOAT32__FLOAT32 -import org.partiql.spi.connector.sql.builtins.Agg_AVG__FLOAT64__FLOAT64 -import org.partiql.spi.connector.sql.builtins.Agg_AVG__INT16__INT16 -import org.partiql.spi.connector.sql.builtins.Agg_AVG__INT32__INT32 -import org.partiql.spi.connector.sql.builtins.Agg_AVG__INT64__INT64 -import org.partiql.spi.connector.sql.builtins.Agg_AVG__INT8__INT8 -import org.partiql.spi.connector.sql.builtins.Agg_AVG__INT__INT -import org.partiql.spi.connector.sql.builtins.Agg_COUNT_STAR____INT32 -import org.partiql.spi.connector.sql.builtins.Agg_COUNT__ANY__INT32 -import org.partiql.spi.connector.sql.builtins.Agg_EVERY__BOOL__BOOL -import org.partiql.spi.connector.sql.builtins.Agg_MAX__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY -import org.partiql.spi.connector.sql.builtins.Agg_MAX__FLOAT32__FLOAT32 -import org.partiql.spi.connector.sql.builtins.Agg_MAX__FLOAT64__FLOAT64 -import org.partiql.spi.connector.sql.builtins.Agg_MAX__INT16__INT16 -import org.partiql.spi.connector.sql.builtins.Agg_MAX__INT32__INT32 -import org.partiql.spi.connector.sql.builtins.Agg_MAX__INT64__INT64 -import org.partiql.spi.connector.sql.builtins.Agg_MAX__INT8__INT8 -import org.partiql.spi.connector.sql.builtins.Agg_MAX__INT__INT -import org.partiql.spi.connector.sql.builtins.Agg_MIN__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY -import org.partiql.spi.connector.sql.builtins.Agg_MIN__FLOAT32__FLOAT32 -import org.partiql.spi.connector.sql.builtins.Agg_MIN__FLOAT64__FLOAT64 -import org.partiql.spi.connector.sql.builtins.Agg_MIN__INT16__INT16 -import org.partiql.spi.connector.sql.builtins.Agg_MIN__INT32__INT32 -import org.partiql.spi.connector.sql.builtins.Agg_MIN__INT64__INT64 -import org.partiql.spi.connector.sql.builtins.Agg_MIN__INT8__INT8 -import org.partiql.spi.connector.sql.builtins.Agg_MIN__INT__INT -import org.partiql.spi.connector.sql.builtins.Agg_SOME__BOOL__BOOL -import org.partiql.spi.connector.sql.builtins.Agg_SUM__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY -import org.partiql.spi.connector.sql.builtins.Agg_SUM__FLOAT32__FLOAT32 -import org.partiql.spi.connector.sql.builtins.Agg_SUM__FLOAT64__FLOAT64 -import org.partiql.spi.connector.sql.builtins.Agg_SUM__INT16__INT16 -import org.partiql.spi.connector.sql.builtins.Agg_SUM__INT32__INT32 -import org.partiql.spi.connector.sql.builtins.Agg_SUM__INT64__INT64 -import org.partiql.spi.connector.sql.builtins.Agg_SUM__INT8__INT8 -import org.partiql.spi.connector.sql.builtins.Agg_SUM__INT__INT -import org.partiql.spi.fn.AggSignature +import org.partiql.spi.BindingPath +import org.partiql.spi.connector.ConnectorAgg +import org.partiql.spi.connector.ConnectorHandle +import org.partiql.spi.connector.ConnectorMetadata import org.partiql.spi.fn.FnExperimental -import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType -import org.partiql.value.PartiQLValueType.ANY -import org.partiql.value.PartiQLValueType.NULL /** * Today, all aggregations are hard-coded into the grammar. We cannot implement user-defined aggregations until @@ -61,123 +21,18 @@ import org.partiql.value.PartiQLValueType.NULL * ; * */ -@OptIn(FnExperimental::class, PartiQLValueExperimental::class) -internal object PathResolverAgg { - - @JvmStatic - private val casts = CastTable.partiql - - private val map = listOf( - Agg_ANY__BOOL__BOOL, - Agg_AVG__INT8__INT8, - Agg_AVG__INT16__INT16, - Agg_AVG__INT32__INT32, - Agg_AVG__INT64__INT64, - Agg_AVG__INT__INT, - Agg_AVG__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY, - Agg_AVG__FLOAT32__FLOAT32, - Agg_AVG__FLOAT64__FLOAT64, - Agg_COUNT__ANY__INT32, - Agg_COUNT_STAR____INT32, - Agg_EVERY__BOOL__BOOL, - Agg_MAX__INT8__INT8, - Agg_MAX__INT16__INT16, - Agg_MAX__INT32__INT32, - Agg_MAX__INT64__INT64, - Agg_MAX__INT__INT, - Agg_MAX__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY, - Agg_MAX__FLOAT32__FLOAT32, - Agg_MAX__FLOAT64__FLOAT64, - Agg_MIN__INT8__INT8, - Agg_MIN__INT16__INT16, - Agg_MIN__INT32__INT32, - Agg_MIN__INT64__INT64, - Agg_MIN__INT__INT, - Agg_MIN__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY, - Agg_MIN__FLOAT32__FLOAT32, - Agg_MIN__FLOAT64__FLOAT64, - Agg_SOME__BOOL__BOOL, - Agg_SUM__INT8__INT8, - Agg_SUM__INT16__INT16, - Agg_SUM__INT32__INT32, - Agg_SUM__INT64__INT64, - Agg_SUM__INT__INT, - Agg_SUM__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY, - Agg_SUM__FLOAT32__FLOAT32, - Agg_SUM__FLOAT64__FLOAT64, - ).map { it.signature }.groupBy { it.name } - - fun resolve(name: String, args: List): Pair>? { - val candidates = map[name] ?: return null - var hadMissingArg = false - val parameters = args.mapIndexed { i, arg -> - if (!hadMissingArg && arg.type.isMissable()) { - hadMissingArg = true - } - arg.type.toRuntimeType() - } - return match(candidates, parameters) - } - - private fun match(candidates: List, args: List): Pair>? { - // 1. Check for an exact match - for (candidate in candidates) { - if (candidate.matches(args)) { - return candidate to arrayOfNulls(args.size) - } - } - // 2. Look for best match. - var match: Pair>? = null - for (candidate in candidates) { - val m = candidate.match(args) ?: continue - // TODO AggMatch comparison - // if (match != null && m.exact < match.exact) { - // // already had a better match. - // continue - // } - match = m - } - // 3. Return best match or null - return match - } - - /** - * Check if this function accepts the exact input argument types. Assume same arity. - */ - private fun AggSignature.matches(args: List): Boolean { - for (i in args.indices) { - val a = args[i] - val p = parameters[i] - if (a != p.type) return false - } - return true +@OptIn(FnExperimental::class) +internal class PathResolverAgg( + catalog: ConnectorMetadata, + session: PartiQLPlanner.Session, +) : PathResolver(catalog, session) { + + companion object { + @JvmStatic + public val casts = CastTable.partiql } - /** - * Attempt to match arguments to the parameters; return the implicit casts if necessary. - * - * @param args - * @return - */ - private fun AggSignature.match(args: List): Pair>? { - val mapping = arrayOfNulls(args.size) - for (i in args.indices) { - val arg = args[i] - val p = parameters[i] - when { - // 1. Exact match - arg == p.type -> continue - // 2. Match ANY, no coercion needed - p.type == ANY -> continue - // 3. Match NULL argument - arg == NULL -> continue - // 4. Check for a coercion - else -> when (val coercion = casts.lookupCoercion(arg, p.type)) { - null -> return null // short-circuit - else -> mapping[i] = coercion - } - } - } - return this to mapping + override fun get(metadata: ConnectorMetadata, path: BindingPath): ConnectorHandle.Agg? { + return metadata.getAggregation(path) } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt index b94f58c29c..bb37267097 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt @@ -106,6 +106,7 @@ internal sealed class Ref : PlanNode() { public override fun accept(visitor: PlanVisitor, ctx: C): R = when (this) { is Obj -> visitor.visitRefObj(this, ctx) is Fn -> visitor.visitRefFn(this, ctx) + is Agg -> visitor.visitRefAgg(this, ctx) } internal data class Obj( @@ -138,6 +139,21 @@ internal sealed class Ref : PlanNode() { } } + internal data class Agg( + @JvmField internal val catalog: String, + @JvmField internal val path: List, + @JvmField internal val signature: AggSignature, + ) : Ref() { + public override val children: List = emptyList() + + public override fun accept(visitor: PlanVisitor, ctx: C): R = visitor.visitRefAgg(this, ctx) + + internal companion object { + @JvmStatic + internal fun builder(): RefAggBuilder = RefAggBuilder() + } + } + internal data class Cast( @JvmField internal val input: PartiQLValueType, @JvmField internal val target: PartiQLValueType, @@ -156,20 +172,6 @@ internal sealed class Ref : PlanNode() { internal fun builder(): RefCastBuilder = RefCastBuilder() } } - - internal data class Agg( - @JvmField internal val name: String, - @JvmField internal val signature: AggSignature, - ) : PlanNode() { - public override val children: List = emptyList() - - public override fun accept(visitor: PlanVisitor, ctx: C): R = visitor.visitRefAgg(this, ctx) - - internal companion object { - @JvmStatic - internal fun builder(): RefAggBuilder = RefAggBuilder() - } - } } internal sealed class Statement : PlanNode() { @@ -527,6 +529,7 @@ internal data class Rex( internal data class Candidate( @JvmField internal val fn: Ref.Fn, + @JvmField internal val parameters: List, @JvmField internal val coercions: List, ) : PlanNode() { public override val children: List by lazy { @@ -1117,6 +1120,7 @@ internal data class Rel( internal data class Unresolved( @JvmField internal val name: String, + @JvmField internal val setQuantifier: SetQuantifier, @JvmField internal val args: List, ) : Call() { public override val children: List by lazy { @@ -1137,6 +1141,7 @@ internal data class Rel( internal data class Resolved( @JvmField internal val agg: Ref.Agg, + @JvmField internal val setQuantifier: SetQuantifier, @JvmField internal val args: List, ) : Call() { public override val children: List by lazy { @@ -1156,6 +1161,10 @@ internal data class Rel( } } + internal enum class SetQuantifier { + ALL, DISTINCT, + } + internal companion object { @JvmStatic internal fun builder(): RelOpAggregateBuilder = RelOpAggregateBuilder() diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/NormalizeSelect.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/NormalizeSelect.kt index 9860aa5f33..5edecc3e4f 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/NormalizeSelect.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/NormalizeSelect.kt @@ -16,6 +16,7 @@ package org.partiql.planner.internal.transforms import org.partiql.ast.Expr import org.partiql.ast.From +import org.partiql.ast.GroupBy import org.partiql.ast.Identifier import org.partiql.ast.Select import org.partiql.ast.exprCall @@ -138,7 +139,13 @@ internal object NormalizeSelect { internal fun visitSFW(node: Expr.SFW, ctx: () -> Int): Expr.SFW { val sfw = super.visitExprSFW(node, ctx) as Expr.SFW return when (val select = sfw.select) { - is Select.Star -> sfw.copy(select = visitSelectAll(select, sfw.from)) + is Select.Star -> { + val selectValue = when (val group = sfw.groupBy) { + null -> visitSelectAll(select, sfw.from) + else -> visitSelectAll(select, group) + } + sfw.copy(select = selectValue) + } else -> sfw } } @@ -215,6 +222,25 @@ internal object NormalizeSelect { ) } + /** + * We need to call this from [visitExprSFW] and not override [visitSelectStar] because we need access to the + * [GroupBy] aliases. + * + * Note: We assume that [select] and [group] have already been visited. + */ + private fun visitSelectAll(select: Select.Star, group: GroupBy): Select.Value { + val groupAs = group.asAlias?.let { structField(it.symbol, varLocal(it.symbol)) } + val fields = group.keys.map { key -> + val alias = key.asAlias ?: error("Expected a GROUP BY alias.") + structField(alias.symbol, varLocal(alias.symbol)) + } + listOfNotNull(groupAs) + val constructor = exprStruct(fields) + return selectValue( + constructor = constructor, + setq = select.setq + ) + } + private fun visitSelectProjectWithProjectAll(node: Select.Project): Select.Value { val tupleUnionArgs = node.items.mapIndexed { index, item -> when (item) { @@ -274,6 +300,17 @@ internal object NormalizeSelect { ) ) + @OptIn(PartiQLValueExperimental::class) + private fun structField(name: String, expr: Expr): Expr.Struct.Field = Expr.Struct.Field( + name = Expr.Lit(stringValue(name)), + value = expr + ) + + private fun varLocal(name: String): Expr.Var = Expr.Var( + identifier = Identifier.Symbol(name, Identifier.CaseSensitivity.SENSITIVE), + scope = Expr.Var.Scope.LOCAL + ) + private fun From.aliases(): List> = when (this) { is From.Join -> lhs.aliases() + rhs.aliases() is From.Value -> { diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt index d4b9e9ad27..e3fda26c9b 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt @@ -66,6 +66,11 @@ internal object PlanTransform { */ override fun visitRefFn(node: Ref.Fn, ctx: Unit) = symbols.insert(node) + /** + * Insert into symbol table, returning the public reference. + */ + override fun visitRefAgg(node: Ref.Agg, ctx: Unit) = symbols.insert(node) + @OptIn(PartiQLValueExperimental::class) override fun visitRefCast(node: Ref.Cast, ctx: Unit) = org.partiql.plan.refCast(node.input, node.target) @@ -179,10 +184,11 @@ internal object PlanTransform { ) } + @OptIn(PartiQLValueExperimental::class) override fun visitRexOpCallDynamicCandidate(node: Rex.Op.Call.Dynamic.Candidate, ctx: Unit): PlanNode { val fn = visitRef(node.fn, ctx) val coercions = node.coercions.map { it?.let { visitRefCast(it, ctx) } } - return org.partiql.plan.Rex.Op.Call.Dynamic.Candidate(fn, coercions) + return org.partiql.plan.Rex.Op.Call.Dynamic.Candidate(fn, node.parameters, coercions) } override fun visitRexOpCase(node: Rex.Op.Case, ctx: Unit) = org.partiql.plan.Rex.Op.Case( @@ -345,9 +351,13 @@ internal object PlanTransform { } override fun visitRelOpAggregateCallResolved(node: Rel.Op.Aggregate.Call.Resolved, ctx: Unit): PlanNode { - val agg = node.agg.name + val agg = visitRef(node.agg, ctx) val args = node.args.map { visitRex(it, ctx) } - return org.partiql.plan.relOpAggregateCall(node.agg.name, args) + val setQuantifier = when (node.setQuantifier) { + Rel.Op.Aggregate.SetQuantifier.ALL -> org.partiql.plan.Rel.Op.Aggregate.Call.SetQuantifier.ALL + Rel.Op.Aggregate.SetQuantifier.DISTINCT -> org.partiql.plan.Rel.Op.Aggregate.Call.SetQuantifier.DISTINCT + } + return org.partiql.plan.relOpAggregateCall(agg, setQuantifier, args) } override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: Unit) = org.partiql.plan.Rel.Op.Exclude( diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt index b26ecd48d6..ab071d3169 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt @@ -28,7 +28,9 @@ import org.partiql.ast.SetOp import org.partiql.ast.SetQuantifier import org.partiql.ast.Sort import org.partiql.ast.builder.ast +import org.partiql.ast.exprVar import org.partiql.ast.helpers.toBinder +import org.partiql.ast.identifierSymbol import org.partiql.ast.util.AstRewriter import org.partiql.ast.visitor.AstBaseVisitor import org.partiql.planner.internal.Env @@ -66,10 +68,13 @@ import org.partiql.planner.internal.ir.rex import org.partiql.planner.internal.ir.rexOpLit import org.partiql.planner.internal.ir.rexOpPivot import org.partiql.planner.internal.ir.rexOpSelect +import org.partiql.planner.internal.ir.rexOpStruct +import org.partiql.planner.internal.ir.rexOpStructField import org.partiql.planner.internal.ir.rexOpVarLocal import org.partiql.types.StaticType import org.partiql.value.PartiQLValueExperimental import org.partiql.value.boolValue +import org.partiql.value.stringValue import org.partiql.planner.internal.ir.Identifier as InternalId /** @@ -158,7 +163,9 @@ internal object RelConverter { val project = visitSelectValue(projection, rel) visitSetQuantifier(projection.setq, project) } - is Select.Star, is Select.Project -> error("AST not normalized, found ${projection.javaClass.simpleName}") + is Select.Star, is Select.Project -> { + error("AST not normalized, found ${projection.javaClass.simpleName}") + } is Select.Pivot -> rel // Skip PIVOT } return rel @@ -244,7 +251,7 @@ internal object RelConverter { override fun visitFromJoin(node: From.Join, nil: Rel): Rel { val lhs = visitFrom(node.lhs, nil) val rhs = visitFrom(node.rhs, nil) - val schema = listOf() + val schema = lhs.type.schema + rhs.type.schema // Note: This gets more specific in PlanTyper. It is only used to find binding names here. val props = emptySet() val condition = node.condition?.let { RexConverter.apply(it, env) } ?: rex(StaticType.BOOL, rexOpLit(boolValue(true))) val joinType = when (node.type) { @@ -336,16 +343,13 @@ internal object RelConverter { * 1. Ast.Expr.SFW has every Ast.Expr.CallAgg replaced by a synthetic Ast.Expr.Var * 2. Rel which has the appropriate Rex.Agg calls and groups */ + @OptIn(PartiQLValueExperimental::class) private fun convertAgg(input: Rel, select: Expr.SFW, groupBy: GroupBy?): Pair { // Rewrite and extract all aggregations in the SELECT clause val (sel, aggregations) = AggregationTransform.apply(select) // No aggregation planning required for GROUP BY - if (aggregations.isEmpty()) { - if (groupBy != null) { - // GROUP BY with no aggregations is considered an error. - error("GROUP BY with no aggregations in SELECT clause") - } + if (aggregations.isEmpty() && groupBy == null) { return Pair(select, input) } @@ -367,7 +371,28 @@ internal object RelConverter { is InternalId.Qualified -> error("Qualified aggregation calls are not supported.") is InternalId.Symbol -> id.symbol.lowercase() } - relOpAggregateCallUnresolved(name, args) + val setq = when (expr.setq) { + null -> Rel.Op.Aggregate.SetQuantifier.ALL + SetQuantifier.ALL -> Rel.Op.Aggregate.SetQuantifier.ALL + SetQuantifier.DISTINCT -> Rel.Op.Aggregate.SetQuantifier.DISTINCT + } + relOpAggregateCallUnresolved(name, setq, args) + }.toMutableList() + + // Add GROUP_AS aggregation + groupBy?.let { gb -> + gb.asAlias?.let { groupAs -> + val binding = relBinding(groupAs.symbol, StaticType.ANY) + schema.add(binding) + val fields = input.type.schema.mapIndexed { bindingIndex, currBinding -> + rexOpStructField( + k = rex(StaticType.STRING, rexOpLit(stringValue(currBinding.name))), + v = rex(StaticType.ANY, rexOpVarLocal(0, bindingIndex)) + ) + } + val arg = listOf(rex(StaticType.ANY, rexOpStruct(fields))) + calls.add(relOpAggregateCallUnresolved("group_as", Rel.Op.Aggregate.SetQuantifier.ALL, arg)) + } } var groups = emptyList() if (groupBy != null) { @@ -559,23 +584,38 @@ internal object RelConverter { /** * Rewrites a SELECT node replacing (and extracting) each aggregation `i` with a synthetic field name `$agg_i`. */ - private object AggregationTransform : AstRewriter>() { + private object AggregationTransform : AstRewriter() { + + private data class Context( + val aggregations: MutableList, + val keys: List + ) fun apply(node: Expr.SFW): Pair> { val aggs = mutableListOf() - val select = super.visitExprSFW(node, aggs) as Expr.SFW + val keys = node.groupBy?.keys ?: emptyList() + val context = Context(aggs, keys) + val select = super.visitExprSFW(node, context) as Expr.SFW return Pair(select, aggs) } + override fun visitSelectValue(node: Select.Value, ctx: Context): AstNode { + val visited = super.visitSelectValue(node, ctx) + val substitutions = ctx.keys.associate { + it.expr to exprVar(identifierSymbol(it.asAlias!!.symbol, Identifier.CaseSensitivity.SENSITIVE), Expr.Var.Scope.DEFAULT) + } + return SubstitutionVisitor.visit(visited, substitutions) + } + // only rewrite top-level SFW - override fun visitExprSFW(node: Expr.SFW, ctx: MutableList): AstNode = node + override fun visitExprSFW(node: Expr.SFW, ctx: Context): AstNode = node - override fun visitExprAgg(node: Expr.Agg, ctx: MutableList) = ast { + override fun visitExprAgg(node: Expr.Agg, ctx: Context) = ast { val id = identifierSymbol { - symbol = syntheticAgg(ctx.size) + symbol = syntheticAgg(ctx.aggregations.size) caseSensitivity = org.partiql.ast.Identifier.CaseSensitivity.INSENSITIVE } - ctx += node + ctx.aggregations += node exprVar(id, Expr.Var.Scope.DEFAULT) } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/SubstitutionVisitor.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/SubstitutionVisitor.kt new file mode 100644 index 0000000000..04114b5346 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/SubstitutionVisitor.kt @@ -0,0 +1,15 @@ +package org.partiql.planner.internal.transforms + +import org.partiql.ast.AstNode +import org.partiql.ast.Expr +import org.partiql.ast.util.AstRewriter + +internal object SubstitutionVisitor : AstRewriter>() { + override fun visitExpr(node: Expr, ctx: Map<*, AstNode>): AstNode { + val visited = super.visitExpr(node, ctx) + if (ctx.containsKey(visited)) { + return ctx[visited]!! + } + return visited + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/Symbols.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/Symbols.kt index 349cdf2949..7a14eb85ce 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/Symbols.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/Symbols.kt @@ -2,6 +2,7 @@ package org.partiql.planner.internal.transforms import org.partiql.plan.Catalog import org.partiql.plan.builder.CatalogBuilder +import org.partiql.plan.catalogItemAgg import org.partiql.plan.catalogItemFn import org.partiql.plan.catalogItemValue import org.partiql.planner.internal.ir.Ref @@ -36,6 +37,12 @@ internal class Symbols private constructor() { item = catalogItemFn(ref.path, ref.signature.specific), ) + @OptIn(FnExperimental::class) + fun insert(ref: Ref.Agg): CatalogRef = insert( + catalog = ref.catalog, + item = catalogItemAgg(ref.path, ref.signature.specific), + ) + private fun insert(catalog: String, item: Catalog.Item): CatalogRef { val i = upsert(catalog) val c = catalogs[i] diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index 6490f250e5..9f55db39d4 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -1218,7 +1218,7 @@ internal class PlanTyper( return node to ANY } else if (arg.type is MissingType) { handleAlwaysMissing() - return relOpAggregateCallUnresolved(node.name, listOf(rexErr("MISSING"))) to MissingType + return relOpAggregateCallUnresolved(node.name, node.setQuantifier, listOf(rexErr("MISSING"))) to MissingType } else if (arg.type.isMissable()) { isMissable = true } @@ -1226,7 +1226,7 @@ internal class PlanTyper( } // Resolve the function - val call = env.resolveAgg(node.name, args) + val call = env.resolveAgg(node.name, node.setQuantifier, args) if (call == null) { handleUnknownAggregation(node) return node to ANY diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt index 3f7de468ea..9606541ec6 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt @@ -10,6 +10,8 @@ import org.partiql.planner.internal.ir.rexOpVarLocal import org.partiql.spi.BindingCase import org.partiql.spi.BindingName import org.partiql.spi.BindingPath +import org.partiql.types.AnyOfType +import org.partiql.types.AnyType import org.partiql.types.StaticType import org.partiql.types.StructType import org.partiql.types.TupleConstraint @@ -100,30 +102,28 @@ internal data class TypeEnv( for (i in schema.indices) { val local = schema[i] val type = local.type - if (type is StructType) { - when (type.containsKey(name)) { - true -> { - if (c != null && known) { - // TODO root was already definitively matched, emit ambiguous error. - return null - } - c = rex(type, rexOpVarLocal(depth, i)) - known = true + when (type.containsKey(name)) { + true -> { + if (c != null && known) { + // TODO root was already definitively matched, emit ambiguous error. + return null } - null -> { - if (c != null) { - if (known) { - continue - } else { - // TODO we have more than one possible match, emit ambiguous error. - return null - } + c = rex(type, rexOpVarLocal(depth, i)) + known = true + } + null -> { + if (c != null) { + if (known) { + continue + } else { + // TODO we have more than one possible match, emit ambiguous error. + return null } - c = rex(type, rexOpVarLocal(depth, i)) - known = false } - false -> continue + c = rex(type, rexOpVarLocal(depth, i)) + known = false } + false -> continue } } if (c == null && outer.isNotEmpty()) { @@ -153,6 +153,35 @@ internal data class TypeEnv( return if (closed) false else null } + /** + * Searches for the [BindingName] within the given [StaticType]. + * + * Returns + * - true iff known to contain key + * - false iff known to NOT contain key + * - null iff NOT known to contain key + * + * @param name + * @return + */ + private fun StaticType.containsKey(name: BindingName): Boolean? { + return when (val type = this.flatten()) { + is StructType -> type.containsKey(name) + is AnyOfType -> { + val anyKnownToContainKey = type.allTypes.any { it.containsKey(name) == true } + val anyKnownToNotContainKey = type.allTypes.any { it.containsKey(name) == false } + val anyNotKnownToContainKey = type.allTypes.any { it.containsKey(name) == null } + when { + anyKnownToNotContainKey.not() && anyNotKnownToContainKey.not() -> true + anyKnownToContainKey.not() && anyNotKnownToContainKey -> false + else -> null + } + } + is AnyType -> null + else -> false + } + } + companion object { /** diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt index 5dd156908a..9a6455711c 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt @@ -92,7 +92,7 @@ internal fun StaticType.toRuntimeType(): PartiQLValueType { // handle anyOf(null, T) cases val t = types.filter { it !is NullType && it !is MissingType } return if (t.size != 1) { - error("Cannot have a UNION runtime type: $this") + PartiQLValueType.ANY } else { t.first().asRuntimeType() } diff --git a/partiql-planner/src/main/resources/partiql_plan_internal.ion b/partiql-planner/src/main/resources/partiql_plan_internal.ion index 8a810f9f68..d69180f572 100644 --- a/partiql-planner/src/main/resources/partiql_plan_internal.ion +++ b/partiql-planner/src/main/resources/partiql_plan_internal.ion @@ -28,6 +28,11 @@ ref::[ path: list::[string], signature: fn_signature, }, + agg::{ + catalog: string, + path: list::[string], + signature: agg_signature, + }, _::[ cast::{ input: partiql_value_type, @@ -37,11 +42,7 @@ ref::[ EXPLICIT, // Lossy CAST(V AS T) -> T UNSAFE, // CAST(V AS T) -> T|MISSING ] - }, - agg::{ - name: string, - signature: agg_signature, - }, + } ] ] @@ -140,6 +141,7 @@ rex::{ // Represents a dynamic function call. If all candidates are exhausted, dynamic calls will return MISSING. // // args: represent the original typed arguments. These will eventually be wrapped by coercions from [candidates]. + // parameters: represents the input type(s) to match. (ex: INT32) // candidates: represent the potentially applicable resolved functions with coercions. Each of these candidates // should be overloaded functions of the same name and number of arguments. dynamic::{ @@ -149,6 +151,7 @@ rex::{ _: [ candidate::{ fn: '.ref.fn', + parameters: list::[partiql_value_type], coercions: list::[optional::'.ref.cast'], } ] @@ -314,13 +317,16 @@ rel::{ call::[ unresolved::{ name: string, + set_quantifier: set_quantifier, args: list::[rex], }, resolved::{ agg: '.ref.agg', + set_quantifier: set_quantifier, args: list::[rex], }, - ] + ], + set_quantifier::[ ALL, DISTINCT ] ], }, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/Connector.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/Connector.kt index 74f061a249..5726996dc2 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/Connector.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/Connector.kt @@ -47,6 +47,14 @@ public interface Connector { @FnExperimental public fun getFunctions(): ConnectorFnProvider + /** + * Returns a [ConnectorAggProvider] which the engine uses to load aggregation function implementations. + * + * @return + */ + @FnExperimental + public fun getAggregations(): ConnectorAggProvider + /** * A Plugin leverages a [Factory] to produce a [Connector] which is used for catalog metadata and data access. */ diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorAgg.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorAgg.kt new file mode 100644 index 0000000000..2bad2ec44f --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorAgg.kt @@ -0,0 +1,29 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.spi.connector + +import org.partiql.spi.fn.AggSignature +import org.partiql.spi.fn.FnExperimental + +@OptIn(FnExperimental::class) +public interface ConnectorAgg { + + /** + * Returns a function's variants. + * + * @return + */ + public fun getVariants(): List +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorAggProvider.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorAggProvider.kt new file mode 100644 index 0000000000..f496275e20 --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorAggProvider.kt @@ -0,0 +1,26 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.spi.connector + +import org.partiql.spi.fn.Agg +import org.partiql.spi.fn.FnExperimental + +/** + * A [ConnectorAggProvider] implementation is responsible for providing an aggregation function implementation given a handle. + */ +@FnExperimental +public interface ConnectorAggProvider { + public fun getAgg(path: ConnectorPath, specific: String): Agg? +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorHandle.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorHandle.kt index bc924d7a9f..e4624278b0 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorHandle.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorHandle.kt @@ -38,4 +38,9 @@ public sealed class ConnectorHandle { override val path: ConnectorPath, override val entity: ConnectorFn, ) : ConnectorHandle() + + public data class Agg( + override val path: ConnectorPath, + override val entity: ConnectorAgg, + ) : ConnectorHandle() } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorMetadata.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorMetadata.kt index 77dea443c9..3b447afb9d 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorMetadata.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorMetadata.kt @@ -46,4 +46,13 @@ public interface ConnectorMetadata { */ @FnExperimental public fun getFunction(path: BindingPath): ConnectorHandle.Fn? + + /** + * Returns all aggregation function signatures matching the given name. + * + * @param path + * @return + */ + @FnExperimental + public fun getAggregation(path: BindingPath): ConnectorHandle.Agg? } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlAgg.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlAgg.kt new file mode 100644 index 0000000000..90912f93dd --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlAgg.kt @@ -0,0 +1,35 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.spi.connector.sql + +import org.partiql.spi.connector.ConnectorAgg +import org.partiql.spi.connector.ConnectorFn +import org.partiql.spi.fn.AggSignature +import org.partiql.spi.fn.FnExperimental + +/** + * Simple [ConnectorFn] implementation wrapping a signature. + * + * @property name + * @property variants + */ +@OptIn(FnExperimental::class) +public class SqlAgg( + private val name: String, + private val variants: List, +) : ConnectorAgg { + + override fun getVariants(): List = variants +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlAggProvider.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlAggProvider.kt new file mode 100644 index 0000000000..38a4f0236d --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlAggProvider.kt @@ -0,0 +1,33 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.spi.connector.sql + +import org.partiql.spi.connector.ConnectorAggProvider +import org.partiql.spi.connector.ConnectorFnProvider +import org.partiql.spi.connector.ConnectorPath +import org.partiql.spi.fn.Agg +import org.partiql.spi.fn.FnExperimental +import org.partiql.spi.fn.Index + +/** + * A basic [ConnectorFnProvider] over an [Index]. + */ +@OptIn(FnExperimental::class) +public class SqlAggProvider(private val index: Index) : ConnectorAggProvider { + + override fun getAgg(path: ConnectorPath, specific: String): Agg? { + return index.get(path, specific) + } +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt index a67452a353..f17a6b8f64 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt @@ -2,6 +2,7 @@ package org.partiql.spi.connector.sql /* ktlint-disable no-wildcard-imports */ import org.partiql.spi.connector.sql.builtins.* +import org.partiql.spi.fn.Agg import org.partiql.spi.fn.Fn import org.partiql.spi.fn.FnExperimental @@ -24,6 +25,14 @@ internal object SqlBuiltins { Fn_CHAR_LENGTH__STRING__INT, Fn_CHAR_LENGTH__SYMBOL__INT, Fn_CHAR_LENGTH__CLOB__INT, + Fn_COLL_AGG__BAG__ANY.ANY, + Fn_COLL_AGG__BAG__ANY.AVG, + Fn_COLL_AGG__BAG__ANY.COUNT, + Fn_COLL_AGG__BAG__ANY.EVERY, + Fn_COLL_AGG__BAG__ANY.MAX, + Fn_COLL_AGG__BAG__ANY.MIN, + Fn_COLL_AGG__BAG__ANY.SOME, + Fn_COLL_AGG__BAG__ANY.SUM, Fn_POS__INT8__INT8, Fn_POS__INT16__INT16, Fn_POS__INT32__INT32, @@ -478,4 +487,52 @@ internal object SqlBuiltins { Fn_CURRENT_USER____STRING, Fn_CURRENT_DATE____DATE ) + + @JvmStatic + val aggregations: List = listOf( + Agg_ANY__BOOL__BOOL, + Agg_AVG__INT8__INT8, + Agg_AVG__INT16__INT16, + Agg_AVG__INT32__INT32, + Agg_AVG__INT64__INT64, + Agg_AVG__INT__INT, + Agg_AVG__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY, + Agg_AVG__FLOAT32__FLOAT32, + Agg_AVG__FLOAT64__FLOAT64, + Agg_AVG__ANY__ANY, + Agg_COUNT__ANY__INT32, + Agg_COUNT_STAR____INT32, + Agg_EVERY__BOOL__BOOL, + Agg_EVERY__ANY__BOOL, + Agg_MAX__INT8__INT8, + Agg_MAX__INT16__INT16, + Agg_MAX__INT32__INT32, + Agg_MAX__INT64__INT64, + Agg_MAX__INT__INT, + Agg_MAX__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY, + Agg_MAX__FLOAT32__FLOAT32, + Agg_MAX__FLOAT64__FLOAT64, + Agg_MAX__ANY__ANY, + Agg_MIN__INT8__INT8, + Agg_MIN__INT16__INT16, + Agg_MIN__INT32__INT32, + Agg_MIN__INT64__INT64, + Agg_MIN__INT__INT, + Agg_MIN__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY, + Agg_MIN__FLOAT32__FLOAT32, + Agg_MIN__FLOAT64__FLOAT64, + Agg_MIN__ANY__ANY, + Agg_SOME__BOOL__BOOL, + Agg_SOME__ANY__BOOL, + Agg_SUM__INT8__INT8, + Agg_SUM__INT16__INT16, + Agg_SUM__INT32__INT32, + Agg_SUM__INT64__INT64, + Agg_SUM__INT__INT, + Agg_SUM__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY, + Agg_SUM__FLOAT32__FLOAT32, + Agg_SUM__FLOAT64__FLOAT64, + Agg_SUM__ANY__ANY, + Agg_GROUP_AS__ANY__ANY + ) } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlConnector.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlConnector.kt index 15e2f3c032..5d2ca12131 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlConnector.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlConnector.kt @@ -15,6 +15,7 @@ package org.partiql.spi.connector.sql import org.partiql.spi.connector.Connector +import org.partiql.spi.connector.ConnectorAggProvider import org.partiql.spi.connector.ConnectorBindings import org.partiql.spi.connector.ConnectorFnProvider import org.partiql.spi.connector.ConnectorMetadata @@ -45,4 +46,7 @@ public abstract class SqlConnector : Connector { @FnExperimental override fun getFunctions(): ConnectorFnProvider = SqlFnProvider(info.functions) + + @FnExperimental + override fun getAggregations(): ConnectorAggProvider = SqlAggProvider(info.aggregations) } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlFnProvider.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlFnProvider.kt index f1c9b4bbd4..a74e5e95b9 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlFnProvider.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlFnProvider.kt @@ -18,16 +18,15 @@ import org.partiql.spi.connector.ConnectorFnProvider import org.partiql.spi.connector.ConnectorPath import org.partiql.spi.fn.Fn import org.partiql.spi.fn.FnExperimental -import org.partiql.spi.fn.FnIndex +import org.partiql.spi.fn.Index /** - * A basic [ConnectorFnProvider] over an [FnIndex]. + * A basic [ConnectorFnProvider] over an [Index]. */ @OptIn(FnExperimental::class) -public class SqlFnProvider(private val index: FnIndex) : ConnectorFnProvider { +public class SqlFnProvider(private val index: Index) : ConnectorFnProvider { override fun getFn(path: ConnectorPath, specific: String): Fn? { - val fn = index.get(path, specific) - return if (fn is Fn) fn else null + return index.get(path, specific) } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlMetadata.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlMetadata.kt index 2a772bcb8c..128085745c 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlMetadata.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlMetadata.kt @@ -51,4 +51,15 @@ public open class SqlMetadata( } return ConnectorHandle.Fn(ConnectorPath(cnf), SqlFn(name, variants)) } + + @FnExperimental + override fun getAggregation(path: BindingPath): ConnectorHandle.Agg? { + val cnf = path.steps.map { it.name.uppercase() } + val name = cnf.last() + val variants = info.aggregations.get(cnf).map { it.signature } + if (variants.isEmpty()) { + return null + } + return ConnectorHandle.Agg(ConnectorPath(cnf), SqlAgg(name, variants)) + } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggAny.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggAny.kt index 0023d64d1a..ae232d4129 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggAny.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggAny.kt @@ -3,11 +3,13 @@ package org.partiql.spi.connector.sql.builtins +import org.partiql.spi.connector.sql.builtins.internal.AccumulatorAnySome import org.partiql.spi.fn.Agg import org.partiql.spi.fn.AggSignature import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnParameter import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType.ANY import org.partiql.value.PartiQLValueType.BOOL @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -23,7 +25,21 @@ public object Agg_ANY__BOOL__BOOL : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation any not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorAnySome() +} + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +public object Agg_ANY__ANY__BOOL : Agg { + + override val signature: AggSignature = AggSignature( + name = "any", + returns = BOOL, + parameters = listOf( + FnParameter("value", ANY), + ), + isNullable = true, + isDecomposable = true + ) + + override fun accumulator(): Agg.Accumulator = AccumulatorAnySome() } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggAvg.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggAvg.kt index 5a7f2e0fff..cf110ab6a3 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggAvg.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggAvg.kt @@ -3,11 +3,13 @@ package org.partiql.spi.connector.sql.builtins +import org.partiql.spi.connector.sql.builtins.internal.AccumulatorAvg import org.partiql.spi.fn.Agg import org.partiql.spi.fn.AggSignature import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnParameter import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType.ANY import org.partiql.value.PartiQLValueType.DECIMAL_ARBITRARY import org.partiql.value.PartiQLValueType.FLOAT32 import org.partiql.value.PartiQLValueType.FLOAT64 @@ -30,9 +32,7 @@ public object Agg_AVG__INT8__INT8 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation avg not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorAvg(INT8) } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -48,9 +48,7 @@ public object Agg_AVG__INT16__INT16 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation avg not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorAvg(INT16) } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -66,9 +64,7 @@ public object Agg_AVG__INT32__INT32 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation avg not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorAvg(INT32) } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -84,9 +80,7 @@ public object Agg_AVG__INT64__INT64 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation avg not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorAvg(INT64) } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -102,9 +96,7 @@ public object Agg_AVG__INT__INT : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation avg not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorAvg(INT) } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -120,9 +112,7 @@ public object Agg_AVG__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation avg not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorAvg(DECIMAL_ARBITRARY) } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -138,9 +128,7 @@ public object Agg_AVG__FLOAT32__FLOAT32 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation avg not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorAvg(FLOAT32) } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -156,7 +144,21 @@ public object Agg_AVG__FLOAT64__FLOAT64 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation avg not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorAvg(FLOAT64) +} + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +public object Agg_AVG__ANY__ANY : Agg { + + override val signature: AggSignature = AggSignature( + name = "avg", + returns = ANY, + parameters = listOf( + FnParameter("value", ANY), + ), + isNullable = true, + isDecomposable = true + ) + + override fun accumulator(): Agg.Accumulator = AccumulatorAvg() } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCount.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCount.kt index 8cc44e1c5e..0496f236db 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCount.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCount.kt @@ -3,6 +3,7 @@ package org.partiql.spi.connector.sql.builtins +import org.partiql.spi.connector.sql.builtins.internal.AccumulatorCount import org.partiql.spi.fn.Agg import org.partiql.spi.fn.AggSignature import org.partiql.spi.fn.FnExperimental @@ -24,7 +25,5 @@ public object Agg_COUNT__ANY__INT32 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation count not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorCount() } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCountStar.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCountStar.kt index b1ad728ccb..d8088a2017 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCountStar.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCountStar.kt @@ -3,6 +3,7 @@ package org.partiql.spi.connector.sql.builtins +import org.partiql.spi.connector.sql.builtins.internal.AccumulatorCountStar import org.partiql.spi.fn.Agg import org.partiql.spi.fn.AggSignature import org.partiql.spi.fn.FnExperimental @@ -20,7 +21,5 @@ public object Agg_COUNT_STAR____INT32 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation count_star not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorCountStar() } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggEvery.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggEvery.kt index c5a2014d04..8420c15d8e 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggEvery.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggEvery.kt @@ -3,11 +3,13 @@ package org.partiql.spi.connector.sql.builtins +import org.partiql.spi.connector.sql.builtins.internal.AccumulatorEvery import org.partiql.spi.fn.Agg import org.partiql.spi.fn.AggSignature import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnParameter import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType.ANY import org.partiql.value.PartiQLValueType.BOOL @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -23,7 +25,21 @@ public object Agg_EVERY__BOOL__BOOL : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation every not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorEvery() +} + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +public object Agg_EVERY__ANY__BOOL : Agg { + + override val signature: AggSignature = AggSignature( + name = "every", + returns = BOOL, + parameters = listOf( + FnParameter("value", ANY), + ), + isNullable = true, + isDecomposable = true + ) + + override fun accumulator(): Agg.Accumulator = AccumulatorEvery() } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggGroupAs.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggGroupAs.kt new file mode 100644 index 0000000000..5b586ee155 --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggGroupAs.kt @@ -0,0 +1,28 @@ +// ktlint-disable filename +@file:Suppress("ClassName") + +package org.partiql.spi.connector.sql.builtins + +import org.partiql.spi.connector.sql.builtins.internal.AccumulatorGroupAs +import org.partiql.spi.fn.Agg +import org.partiql.spi.fn.AggSignature +import org.partiql.spi.fn.FnExperimental +import org.partiql.spi.fn.FnParameter +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +public object Agg_GROUP_AS__ANY__ANY : Agg { + + override val signature: AggSignature = AggSignature( + name = "group_as", + returns = PartiQLValueType.ANY, + parameters = listOf( + FnParameter("value", PartiQLValueType.ANY), + ), + isNullable = true, + isDecomposable = true + ) + + override fun accumulator(): Agg.Accumulator = AccumulatorGroupAs() +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggMax.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggMax.kt index f17fed0f7c..71a094dc98 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggMax.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggMax.kt @@ -3,11 +3,13 @@ package org.partiql.spi.connector.sql.builtins +import org.partiql.spi.connector.sql.builtins.internal.AccumulatorMax import org.partiql.spi.fn.Agg import org.partiql.spi.fn.AggSignature import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnParameter import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType.ANY import org.partiql.value.PartiQLValueType.DECIMAL_ARBITRARY import org.partiql.value.PartiQLValueType.FLOAT32 import org.partiql.value.PartiQLValueType.FLOAT64 @@ -30,9 +32,7 @@ public object Agg_MAX__INT8__INT8 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation max not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorMax() } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -48,9 +48,7 @@ public object Agg_MAX__INT16__INT16 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation max not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorMax() } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -66,9 +64,7 @@ public object Agg_MAX__INT32__INT32 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation max not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorMax() } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -84,9 +80,7 @@ public object Agg_MAX__INT64__INT64 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation max not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorMax() } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -102,9 +96,7 @@ public object Agg_MAX__INT__INT : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation max not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorMax() } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -120,9 +112,7 @@ public object Agg_MAX__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation max not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorMax() } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -138,9 +128,7 @@ public object Agg_MAX__FLOAT32__FLOAT32 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation max not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorMax() } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -156,7 +144,21 @@ public object Agg_MAX__FLOAT64__FLOAT64 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation max not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorMax() +} + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +public object Agg_MAX__ANY__ANY : Agg { + + override val signature: AggSignature = AggSignature( + name = "max", + returns = ANY, + parameters = listOf( + FnParameter("value", ANY), + ), + isNullable = true, + isDecomposable = true + ) + + override fun accumulator(): Agg.Accumulator = AccumulatorMax() } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggMin.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggMin.kt index 10047d6864..03b2f7d009 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggMin.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggMin.kt @@ -3,11 +3,13 @@ package org.partiql.spi.connector.sql.builtins +import org.partiql.spi.connector.sql.builtins.internal.AccumulatorMin import org.partiql.spi.fn.Agg import org.partiql.spi.fn.AggSignature import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnParameter import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType.ANY import org.partiql.value.PartiQLValueType.DECIMAL_ARBITRARY import org.partiql.value.PartiQLValueType.FLOAT32 import org.partiql.value.PartiQLValueType.FLOAT64 @@ -30,9 +32,7 @@ public object Agg_MIN__INT8__INT8 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation min not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorMin() } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -48,9 +48,7 @@ public object Agg_MIN__INT16__INT16 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation min not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorMin() } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -66,9 +64,7 @@ public object Agg_MIN__INT32__INT32 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation min not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorMin() } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -84,9 +80,7 @@ public object Agg_MIN__INT64__INT64 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation min not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorMin() } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -102,9 +96,7 @@ public object Agg_MIN__INT__INT : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation min not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorMin() } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -120,9 +112,7 @@ public object Agg_MIN__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation min not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorMin() } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -138,9 +128,7 @@ public object Agg_MIN__FLOAT32__FLOAT32 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation min not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorMin() } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -156,7 +144,21 @@ public object Agg_MIN__FLOAT64__FLOAT64 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation min not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorMin() +} + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +public object Agg_MIN__ANY__ANY : Agg { + + override val signature: AggSignature = AggSignature( + name = "min", + returns = ANY, + parameters = listOf( + FnParameter("value", ANY), + ), + isNullable = true, + isDecomposable = true + ) + + override fun accumulator(): Agg.Accumulator = AccumulatorMin() } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggSome.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggSome.kt index 1a1b279252..1480c27a8f 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggSome.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggSome.kt @@ -3,11 +3,13 @@ package org.partiql.spi.connector.sql.builtins +import org.partiql.spi.connector.sql.builtins.internal.AccumulatorAnySome import org.partiql.spi.fn.Agg import org.partiql.spi.fn.AggSignature import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnParameter import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType.ANY import org.partiql.value.PartiQLValueType.BOOL @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -23,7 +25,21 @@ public object Agg_SOME__BOOL__BOOL : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation some not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorAnySome() +} + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +public object Agg_SOME__ANY__BOOL : Agg { + + override val signature: AggSignature = AggSignature( + name = "some", + returns = BOOL, + parameters = listOf( + FnParameter("value", ANY), + ), + isNullable = true, + isDecomposable = true + ) + + override fun accumulator(): Agg.Accumulator = AccumulatorAnySome() } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggSum.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggSum.kt index b0f73000b6..9fc312159e 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggSum.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggSum.kt @@ -3,11 +3,13 @@ package org.partiql.spi.connector.sql.builtins +import org.partiql.spi.connector.sql.builtins.internal.AccumulatorSum import org.partiql.spi.fn.Agg import org.partiql.spi.fn.AggSignature import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnParameter import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType.ANY import org.partiql.value.PartiQLValueType.DECIMAL_ARBITRARY import org.partiql.value.PartiQLValueType.FLOAT32 import org.partiql.value.PartiQLValueType.FLOAT64 @@ -30,9 +32,7 @@ public object Agg_SUM__INT8__INT8 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation sum not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorSum(INT8) } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -48,9 +48,7 @@ public object Agg_SUM__INT16__INT16 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation sum not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorSum(INT16) } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -66,9 +64,7 @@ public object Agg_SUM__INT32__INT32 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation sum not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorSum(INT32) } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -84,9 +80,7 @@ public object Agg_SUM__INT64__INT64 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation sum not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorSum(INT64) } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -102,9 +96,7 @@ public object Agg_SUM__INT__INT : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation sum not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorSum(INT) } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -120,9 +112,7 @@ public object Agg_SUM__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation sum not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorSum(DECIMAL_ARBITRARY) } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -138,9 +128,7 @@ public object Agg_SUM__FLOAT32__FLOAT32 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation sum not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorSum(FLOAT32) } @OptIn(PartiQLValueExperimental::class, FnExperimental::class) @@ -156,7 +144,21 @@ public object Agg_SUM__FLOAT64__FLOAT64 : Agg { isDecomposable = true ) - override fun accumulator(): Agg.Accumulator { - TODO("Aggregation sum not implemented") - } + override fun accumulator(): Agg.Accumulator = AccumulatorSum(FLOAT64) +} + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +public object Agg_SUM__ANY__ANY : Agg { + + override val signature: AggSignature = AggSignature( + name = "sum", + returns = ANY, + parameters = listOf( + FnParameter("value", ANY), + ), + isNullable = true, + isDecomposable = true + ) + + override fun accumulator(): Agg.Accumulator = AccumulatorSum() } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnCollAgg.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnCollAgg.kt new file mode 100644 index 0000000000..c480edb9ec --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnCollAgg.kt @@ -0,0 +1,89 @@ +// ktlint-disable filename +@file:Suppress("ClassName") + +package org.partiql.spi.connector.sql.builtins + +import org.partiql.spi.connector.sql.builtins.internal.Accumulator +import org.partiql.spi.connector.sql.builtins.internal.AccumulatorAnySome +import org.partiql.spi.connector.sql.builtins.internal.AccumulatorAvg +import org.partiql.spi.connector.sql.builtins.internal.AccumulatorCount +import org.partiql.spi.connector.sql.builtins.internal.AccumulatorEvery +import org.partiql.spi.connector.sql.builtins.internal.AccumulatorMax +import org.partiql.spi.connector.sql.builtins.internal.AccumulatorMin +import org.partiql.spi.connector.sql.builtins.internal.AccumulatorSum +import org.partiql.spi.fn.Agg +import org.partiql.spi.fn.Fn +import org.partiql.spi.fn.FnExperimental +import org.partiql.spi.fn.FnParameter +import org.partiql.spi.fn.FnSignature +import org.partiql.value.BagValue +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType +import org.partiql.value.check + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +internal abstract class Fn_COLL_AGG__BAG__ANY : Fn { + + abstract fun getAccumulator(): Agg.Accumulator + + companion object { + @JvmStatic + internal fun createSignature(name: String) = FnSignature( + name = name, + returns = PartiQLValueType.ANY, + parameters = listOf( + FnParameter("value", PartiQLValueType.BAG), + ), + isNullCall = true, + isNullable = true + ) + } + + override fun invoke(args: Array): PartiQLValue { + val bag = args[0].check>() + val accumulator = getAccumulator() + bag.forEach { element -> accumulator.next(arrayOf(element)) } + return accumulator.value() + } + + object SUM : Fn_COLL_AGG__BAG__ANY() { + override val signature = createSignature("coll_sum") + override fun getAccumulator(): Accumulator = AccumulatorSum() + } + + object AVG : Fn_COLL_AGG__BAG__ANY() { + override val signature = createSignature("coll_avg") + override fun getAccumulator(): Accumulator = AccumulatorAvg() + } + + object MIN : Fn_COLL_AGG__BAG__ANY() { + override val signature = createSignature("coll_min") + override fun getAccumulator(): Accumulator = AccumulatorMin() + } + + object MAX : Fn_COLL_AGG__BAG__ANY() { + override val signature = createSignature("coll_max") + override fun getAccumulator(): Accumulator = AccumulatorMax() + } + + object COUNT : Fn_COLL_AGG__BAG__ANY() { + override val signature = createSignature("coll_count") + override fun getAccumulator(): Accumulator = AccumulatorCount() + } + + object EVERY : Fn_COLL_AGG__BAG__ANY() { + override val signature = createSignature("coll_every") + override fun getAccumulator(): Accumulator = AccumulatorEvery() + } + + object ANY : Fn_COLL_AGG__BAG__ANY() { + override val signature = createSignature("coll_any") + override fun getAccumulator(): Accumulator = AccumulatorAnySome() + } + + object SOME : Fn_COLL_AGG__BAG__ANY() { + override val signature = createSignature("coll_some") + override fun getAccumulator(): Accumulator = AccumulatorAnySome() + } +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnEq.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnEq.kt index 881a74284c..fe49428cb4 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnEq.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnEq.kt @@ -67,6 +67,8 @@ import org.partiql.value.check @OptIn(PartiQLValueExperimental::class, FnExperimental::class) internal object Fn_EQ__ANY_ANY__BOOL : Fn { + private val comparator = PartiQLValue.comparator() + override val signature = FnSignature( name = "eq", returns = BOOL, @@ -84,7 +86,10 @@ internal object Fn_EQ__ANY_ANY__BOOL : Fn { override fun invoke(args: Array): PartiQLValue { val lhs = args[0] val rhs = args[1] - return boolValue(lhs == rhs) + return when { + lhs.type == MISSING || rhs.type == MISSING -> boolValue(lhs == rhs) + else -> boolValue(comparator.compare(lhs, rhs) == 0) + } } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/Accumulator.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/Accumulator.kt new file mode 100644 index 0000000000..8cfec2bca3 --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/Accumulator.kt @@ -0,0 +1,236 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +@file:OptIn(PartiQLValueExperimental::class) + +package org.partiql.spi.connector.sql.builtins.internal + +import com.amazon.ion.Decimal +import org.partiql.errors.TypeCheckException +import org.partiql.spi.fn.Agg +import org.partiql.spi.fn.FnExperimental +import org.partiql.value.BoolValue +import org.partiql.value.DecimalValue +import org.partiql.value.Float32Value +import org.partiql.value.Float64Value +import org.partiql.value.Int16Value +import org.partiql.value.Int32Value +import org.partiql.value.Int64Value +import org.partiql.value.Int8Value +import org.partiql.value.IntValue +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType +import org.partiql.value.decimalValue +import org.partiql.value.float32Value +import org.partiql.value.float64Value +import org.partiql.value.int16Value +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.int8Value +import org.partiql.value.intValue +import org.partiql.value.nullValue +import org.partiql.value.util.coerceNumbers +import java.math.BigDecimal +import java.math.BigInteger +import java.math.MathContext +import java.math.RoundingMode + +@OptIn(FnExperimental::class) +internal abstract class Accumulator : Agg.Accumulator { + + /** Accumulates the next value into this [Accumulator]. */ + @OptIn(PartiQLValueExperimental::class) + override fun next(args: Array) { + val value = args[0] + if (value.isUnknown()) return + nextValue(value) + } + + abstract fun nextValue(value: PartiQLValue) +} + +@OptIn(PartiQLValueExperimental::class) +internal fun comparisonAccumulator(comparator: Comparator): (PartiQLValue?, PartiQLValue) -> PartiQLValue = + { left, right -> + when { + left == null || comparator.compare(left, right) > 0 -> right + else -> left + } + } + +@OptIn(PartiQLValueExperimental::class) +internal fun checkIsNumberType(funcName: String, value: PartiQLValue) { + if (!value.type.isNumber()) { + throw TypeCheckException("Expected NUMBER but received ${value.type}.") + } +} + +internal operator fun Number.plus(other: Number): Number { + val (first, second) = coerceNumbers(this, other) + return when (first) { + is Long -> first.checkOverflowPlus(second as Long) + is Double -> first + second as Double + is BigDecimal -> first.add(second as BigDecimal, MATH_CONTEXT) + is BigInteger -> first.add(second as BigInteger) + else -> throw IllegalStateException() + } +} + +internal operator fun Number.div(other: Number): Number { + val (first, second) = coerceNumbers(this, other) + return when (first) { + is Long -> first.checkOverflowDivision(second as Long) + is Double -> first / second as Double + is BigDecimal -> first.divide(second as BigDecimal, MATH_CONTEXT) + else -> throw IllegalStateException() + } +} + +private fun Long.checkOverflowDivision(other: Long): Number { + // division can only underflow Long.MIN_VALUE / -1 + // because abs(Long.MIN_VALUE) == abs(Long.MAX_VALUE) + 1 + if (this == Long.MIN_VALUE && other == -1L) { + error("Division overflow or underflow.") + } + + return this / other +} + +private fun Long.checkOverflowPlus(other: Long): Number { + // uses to XOR to check if + // this and other are >= 0 then if result < 0 means overflow + // this and other are < 0 then if result > 0 means underflow + // if this and other have different signs then no overflow can happen + + val result: Long = this + other + val overflows = ((this xor other) >= 0) and ((this xor result) < 0) + return when (overflows) { + false -> result + else -> error("Int overflow or underflow") + } +} + +@OptIn(PartiQLValueExperimental::class) +internal fun checkIsBooleanType(funcName: String, value: PartiQLValue) { + if (value.type != PartiQLValueType.BOOL) { + throw TypeCheckException("Expected ${PartiQLValueType.BOOL} but received ${value.type}.") + } +} + +@OptIn(PartiQLValueExperimental::class) +internal fun PartiQLValue.isUnknown(): Boolean = this.type == PartiQLValueType.MISSING || this.isNull + +@OptIn(PartiQLValueExperimental::class) +internal fun PartiQLValue.numberValue(): Number = when (this) { + is IntValue -> this.value!! + is Int8Value -> this.value!! + is Int16Value -> this.value!! + is Int32Value -> this.value!! + is Int64Value -> this.value!! + is DecimalValue -> this.value!! + is Float32Value -> this.value!! + is Float64Value -> this.value!! + else -> error("Cannot convert PartiQLValue ($this) to number.") +} + +@OptIn(PartiQLValueExperimental::class) +internal fun PartiQLValue.booleanValue(): Boolean = when (this) { + is BoolValue -> this.value!! + else -> error("Cannot convert PartiQLValue ($this) to boolean.") +} + +@OptIn(PartiQLValueExperimental::class) +internal fun PartiQLValueType.isNumber(): Boolean = when (this) { + PartiQLValueType.INT, + PartiQLValueType.INT8, + PartiQLValueType.INT16, + PartiQLValueType.INT32, + PartiQLValueType.INT64, + PartiQLValueType.DECIMAL, + PartiQLValueType.DECIMAL_ARBITRARY, + PartiQLValueType.FLOAT32, + PartiQLValueType.FLOAT64 -> true + else -> false +} + +/** + * This is specifically for SUM/AVG + */ +@OptIn(PartiQLValueExperimental::class) +internal fun nullToTargetType(type: PartiQLValueType): PartiQLValue = when (type) { + PartiQLValueType.ANY -> nullValue() + PartiQLValueType.FLOAT32 -> float32Value(null) + PartiQLValueType.FLOAT64 -> float64Value(null) + PartiQLValueType.INT8 -> int8Value(null) + PartiQLValueType.INT16 -> int16Value(null) + PartiQLValueType.INT32 -> int32Value(null) + PartiQLValueType.INT64 -> int64Value(null) + PartiQLValueType.INT -> intValue(null) + PartiQLValueType.DECIMAL_ARBITRARY, PartiQLValueType.DECIMAL -> decimalValue(null) + else -> TODO("Unsupported target type $type") +} + +/** + * This is specifically for SUM/AVG + */ +@OptIn(PartiQLValueExperimental::class) +internal fun Number.toTargetType(type: PartiQLValueType): PartiQLValue = when (type) { + PartiQLValueType.ANY -> this.partiqlValue() + PartiQLValueType.FLOAT32 -> float32Value(this.toFloat()) + PartiQLValueType.FLOAT64 -> float64Value(this.toDouble()) + PartiQLValueType.DECIMAL, PartiQLValueType.DECIMAL_ARBITRARY -> { + when (this) { + is BigDecimal -> decimalValue(this) + is BigInteger -> decimalValue(this.toBigDecimal()) + else -> decimalValue(BigDecimal.valueOf(this.toDouble())) + } + } + PartiQLValueType.INT8 -> int8Value(this.toByte()) + PartiQLValueType.INT16 -> int16Value(this.toShort()) + PartiQLValueType.INT32 -> int32Value(this.toInt()) + PartiQLValueType.INT64 -> int64Value(this.toLong()) + PartiQLValueType.INT -> when (this) { + is BigInteger -> intValue(this) + is BigDecimal -> intValue(this.toBigInteger()) + else -> intValue(BigInteger.valueOf(this.toLong())) + } + else -> TODO("Unsupported target type $type") +} + +@OptIn(PartiQLValueExperimental::class) +internal fun Number.partiqlValue(): PartiQLValue = when (this) { + is Int -> int32Value(this) + is Long -> int64Value(this) + is Double -> float64Value(this) + is BigDecimal -> decimalValue(this) + is BigInteger -> intValue(this) + else -> TODO("Could not convert $this to PartiQL Value") +} + +private val MATH_CONTEXT = MathContext(38, RoundingMode.HALF_EVEN) + +/** + * Factory function to create a [BigDecimal] using correct precision, use it in favor of native BigDecimal constructors + * and factory methods + */ +internal fun bigDecimalOf(num: Number, mc: MathContext = MATH_CONTEXT): BigDecimal = when (num) { + is Decimal -> num + is Int -> BigDecimal(num, mc) + is Long -> BigDecimal(num, mc) + is Double -> BigDecimal(num, mc) + is BigDecimal -> num + Decimal.NEGATIVE_ZERO -> num as Decimal + else -> throw IllegalArgumentException("Unsupported number type: $num, ${num.javaClass}") +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorAnySome.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorAnySome.kt new file mode 100644 index 0000000000..fa82aa7d14 --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorAnySome.kt @@ -0,0 +1,19 @@ +package org.partiql.spi.connector.sql.builtins.internal + +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.boolValue +import org.partiql.value.nullValue + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorAnySome : Accumulator() { + + private var res: PartiQLValue? = null + + override fun nextValue(value: PartiQLValue) { + checkIsBooleanType("ANY/SOME", value) + res = res?.let { boolValue(it.booleanValue() || value.booleanValue()) } ?: value + } + + override fun value(): PartiQLValue = res ?: nullValue() +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorAvg.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorAvg.kt new file mode 100644 index 0000000000..2704075938 --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorAvg.kt @@ -0,0 +1,25 @@ +package org.partiql.spi.connector.sql.builtins.internal + +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorAvg( + private val targetType: PartiQLValueType = PartiQLValueType.ANY +) : Accumulator() { + + var sum: Number = 0.0 + var count: Long = 0L + + override fun nextValue(value: PartiQLValue) { + checkIsNumberType(funcName = "AVG", value = value) + this.sum += value.numberValue() + this.count += 1L + } + + override fun value(): PartiQLValue = when (count) { + 0L -> nullToTargetType(targetType) + else -> (sum / bigDecimalOf(count)).toTargetType(targetType) + } +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorCount.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorCount.kt new file mode 100644 index 0000000000..ec4d926b24 --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorCount.kt @@ -0,0 +1,17 @@ +package org.partiql.spi.connector.sql.builtins.internal + +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.int64Value + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorCount : Accumulator() { + + var count: Long = 0L + + override fun nextValue(value: PartiQLValue) { + this.count += 1L + } + + override fun value(): PartiQLValue = int64Value(count) +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorCountStar.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorCountStar.kt new file mode 100644 index 0000000000..82768fdb1b --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorCountStar.kt @@ -0,0 +1,19 @@ +package org.partiql.spi.connector.sql.builtins.internal + +import org.partiql.spi.fn.Agg +import org.partiql.spi.fn.FnExperimental +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.int64Value + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +internal class AccumulatorCountStar : Agg.Accumulator { + + var count: Long = 0L + + override fun next(args: Array) { + this.count += 1L + } + + override fun value(): PartiQLValue = int64Value(count) +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorEvery.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorEvery.kt new file mode 100644 index 0000000000..900fc8238f --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorEvery.kt @@ -0,0 +1,20 @@ +package org.partiql.spi.connector.sql.builtins.internal + +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.boolValue +import org.partiql.value.nullValue + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorEvery : Accumulator() { + + private var res: PartiQLValue? = null + + @OptIn(PartiQLValueExperimental::class) + override fun nextValue(value: PartiQLValue) { + checkIsBooleanType("EVERY", value) + res = res?.let { boolValue(it.booleanValue() && value.booleanValue()) } ?: value + } + + override fun value(): PartiQLValue = res ?: nullValue() +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorGroupAs.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorGroupAs.kt new file mode 100644 index 0000000000..30d1b88778 --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorGroupAs.kt @@ -0,0 +1,17 @@ +package org.partiql.spi.connector.sql.builtins.internal + +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.bagValue + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorGroupAs : Accumulator() { + + val values = mutableListOf() + + override fun nextValue(value: PartiQLValue) { + values.add(value) + } + + override fun value(): PartiQLValue = bagValue(values) +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorMax.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorMax.kt new file mode 100644 index 0000000000..dfce376ed2 --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorMax.kt @@ -0,0 +1,17 @@ +package org.partiql.spi.connector.sql.builtins.internal + +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.nullValue + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorMax : Accumulator() { + + var max: PartiQLValue = nullValue() + + override fun nextValue(value: PartiQLValue) { + max = comparisonAccumulator(PartiQLValue.comparator(nullsFirst = true).reversed())(max, value) + } + + override fun value(): PartiQLValue = max +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorMin.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorMin.kt new file mode 100644 index 0000000000..75c0972289 --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorMin.kt @@ -0,0 +1,17 @@ +package org.partiql.spi.connector.sql.builtins.internal + +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.nullValue + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorMin : Accumulator() { + + var min: PartiQLValue = nullValue() + + override fun nextValue(value: PartiQLValue) { + min = comparisonAccumulator(PartiQLValue.comparator(nullsFirst = false))(min, value) + } + + override fun value(): PartiQLValue = min +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorSum.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorSum.kt new file mode 100644 index 0000000000..a9405d9e4e --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/internal/AccumulatorSum.kt @@ -0,0 +1,25 @@ +package org.partiql.spi.connector.sql.builtins.internal + +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorSum( + private val targetType: PartiQLValueType = PartiQLValueType.ANY +) : Accumulator() { + + var sum: Number? = null + + @OptIn(PartiQLValueExperimental::class) + override fun nextValue(value: PartiQLValue) { + checkIsNumberType(funcName = "SUM", value = value) + if (sum == null) sum = 0L + this.sum = value.numberValue() + this.sum!! + } + + @OptIn(PartiQLValueExperimental::class) + override fun value(): PartiQLValue { + return sum?.toTargetType(targetType) ?: nullToTargetType(targetType) + } +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/info/InfoSchema.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/info/InfoSchema.kt index 57a058975c..a29053d85d 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/info/InfoSchema.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/info/InfoSchema.kt @@ -1,19 +1,23 @@ package org.partiql.spi.connector.sql.info import org.partiql.spi.connector.sql.SqlBuiltins +import org.partiql.spi.fn.Agg +import org.partiql.spi.fn.Fn import org.partiql.spi.fn.FnExperimental -import org.partiql.spi.fn.FnIndex +import org.partiql.spi.fn.Index /** * Provides the INFORMATION_SCHEMA views over internal database symbols. */ -public class InfoSchema( - public val functions: FnIndex, +public class InfoSchema @OptIn(FnExperimental::class) constructor( + public val functions: Index, + public val aggregations: Index ) { /** * INFORMATION_SCHEMA.ROUTINES */ + @OptIn(FnExperimental::class) private val routines: InfoView = InfoViewRoutines(functions) public fun get(table: String): InfoView? = when (table) { @@ -26,10 +30,13 @@ public class InfoSchema( @OptIn(FnExperimental::class) @JvmStatic public fun default(): InfoSchema { - val functions = FnIndex.builder() + val functions = Index.fnBuilder() .addAll(SqlBuiltins.builtins) .build() - return InfoSchema(functions) + val aggregations = Index.aggBuilder() + .addAll(SqlBuiltins.aggregations) + .build() + return InfoSchema(functions, aggregations) } } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/info/InfoViewRoutines.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/info/InfoViewRoutines.kt index 64e0fe29fa..157c3ab556 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/info/InfoViewRoutines.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/info/InfoViewRoutines.kt @@ -1,6 +1,8 @@ package org.partiql.spi.connector.sql.info -import org.partiql.spi.fn.FnIndex +import org.partiql.spi.fn.Fn +import org.partiql.spi.fn.FnExperimental +import org.partiql.spi.fn.Index import org.partiql.types.BagType import org.partiql.types.StaticType import org.partiql.types.StructType @@ -13,7 +15,7 @@ import org.partiql.value.nullValue /** * This provides the INFORMATION_SCHEMA.ROUTINES view for an [SqlConnector]. */ -internal class InfoViewRoutines(private val index: FnIndex) : InfoView { +internal class InfoViewRoutines @OptIn(FnExperimental::class) constructor(private val index: Index) : InfoView { override val schema: StaticType = BagType( elementType = StructType( diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/Agg.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/Agg.kt index 53b5b642e6..485ce2f0eb 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/Agg.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/Agg.kt @@ -33,7 +33,7 @@ public interface Agg { * @return */ @OptIn(PartiQLValueExperimental::class) - public fun next(args: Array): PartiQLValue + public fun next(args: Array) /** * Return the accumulator value. diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/FnIndex.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/FnIndex.kt deleted file mode 100644 index fdd1f98e01..0000000000 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/FnIndex.kt +++ /dev/null @@ -1,60 +0,0 @@ -package org.partiql.spi.fn - -import org.partiql.spi.connector.ConnectorPath - -/** - * Utility class for an optimized function lookup data structure. Right now this is read only. - */ -@OptIn(FnExperimental::class) -public interface FnIndex { - - /** - * Search for all functions matching the normalized path. - * - * @param path - * @return - */ - public fun get(path: List): List - - /** - * Lookup a function signature by its specific name. - * - * @param specific - * @return - */ - public fun get(path: ConnectorPath, specific: String): Fn? - - public class Builder { - - /** - * A catalog's builtins exposed via INFORMATION_SCHEMA. - */ - private val builtins: MutableList = mutableListOf() - - public fun add(fn: Fn): Builder = this.apply { - builtins.add(fn) - } - - public fun addAll(fns: List): Builder = this.apply { - builtins.addAll(fns) - } - - /** - * Creates a map of function name to variants; variants are keyed by their specific. - * - * @return - */ - public fun build(): FnIndex { - val fns = builtins - .groupBy { it.signature.name.uppercase() } - .mapValues { e -> e.value.associateBy { f -> f.signature.specific } } - return FnIndexMap(fns) - } - } - - public companion object { - - @JvmStatic - public fun builder(): Builder = Builder() - } -} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/Index.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/Index.kt new file mode 100644 index 0000000000..738f2ed901 --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/Index.kt @@ -0,0 +1,76 @@ +package org.partiql.spi.fn + +import org.partiql.spi.connector.ConnectorPath + +/** + * Utility class for an optimized function lookup data structure. Right now this is read only. + */ +@OptIn(FnExperimental::class) +public interface Index { + + /** + * Search for all functions matching the normalized path. + * + * @param path + * @return + */ + public fun get(path: List): List + + /** + * Lookup a function signature by its specific name. + * + * @param specific + * @return + */ + public fun get(path: ConnectorPath, specific: String): T? + + public abstract class Builder { + + /** + * A catalog's builtins exposed via INFORMATION_SCHEMA. + */ + internal val builtins: MutableList = mutableListOf() + + public fun add(fn: T): Builder = this.apply { + builtins.add(fn) + } + + public fun addAll(fns: List): Builder = this.apply { + builtins.addAll(fns) + } + + /** + * Creates a map of function name to variants; variants are keyed by their specific. + * + * @return + */ + public abstract fun build(): Index + + public class Fn : Builder() { + override fun build(): Index { + val fns = builtins + .groupBy { it.signature.name.uppercase() } + .mapValues { e -> e.value.associateBy { f -> f.signature.specific } } + return IndexMap(fns) + } + } + + public class Agg : Builder() { + override fun build(): Index { + val fns = builtins + .groupBy { it.signature.name.uppercase() } + .mapValues { e -> e.value.associateBy { f -> f.signature.specific } } + return IndexMap(fns) + } + } + } + + public companion object { + + @JvmStatic + public fun fnBuilder(): Builder = Builder.Fn() + + @JvmStatic + public fun aggBuilder(): Builder = Builder.Agg() + } +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/FnIndexMap.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/IndexMap.kt similarity index 54% rename from partiql-spi/src/main/kotlin/org/partiql/spi/fn/FnIndexMap.kt rename to partiql-spi/src/main/kotlin/org/partiql/spi/fn/IndexMap.kt index f583c1ade2..660f099663 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/FnIndexMap.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/IndexMap.kt @@ -3,20 +3,19 @@ package org.partiql.spi.fn import org.partiql.spi.connector.ConnectorPath /** - * An implementation of [FnIndex] which uses the normalized paths as map keys. + * An implementation of [Index] which uses the normalized paths as map keys. * * @property map */ -@OptIn(FnExperimental::class) -internal class FnIndexMap(private val map: Map>) : FnIndex { +internal class IndexMap(private val map: Map>) : Index { - override fun get(path: List): List { + override fun get(path: List): List { val key = path.joinToString(".") val variants = map[key] ?: emptyMap() return variants.values.toList() } - override fun get(path: ConnectorPath, specific: String): Fn? { + override fun get(path: ConnectorPath, specific: String): T? { val key = path.steps.joinToString(".") val variants = map[key] ?: emptyMap() return variants[specific] diff --git a/partiql-types/src/main/kotlin/org/partiql/errors/TypeCheckException.kt b/partiql-types/src/main/kotlin/org/partiql/errors/TypeCheckException.kt index f0e495fda1..c9942b21ff 100644 --- a/partiql-types/src/main/kotlin/org/partiql/errors/TypeCheckException.kt +++ b/partiql-types/src/main/kotlin/org/partiql/errors/TypeCheckException.kt @@ -3,7 +3,7 @@ package org.partiql.errors /** * A [TypeCheckException] represents an invalid operation due to argument types. */ -public class TypeCheckException : RuntimeException() +public class TypeCheckException(message: String? = null) : RuntimeException(message) /** * A [DataException] represents an unrecoverable query runtime exception. diff --git a/partiql-types/src/main/kotlin/org/partiql/value/PartiQLValue.kt b/partiql-types/src/main/kotlin/org/partiql/value/PartiQLValue.kt index 08a69c49b6..5256ded050 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/PartiQLValue.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/PartiQLValue.kt @@ -602,5 +602,5 @@ public fun PartiQLValue.toIon(): IonElement = accept(ToIon, Unit) @PartiQLValueExperimental @Throws(TypeCheckException::class) public inline fun PartiQLValue.check(): T { - if (this is T) return this else throw TypeCheckException() + if (this is T) return this else throw TypeCheckException("Expected ${T::class.java} but received $this.") } diff --git a/partiql-types/src/main/kotlin/org/partiql/value/util/NumberExtensions.kt b/partiql-types/src/main/kotlin/org/partiql/value/util/NumberExtensions.kt index cab5a82d2a..1cf34e2a64 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/util/NumberExtensions.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/util/NumberExtensions.kt @@ -16,6 +16,7 @@ package org.partiql.value.util import com.amazon.ion.Decimal import java.math.BigDecimal +import java.math.BigInteger import java.math.MathContext import java.math.RoundingMode @@ -38,6 +39,7 @@ internal fun bigDecimalOf(num: Number, mc: MathContext = MATH_CONTEXT): BigDecim is Long -> BigDecimal(num, mc) is Float -> BigDecimal(num.toDouble(), mc) is Double -> BigDecimal(num, mc) + is BigInteger -> BigDecimal(num, mc) is BigDecimal -> num else -> throw IllegalArgumentException("Unsupported number type: $num, ${num.javaClass}") } @@ -45,6 +47,9 @@ internal fun bigDecimalOf(num: Number, mc: MathContext = MATH_CONTEXT): BigDecim private val CONVERSION_MAP = mapOf>, Class>( setOf(Int::class.javaObjectType, Int::class.javaObjectType) to Int::class.javaObjectType, setOf(Int::class.javaObjectType, Long::class.javaObjectType) to Long::class.javaObjectType, + setOf(Int::class.javaObjectType, BigInteger::class.javaObjectType) to BigInteger::class.javaObjectType, + setOf(Long::class.javaObjectType, BigInteger::class.javaObjectType) to BigInteger::class.javaObjectType, + setOf(BigInteger::class.javaObjectType, BigInteger::class.javaObjectType) to BigInteger::class.javaObjectType, // Int w/ Float -> Double setOf(Int::class.javaObjectType, Float::class.javaObjectType) to Double::class.javaObjectType, setOf(Int::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, @@ -60,6 +65,9 @@ private val CONVERSION_MAP = mapOf>, Class>( setOf(Long::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, setOf(Long::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, + setOf(BigInteger::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, + setOf(BigInteger::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, + setOf(Double::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, setOf(Double::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, @@ -71,6 +79,12 @@ private val CONVERTERS = mapOf, (Number) -> Number>( Long::class.javaObjectType to Number::toLong, Float::class.javaObjectType to Number::toFloat, Double::class.javaObjectType to Number::toDouble, + BigInteger::class.javaObjectType to { num -> + when (num) { + is BigInteger -> num + else -> BigInteger.valueOf(num.toLong()) + } + }, BigDecimal::class.java to { num -> when (num) { is Int -> bigDecimalOf(num) @@ -78,8 +92,9 @@ private val CONVERTERS = mapOf, (Number) -> Number>( is Float -> bigDecimalOf(num) is Double -> bigDecimalOf(num) is BigDecimal -> bigDecimalOf(num) + is BigInteger -> bigDecimalOf(num) else -> throw IllegalArgumentException( - "Unsupported number for decimal conversion: $num" + "Unsupported number for decimal conversion: $num (${num.javaClass.simpleName})" ) } } @@ -92,7 +107,8 @@ internal fun Number.isZero() = when (this) { is Float -> this == 0.0f || this == -0.0f is Double -> this == 0.0 || this == -0.0 is BigDecimal -> BigDecimal.ZERO.compareTo(this) == 0 - else -> throw IllegalStateException("$this") + is BigInteger -> BigInteger.ZERO.compareTo(this) == 0 + else -> throw IllegalStateException("$this (${this.javaClass.simpleName})") } @Suppress("UNCHECKED_CAST") @@ -107,8 +123,9 @@ internal fun Number.coerce(type: Class): T where T : Number { * compatible type. * * This is only supported on limited types needed by the expression system. + * TODO: Make no longer public. */ -internal fun coerceNumbers(first: Number, second: Number): Pair { +public fun coerceNumbers(first: Number, second: Number): Pair { fun typeFor(n: Number): Class<*> = if (n is Decimal) { BigDecimal::class.javaObjectType } else { @@ -129,6 +146,7 @@ internal operator fun Number.compareTo(other: Number): Int { is Float -> first.compareTo(second as Float) is Double -> first.compareTo(second as Double) is BigDecimal -> first.compareTo(second as BigDecimal) + is BigInteger -> first.compareTo(second as BigInteger) else -> throw IllegalStateException() } } diff --git a/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/LocalConnector.kt b/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/LocalConnector.kt index a2a32567a0..9072c659c8 100644 --- a/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/LocalConnector.kt +++ b/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/LocalConnector.kt @@ -17,6 +17,7 @@ package org.partiql.plugins.local import com.amazon.ionelement.api.StructElement import org.partiql.spi.BindingPath import org.partiql.spi.connector.Connector +import org.partiql.spi.connector.ConnectorAggProvider import org.partiql.spi.connector.ConnectorBindings import org.partiql.spi.connector.ConnectorFnProvider import org.partiql.spi.connector.ConnectorHandle @@ -71,6 +72,11 @@ public class LocalConnector( TODO("Not yet implemented") } + @FnExperimental + override fun getAggregations(): ConnectorAggProvider { + TODO("Not yet implemented") + } + internal class Factory : Connector.Factory { private val default: Path = Paths.get(System.getProperty("user.home")).resolve(".partiql/local") @@ -113,6 +119,9 @@ public class LocalConnector( TODO("Not yet implemented") } + @FnExperimental + override fun getAggregation(path: BindingPath): ConnectorHandle.Agg? = null + internal fun listObjects(): List = catalog.listObjects() } }