diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt index 8f512b2866..4906d2b800 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt @@ -2,6 +2,15 @@ package org.partiql.eval.internal import org.partiql.eval.PartiQLEngine import org.partiql.eval.internal.operator.Operator +import org.partiql.eval.internal.operator.agg.AccumulatorAnySome +import org.partiql.eval.internal.operator.agg.AccumulatorAvg +import org.partiql.eval.internal.operator.agg.AccumulatorCount +import org.partiql.eval.internal.operator.agg.AccumulatorEvery +import org.partiql.eval.internal.operator.agg.AccumulatorGroupAs +import org.partiql.eval.internal.operator.agg.AccumulatorMax +import org.partiql.eval.internal.operator.agg.AccumulatorMin +import org.partiql.eval.internal.operator.agg.AccumulatorSum +import org.partiql.eval.internal.operator.rel.RelAggregate import org.partiql.eval.internal.operator.rel.RelDistinct import org.partiql.eval.internal.operator.rel.RelExclude import org.partiql.eval.internal.operator.rel.RelFilter @@ -146,7 +155,33 @@ internal class Compiler( return ExprLocal(node.ref) } - override fun visitRexOpVarGlobal(node: Rex.Op.Var.Global, ctx: StaticType?): Operator = symbols.getGlobal(node.ref) + override fun visitRelOpAggregate(node: Rel.Op.Aggregate, ctx: StaticType?): Operator.Relation { + val input = visitRel(node.input, ctx) + val calls = node.calls.map { visitRelOpAggregateCall(it, ctx) } + val groups = node.groups.map { visitRex(it, ctx).modeHandled() } + return RelAggregate(input, groups, calls) + } + + @OptIn(PartiQLValueExperimental::class) + override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: StaticType?): Operator.Accumulator { + val args = node.args.map { visitRex(it, it.type).modeHandled() } // TODO: Should we support multiple arguments? + val setQuantifier: Operator.Accumulator.SetQuantifier = when (node.setQuantifier) { + Rel.Op.Aggregate.Call.SetQuantifier.ALL -> Operator.Accumulator.SetQuantifier.ALL + Rel.Op.Aggregate.Call.SetQuantifier.DISTINCT -> Operator.Accumulator.SetQuantifier.DISTINCT + } + return when (node.agg.uppercase()) { + "MIN" -> AccumulatorMin.Factory(args, setQuantifier) + "MAX" -> AccumulatorMax.Factory(args, setQuantifier) + "AVG" -> AccumulatorAvg.Factory(args, setQuantifier) + "COUNT" -> AccumulatorCount.Factory(args, setQuantifier) + "SUM" -> AccumulatorSum.Factory(args, setQuantifier) + "GROUP_AS" -> AccumulatorGroupAs.Factory(args, setQuantifier) + "EVERY" -> AccumulatorEvery.Factory(args, setQuantifier) + "ANY" -> AccumulatorAnySome.Factory(args, setQuantifier) + "SOME" -> AccumulatorAnySome.Factory(args, setQuantifier) + else -> error("Unexpected aggregation: ${node.agg}.") + } + } override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: StaticType?): Operator { val root = visitRex(node.root, ctx) diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/Operator.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/Operator.kt index ade5920f07..77fbdbb093 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/Operator.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/Operator.kt @@ -26,4 +26,31 @@ internal sealed interface Operator { override fun close() } + + interface Accumulator : Operator { + + val setQuantifier: SetQuantifier + fun create(): Instance + + interface Instance { + + /** + * The argument to invoke. + */ + val args: List + + /** Accumulates the next value into this [Instance]. */ + @OptIn(PartiQLValueExperimental::class) + fun next(value: PartiQLValue) + + /** Digests the result of the accumulated values. */ + @OptIn(PartiQLValueExperimental::class) + fun compute(): PartiQLValue + } + + enum class SetQuantifier { + ALL, + DISTINCT + } + } } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/Accumulator.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/Accumulator.kt new file mode 100644 index 0000000000..c66cfa5f12 --- /dev/null +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/Accumulator.kt @@ -0,0 +1,142 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +@file:OptIn(PartiQLValueExperimental::class) + +package org.partiql.eval.internal.operator.agg + +import com.amazon.ion.Decimal +import org.partiql.eval.internal.operator.Operator +import org.partiql.value.BoolValue +import org.partiql.value.DecimalValue +import org.partiql.value.Float32Value +import org.partiql.value.Float64Value +import org.partiql.value.Int16Value +import org.partiql.value.Int32Value +import org.partiql.value.Int64Value +import org.partiql.value.Int8Value +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType +import org.partiql.value.decimalValue +import org.partiql.value.float64Value +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import java.math.BigDecimal +import java.math.MathContext +import java.math.RoundingMode + +internal abstract class Accumulator : Operator.Accumulator.Instance { + + /** Accumulates the next value into this [Accumulator]. */ + @OptIn(PartiQLValueExperimental::class) + override fun next(value: PartiQLValue) { + if (value.isUnknown()) return + nextValue(value) + } + + abstract fun nextValue(value: PartiQLValue) +} + +// TODO: Make this better +@OptIn(PartiQLValueExperimental::class) +internal fun comparisonAccumulator(comparator: Comparator): (PartiQLValue?, PartiQLValue) -> PartiQLValue = + { left, right -> + when { + left == null || comparator.compare(left, right) > 0 -> right + else -> left + } + } + +// TODO: Make this better +@OptIn(PartiQLValueExperimental::class) +internal fun checkIsNumberType(funcName: String, value: PartiQLValue) { + if (!value.type.isNumber()) { + TODO("NEED TO HANDLE") + } +} + +// TODO: Make this better +@OptIn(PartiQLValueExperimental::class) +internal fun checkIsBooleanType(funcName: String, value: PartiQLValue) { + if (value.type != PartiQLValueType.BOOL) { + TODO("NEED TO HANDLE") + } +} + +// TODO: Make this better +@OptIn(PartiQLValueExperimental::class) +internal fun PartiQLValue.isUnknown(): Boolean = this.type == PartiQLValueType.MISSING || this.isNull + +// TODO: Make this better +@OptIn(PartiQLValueExperimental::class) +internal fun PartiQLValue.numberValue(): Number = when (this) { + is Int8Value -> this.value!! + is Int16Value -> this.value!! + is Int32Value -> this.value!! + is Int64Value -> this.value!! + is DecimalValue -> this.value!! + is Float32Value -> this.value!! + is Float64Value -> this.value!! + else -> error("Cannot convert PartiQLValue ($this) to number.") +} + +// TODO: Make this better +@OptIn(PartiQLValueExperimental::class) +internal fun PartiQLValue.booleanValue(): Boolean = when (this) { + is BoolValue -> this.value!! + else -> error("Cannot convert PartiQLValue ($this) to boolean.") +} + +// TODO: Make this better +@OptIn(PartiQLValueExperimental::class) +internal fun PartiQLValueType.isNumber(): Boolean = when (this) { + PartiQLValueType.INT, + PartiQLValueType.INT8, + PartiQLValueType.INT16, + PartiQLValueType.INT32, + PartiQLValueType.INT64, + PartiQLValueType.DECIMAL, + PartiQLValueType.FLOAT32, + PartiQLValueType.FLOAT64 -> true + else -> false +} + +// TODO: Make this better +@OptIn(PartiQLValueExperimental::class) +internal fun Number.partiqlValue(): PartiQLValue = when (this) { + is Int -> int32Value(this) + is Long -> int64Value(this) + is Double -> float64Value(this) + is BigDecimal -> decimalValue(this) + else -> TODO("Error context") +} + +// TODO: Make this better +private val MATH_CONTEXT = MathContext(38, RoundingMode.HALF_EVEN) + +// TODO: Make this better +/** + * Factory function to create a [BigDecimal] using correct precision, use it in favor of native BigDecimal constructors + * and factory methods + */ +internal fun bigDecimalOf(num: Number, mc: MathContext = MATH_CONTEXT): BigDecimal = when (num) { + is Decimal -> num + is Int -> BigDecimal(num, mc) + is Long -> BigDecimal(num, mc) + is Double -> BigDecimal(num, mc) + is BigDecimal -> num + Decimal.NEGATIVE_ZERO -> num as Decimal + else -> throw IllegalArgumentException("Unsupported number type: $num, ${num.javaClass}") +} diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorAnySome.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorAnySome.kt new file mode 100644 index 0000000000..31f45c5809 --- /dev/null +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorAnySome.kt @@ -0,0 +1,29 @@ +package org.partiql.eval.internal.operator.agg + +import org.partiql.eval.internal.operator.Operator +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.boolValue +import org.partiql.value.nullValue + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorAnySome( + override val args: List +) : Accumulator() { + + private var res: PartiQLValue? = null + + override fun nextValue(value: PartiQLValue) { + checkIsBooleanType("ANY/SOME", value) + res = res?.let { boolValue(it.booleanValue() || value.booleanValue()) } ?: value + } + + override fun compute(): PartiQLValue = res ?: nullValue() + + class Factory( + val args: List, + override val setQuantifier: Operator.Accumulator.SetQuantifier + ) : Operator.Accumulator { + override fun create(): Operator.Accumulator.Instance = AccumulatorAnySome(args) + } +} diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorAvg.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorAvg.kt new file mode 100644 index 0000000000..1c13d1b59f --- /dev/null +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorAvg.kt @@ -0,0 +1,35 @@ +package org.partiql.eval.internal.operator.agg + +import org.partiql.eval.internal.operator.Operator +import org.partiql.lang.util.div +import org.partiql.lang.util.plus +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.nullValue + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorAvg( + override val args: List +) : Accumulator() { + + var sum: Number = 0.0 + var count: Long = 0L + + override fun nextValue(value: PartiQLValue) { + checkIsNumberType(funcName = "AVG", value = value) + this.sum += value.numberValue() + this.count += 1L + } + + override fun compute(): PartiQLValue = when (count) { + 0L -> nullValue() + else -> (sum / bigDecimalOf(count)).partiqlValue() + } + + class Factory( + val args: List, + override val setQuantifier: Operator.Accumulator.SetQuantifier + ) : Operator.Accumulator { + override fun create(): Operator.Accumulator.Instance = AccumulatorAvg(args) + } +} diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorCount.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorCount.kt new file mode 100644 index 0000000000..4385ffe8f0 --- /dev/null +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorCount.kt @@ -0,0 +1,27 @@ +package org.partiql.eval.internal.operator.agg + +import org.partiql.eval.internal.operator.Operator +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.int64Value + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorCount( + override val args: List +) : Accumulator() { + + var count: Long = 0L + + override fun nextValue(value: PartiQLValue) { + this.count += 1L + } + + override fun compute(): PartiQLValue = int64Value(count) + + class Factory( + val args: List, + override val setQuantifier: Operator.Accumulator.SetQuantifier + ) : Operator.Accumulator { + override fun create(): Operator.Accumulator.Instance = AccumulatorCount(args) + } +} diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorEvery.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorEvery.kt new file mode 100644 index 0000000000..574a287599 --- /dev/null +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorEvery.kt @@ -0,0 +1,30 @@ +package org.partiql.eval.internal.operator.agg + +import org.partiql.eval.internal.operator.Operator +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.boolValue +import org.partiql.value.nullValue + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorEvery( + override val args: List +) : Accumulator() { + + private var res: PartiQLValue? = null + + @OptIn(PartiQLValueExperimental::class) + override fun nextValue(value: PartiQLValue) { + checkIsBooleanType("EVERY", value) + res = res?.let { boolValue(it.booleanValue() && value.booleanValue()) } ?: value + } + + override fun compute(): PartiQLValue = res ?: nullValue() + + class Factory( + val args: List, + override val setQuantifier: Operator.Accumulator.SetQuantifier + ) : Operator.Accumulator { + override fun create(): Operator.Accumulator.Instance = AccumulatorEvery(args) + } +} diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorGroupAs.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorGroupAs.kt new file mode 100644 index 0000000000..43293f0c2a --- /dev/null +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorGroupAs.kt @@ -0,0 +1,27 @@ +package org.partiql.eval.internal.operator.agg + +import org.partiql.eval.internal.operator.Operator +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.bagValue + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorGroupAs( + override val args: List +) : Accumulator() { + + val values = mutableListOf() + + override fun nextValue(value: PartiQLValue) { + values.add(value) + } + + override fun compute(): PartiQLValue = bagValue(values) + + class Factory( + val args: List, + override val setQuantifier: Operator.Accumulator.SetQuantifier + ) : Operator.Accumulator { + override fun create(): Operator.Accumulator.Instance = AccumulatorGroupAs(args) + } +} diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorMax.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorMax.kt new file mode 100644 index 0000000000..6dd3c70bef --- /dev/null +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorMax.kt @@ -0,0 +1,27 @@ +package org.partiql.eval.internal.operator.agg + +import org.partiql.eval.internal.operator.Operator +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.nullValue + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorMax( + override val args: List +) : Accumulator() { + + var max: PartiQLValue = nullValue() + + override fun nextValue(value: PartiQLValue) { + max = comparisonAccumulator(PartiQLValue.comparator(nullsFirst = false).reversed())(max, value) + } + + override fun compute(): PartiQLValue = max + + class Factory( + val args: List, + override val setQuantifier: Operator.Accumulator.SetQuantifier + ) : Operator.Accumulator { + override fun create(): Operator.Accumulator.Instance = AccumulatorMax(args) + } +} diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorMin.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorMin.kt new file mode 100644 index 0000000000..c3a1bbc77b --- /dev/null +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorMin.kt @@ -0,0 +1,27 @@ +package org.partiql.eval.internal.operator.agg + +import org.partiql.eval.internal.operator.Operator +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.nullValue + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorMin( + override val args: List +) : Accumulator() { + + var min: PartiQLValue = nullValue() + + override fun nextValue(value: PartiQLValue) { + min = comparisonAccumulator(PartiQLValue.comparator(nullsFirst = false))(min, value) + } + + override fun compute(): PartiQLValue = min + + class Factory( + val args: List, + override val setQuantifier: Operator.Accumulator.SetQuantifier + ) : Operator.Accumulator { + override fun create(): Operator.Accumulator.Instance = AccumulatorMin(args) + } +} diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorSum.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorSum.kt new file mode 100644 index 0000000000..780714f640 --- /dev/null +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/agg/AccumulatorSum.kt @@ -0,0 +1,34 @@ +package org.partiql.eval.internal.operator.agg + +import org.partiql.eval.internal.operator.Operator +import org.partiql.lang.util.plus +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.nullValue + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorSum( + override val args: List +) : Accumulator() { + + var sum: Number? = null + + @OptIn(PartiQLValueExperimental::class) + override fun nextValue(value: PartiQLValue) { + checkIsNumberType(funcName = "SUM", value = value) + if (sum == null) sum = 0L + this.sum = value.numberValue() + this.sum!! + } + + @OptIn(PartiQLValueExperimental::class) + override fun compute(): PartiQLValue { + return sum?.partiqlValue() ?: nullValue() + } + + class Factory( + val args: List, + override val setQuantifier: Operator.Accumulator.SetQuantifier + ) : Operator.Accumulator { + override fun create(): Operator.Accumulator.Instance = AccumulatorSum(args) + } +} diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelAggregate.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelAggregate.kt new file mode 100644 index 0000000000..3dc3ad7205 --- /dev/null +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelAggregate.kt @@ -0,0 +1,94 @@ +package org.partiql.eval.internal.operator.rel + +import org.partiql.eval.internal.Record +import org.partiql.eval.internal.operator.Operator +import org.partiql.value.ListValue +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType +import org.partiql.value.listValue +import org.partiql.value.nullValue +import java.util.TreeMap +import java.util.TreeSet + +internal class RelAggregate( + val input: Operator.Relation, + val keys: List, + val functions: List +) : Operator.Relation { + + lateinit var records: Iterator + + @OptIn(PartiQLValueExperimental::class) + val aggregationMap = TreeMap>(PartiQLValue.comparator()) + + @OptIn(PartiQLValueExperimental::class) + val seen: List?> = functions.map { function -> + when (function.setQuantifier) { + Operator.Accumulator.SetQuantifier.DISTINCT -> TreeSet(PartiQLValue.comparator()) + Operator.Accumulator.SetQuantifier.ALL -> null + } + } + + @OptIn(PartiQLValueExperimental::class) + override fun open() { + input.open() + var inputRecord = input.next() + while (inputRecord != null) { + // Initialize the AggregationMap + val evaluatedGroupByKeys = listValue( + keys.map { + val key = it.eval(inputRecord!!) + when (key.type == PartiQLValueType.MISSING) { + true -> nullValue() + false -> key + } + } + ) + val accumulators = aggregationMap.getOrPut(evaluatedGroupByKeys) { + functions.map { it.create() } + } + + // Aggregate Values in Aggregation State + accumulators.forEachIndexed { index, function -> + val valueToAggregate = function.args.map { it.eval(inputRecord!!) }.first() // TODO: should we handle multiple arguments? + // If ALL OR (DISTINCT and not seen) + if (seen[index] == null || (seen[index]!!.add(valueToAggregate))) { + accumulators[index].next(valueToAggregate) + } + } + inputRecord = input.next() + } + + // No Aggregations Created // TODO: How would this be possible? + if (keys.isEmpty() && aggregationMap.isEmpty()) { + val record = mutableListOf() + functions.forEach { function -> + val accumulator = function.create() + record.add(accumulator.compute()) + } + records = iterator { yield(Record.of(*record.toTypedArray())) } + return + } + + records = iterator { + aggregationMap.forEach { (pValue, accumulators) -> + val keysEvaluated = pValue as ListValue<*> + val recordValues = accumulators.map { acc -> acc.compute() } + keysEvaluated.map { value -> value } + yield(Record.of(*recordValues.toTypedArray())) + } + } + } + + override fun next(): Record? { + return if (records.hasNext()) { + records.next() + } else { + null + } + } + + override fun close() { + input.close() + } +} diff --git a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEngineDefaultTest.kt b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEngineDefaultTest.kt index 9105bed264..8c11f6e432 100644 --- a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEngineDefaultTest.kt +++ b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEngineDefaultTest.kt @@ -9,7 +9,7 @@ import org.junit.jupiter.params.provider.MethodSource import org.partiql.eval.PartiQLEngine import org.partiql.eval.PartiQLResult import org.partiql.parser.PartiQLParser -import org.partiql.plan.PlanNode +import org.partiql.plan.PartiQLPlan import org.partiql.plan.debug.PlanPrinter import org.partiql.planner.PartiQLPlanner import org.partiql.planner.PartiQLPlannerBuilder @@ -54,6 +54,33 @@ class PartiQLEngineDefaultTest { @MethodSource("subqueryTestCases") @Execution(ExecutionMode.CONCURRENT) fun subqueryTests(tc: SuccessTestCase) = tc.assert() + @MethodSource("aggregationTestCases") + @Execution(ExecutionMode.CONCURRENT) + fun aggregationTests(tc: SuccessTestCase) = tc.assert() + + @Test + fun singleTest() { + SuccessTestCase( + input = """ + SELECT col1, g + FROM [{ 'col1':1 }, { 'col1':1 }] simple_1_col_1_group + GROUP BY col1 GROUP AS g + """.trimIndent(), + expected = bagValue( + structValue( + "col1" to int32Value(1), + "g" to bagValue( + structValue( + "simple_1_col_1_group" to structValue("col1" to int32Value(1)) + ), + structValue( + "simple_1_col_1_group" to structValue("col1" to int32Value(1)) + ), + ) + ), + ) + ).assert() + } companion object { @@ -314,6 +341,96 @@ class PartiQLEngineDefaultTest { ) ) + @JvmStatic + fun aggregationTestCases() = kotlin.collections.listOf( + SuccessTestCase( + input = """ + SELECT + gk_0, SUM(t.c) AS t_c_sum + FROM << + { 'b': NULL, 'c': 1 }, + { 'b': MISSING, 'c': 2 }, + { 'b': 1, 'c': 1 }, + { 'b': 1, 'c': 2 }, + { 'b': 2, 'c': NULL }, + { 'b': 2, 'c': 2 }, + { 'b': 3, 'c': MISSING }, + { 'b': 3, 'c': 2 }, + { 'b': 4, 'c': MISSING }, + { 'b': 4, 'c': NULL } + >> AS t GROUP BY t.b AS gk_0; + """.trimIndent(), + expected = org.partiql.value.bagValue( + org.partiql.value.structValue( + "gk_0" to org.partiql.value.nullValue(), + "t_c_sum" to org.partiql.value.int64Value(3) + ), + org.partiql.value.structValue( + "gk_0" to org.partiql.value.int32Value(1), + "t_c_sum" to org.partiql.value.int64Value(3) + ), + org.partiql.value.structValue( + "gk_0" to org.partiql.value.int32Value(2), + "t_c_sum" to org.partiql.value.int64Value(2) + ), + org.partiql.value.structValue( + "gk_0" to org.partiql.value.int32Value(3), + "t_c_sum" to org.partiql.value.int64Value(2) + ), + org.partiql.value.structValue( + "gk_0" to org.partiql.value.int32Value(4), + "t_c_sum" to org.partiql.value.nullValue() + ), + ), + mode = org.partiql.eval.PartiQLEngine.Mode.PERMISSIVE + ), + SuccessTestCase( + input = """ + SELECT VALUE { 'sensor': sensor, + 'readings': (SELECT VALUE v.l.co FROM g AS v) + } + FROM [{'sensor':1, 'co':0.4}, {'sensor':1, 'co':0.2}, {'sensor':2, 'co':0.3}] AS l + GROUP BY l.sensor AS sensor GROUP AS g + """.trimIndent(), + expected = org.partiql.value.bagValue( + org.partiql.value.structValue( + "sensor" to org.partiql.value.int32Value(1), + "readings" to org.partiql.value.listValue( + org.partiql.value.decimalValue(0.4.toBigDecimal()), + org.partiql.value.decimalValue(0.2.toBigDecimal()) + ) + ), + org.partiql.value.structValue( + "sensor" to org.partiql.value.int32Value(2), + "readings" to org.partiql.value.listValue( + org.partiql.value.decimalValue(2.toBigDecimal()), + org.partiql.value.decimalValue(0.3.toBigDecimal()) + ) + ), + ) + ), + SuccessTestCase( + input = """ + SELECT col1, g + FROM [{ 'col1':1 }, { 'col1':1 }] simple_1_col_1_group + GROUP BY col1 GROUP AS g + """.trimIndent(), + expected = bagValue( + structValue( + "col1" to int32Value(1), + "g" to bagValue( + structValue( + "simple_1_col_1_group" to structValue("col1" to int32Value(1)) + ), + structValue( + "simple_1_col_1_group" to structValue("col1" to int32Value(1)) + ), + ) + ), + ) + ), + ) + @JvmStatic fun sanityTestsCases() = listOf( SuccessTestCase( @@ -938,11 +1055,11 @@ class PartiQLEngineDefaultTest { } } val output = result.value - assertEquals(expected, output, comparisonString(plan.plan, expected, output)) + assertEquals(expected, output, comparisonString(expected, output, plan.plan)) } @OptIn(PartiQLValueExperimental::class) - private fun comparisonString(plan: PlanNode, expected: PartiQLValue, actual: PartiQLValue): String { + private fun comparisonString(expected: PartiQLValue, actual: PartiQLValue, plan: PartiQLPlan): String { val expectedBuffer = ByteArrayOutputStream() val expectedWriter = PartiQLValueIonWriterBuilder.standardIonTextBuilder().build(expectedBuffer) expectedWriter.append(expected) diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/util/NumberExtensions.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/util/NumberExtensions.kt index a83a094c4c..c9c459c300 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/util/NumberExtensions.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/util/NumberExtensions.kt @@ -46,6 +46,7 @@ fun bigDecimalOf(num: Number, mc: MathContext = MATH_CONTEXT): BigDecimal = when fun bigDecimalOf(text: String, mc: MathContext = MATH_CONTEXT): BigDecimal = BigDecimal(text.trim(), mc) private val CONVERSION_MAP = mapOf>, Class>( + setOf(Int::class.javaObjectType, Long::class.javaObjectType) to Long::class.javaObjectType, setOf(Long::class.javaObjectType, Long::class.javaObjectType) to Long::class.javaObjectType, setOf(Long::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, setOf(Long::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, diff --git a/partiql-plan/src/main/resources/partiql_plan.ion b/partiql-plan/src/main/resources/partiql_plan.ion index ce5e5406e3..184cb9e458 100644 --- a/partiql-plan/src/main/resources/partiql_plan.ion +++ b/partiql-plan/src/main/resources/partiql_plan.ion @@ -297,6 +297,7 @@ rel::{ _: [ call::{ agg: string, + set_quantifier: [ ALL, DISTINCT ], args: list::[rex], }, ], diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt index d124259852..11080de859 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt @@ -130,7 +130,7 @@ internal class Env(private val session: PartiQLPlanner.Session) { } @OptIn(FnExperimental::class, PartiQLValueExperimental::class) - fun resolveAgg(name: String, args: List): Rel.Op.Aggregate.Call.Resolved? { + fun resolveAgg(name: String, setQuantifier: Rel.Op.Aggregate.SetQuantifier, args: List): Rel.Op.Aggregate.Call.Resolved? { val match = aggs.resolve(name, args) ?: return null val agg = match.first val mapping = match.second @@ -143,7 +143,7 @@ internal class Env(private val session: PartiQLPlanner.Session) { else -> rex(cast.target.toStaticType(), rexOpCastResolved(cast, arg)) } } - return relOpAggregateCallResolved(ref, coercions) + return relOpAggregateCallResolved(ref, setQuantifier, coercions) } @OptIn(PartiQLValueExperimental::class) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathResolverAgg.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathResolverAgg.kt index 4a711306ae..fb650bcd0c 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathResolverAgg.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathResolverAgg.kt @@ -16,6 +16,7 @@ import org.partiql.spi.connector.sql.builtins.Agg_AVG__INT__INT import org.partiql.spi.connector.sql.builtins.Agg_COUNT_STAR____INT32 import org.partiql.spi.connector.sql.builtins.Agg_COUNT__ANY__INT32 import org.partiql.spi.connector.sql.builtins.Agg_EVERY__BOOL__BOOL +import org.partiql.spi.connector.sql.builtins.Agg_GROUP_AS__ANY__ANY import org.partiql.spi.connector.sql.builtins.Agg_MAX__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY import org.partiql.spi.connector.sql.builtins.Agg_MAX__FLOAT32__FLOAT32 import org.partiql.spi.connector.sql.builtins.Agg_MAX__FLOAT64__FLOAT64 @@ -105,6 +106,7 @@ internal object PathResolverAgg { Agg_SUM__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY, Agg_SUM__FLOAT32__FLOAT32, Agg_SUM__FLOAT64__FLOAT64, + Agg_GROUP_AS__ANY__ANY ).map { it.signature }.groupBy { it.name } fun resolve(name: String, args: List): Pair>? { @@ -148,7 +150,7 @@ internal object PathResolverAgg { for (i in args.indices) { val a = args[i] val p = parameters[i] - if (a != p.type) return false + if (p.type != ANY && a != p.type) return false } return true } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt index 86b931928b..5f0b10f7e1 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt @@ -1133,6 +1133,7 @@ internal data class Rel( internal data class Unresolved( @JvmField internal val name: String, + @JvmField internal val setQuantifier: SetQuantifier, @JvmField internal val args: List, ) : Call() { public override val children: List by lazy { @@ -1153,6 +1154,7 @@ internal data class Rel( internal data class Resolved( @JvmField internal val agg: Ref.Agg, + @JvmField internal val setQuantifier: SetQuantifier, @JvmField internal val args: List, ) : Call() { public override val children: List by lazy { @@ -1172,6 +1174,10 @@ internal data class Rel( } } + internal enum class SetQuantifier { + ALL, DISTINCT, + } + internal companion object { @JvmStatic internal fun builder(): RelOpAggregateBuilder = RelOpAggregateBuilder() diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt index 6c7ea11f27..798510d72c 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt @@ -349,9 +349,12 @@ internal object PlanTransform { } override fun visitRelOpAggregateCallResolved(node: Rel.Op.Aggregate.Call.Resolved, ctx: Unit): PlanNode { - val agg = node.agg.name val args = node.args.map { visitRex(it, ctx) } - return org.partiql.plan.relOpAggregateCall(node.agg.name, args) + val setQuantifier = when (node.setQuantifier) { + Rel.Op.Aggregate.SetQuantifier.ALL -> org.partiql.plan.Rel.Op.Aggregate.Call.SetQuantifier.ALL + Rel.Op.Aggregate.SetQuantifier.DISTINCT -> org.partiql.plan.Rel.Op.Aggregate.Call.SetQuantifier.DISTINCT + } + return org.partiql.plan.relOpAggregateCall(node.agg.name, setQuantifier, args) } override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: Unit) = org.partiql.plan.Rel.Op.Exclude( diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt index a4bbe088f5..4d9c67f4a6 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt @@ -66,10 +66,13 @@ import org.partiql.planner.internal.ir.rex import org.partiql.planner.internal.ir.rexOpLit import org.partiql.planner.internal.ir.rexOpPivot import org.partiql.planner.internal.ir.rexOpSelect +import org.partiql.planner.internal.ir.rexOpStruct +import org.partiql.planner.internal.ir.rexOpStructField import org.partiql.planner.internal.ir.rexOpVarLocal import org.partiql.types.StaticType import org.partiql.value.PartiQLValueExperimental import org.partiql.value.boolValue +import org.partiql.value.stringValue import org.partiql.planner.internal.ir.Identifier as InternalId /** @@ -335,16 +338,13 @@ internal object RelConverter { * 1. Ast.Expr.SFW has every Ast.Expr.CallAgg replaced by a synthetic Ast.Expr.Var * 2. Rel which has the appropriate Rex.Agg calls and groups */ + @OptIn(PartiQLValueExperimental::class) private fun convertAgg(input: Rel, select: Expr.SFW, groupBy: GroupBy?): Pair { // Rewrite and extract all aggregations in the SELECT clause val (sel, aggregations) = AggregationTransform.apply(select) // No aggregation planning required for GROUP BY - if (aggregations.isEmpty()) { - if (groupBy != null) { - // GROUP BY with no aggregations is considered an error. - error("GROUP BY with no aggregations in SELECT clause") - } + if (aggregations.isEmpty() && groupBy == null) { return Pair(select, input) } @@ -366,7 +366,34 @@ internal object RelConverter { is InternalId.Qualified -> error("Qualified aggregation calls are not supported.") is InternalId.Symbol -> id.symbol.lowercase() } - relOpAggregateCallUnresolved(name, args) + val setq = when (expr.setq) { + null -> Rel.Op.Aggregate.SetQuantifier.ALL + SetQuantifier.ALL -> Rel.Op.Aggregate.SetQuantifier.ALL + SetQuantifier.DISTINCT -> Rel.Op.Aggregate.SetQuantifier.DISTINCT + } + relOpAggregateCallUnresolved(name, setq, args) + }.toMutableList() + + // Add GROUP_AS aggregation + groupBy?.let { gb -> + gb.asAlias?.let { groupAs -> + val binding = relBinding(groupAs.symbol, StaticType.ANY) + schema.add(binding) + val arg = listOf( + rex( + StaticType.ANY, + rexOpStruct( + listOf( + rexOpStructField( + k = rex(StaticType.STRING, rexOpLit(stringValue(input.type.schema[0].name))), + v = rex(StaticType.ANY, rexOpVarLocal(0)) + ) + ) + ) + ), + ) + calls.add(relOpAggregateCallUnresolved("group_as", Rel.Op.Aggregate.SetQuantifier.ALL, arg)) + } } var groups = emptyList() if (groupBy != null) { @@ -563,6 +590,12 @@ internal object RelConverter { fun apply(node: Expr.SFW): Pair> { val aggs = mutableListOf() val select = super.visitExprSFW(node, aggs) as Expr.SFW +// // Add GROUP_AS aggregation +// node.groupBy?.let { groupBy -> +// groupBy.asAlias?.let { groupAs -> +// aggs.add(exprAgg(identifierSymbol("GROUP_AS", Identifier.CaseSensitivity.SENSITIVE), emptyList(), null)) +// } +// } return Pair(select, aggs) } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index dd7f3852ef..ca19f4c3b6 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -1223,7 +1223,7 @@ internal class PlanTyper( return node to ANY } else if (arg.type is MissingType) { handleAlwaysMissing() - return relOpAggregateCallUnresolved(node.name, listOf(rexErr("MISSING"))) to MissingType + return relOpAggregateCallUnresolved(node.name, node.setQuantifier, listOf(rexErr("MISSING"))) to MissingType } else if (arg.type.isMissable()) { isMissable = true } @@ -1231,7 +1231,7 @@ internal class PlanTyper( } // Resolve the function - val call = env.resolveAgg(node.name, args) + val call = env.resolveAgg(node.name, node.setQuantifier, args) if (call == null) { handleUnknownAggregation(node) return node to ANY diff --git a/partiql-planner/src/main/resources/partiql_plan_internal.ion b/partiql-planner/src/main/resources/partiql_plan_internal.ion index c7d7f091a2..baccf62f16 100644 --- a/partiql-planner/src/main/resources/partiql_plan_internal.ion +++ b/partiql-planner/src/main/resources/partiql_plan_internal.ion @@ -318,13 +318,16 @@ rel::{ call::[ unresolved::{ name: string, + set_quantifier: set_quantifier, args: list::[rex], }, resolved::{ agg: '.ref.agg', + set_quantifier: set_quantifier, args: list::[rex], }, - ] + ], + set_quantifier::[ ALL, DISTINCT ] ], }, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/Agg_GROUP_AS__ANY__ANY.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/Agg_GROUP_AS__ANY__ANY.kt new file mode 100644 index 0000000000..b711fc79be --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/Agg_GROUP_AS__ANY__ANY.kt @@ -0,0 +1,26 @@ +package org.partiql.spi.connector.sql.builtins + +import org.partiql.spi.fn.Agg +import org.partiql.spi.fn.AggSignature +import org.partiql.spi.fn.FnExperimental +import org.partiql.spi.fn.FnParameter +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +public object Agg_GROUP_AS__ANY__ANY : Agg { + + override val signature: AggSignature = AggSignature( + name = "group_as", + returns = PartiQLValueType.ANY, + parameters = listOf( + FnParameter("value", PartiQLValueType.ANY), + ), + isNullable = true, + isDecomposable = true + ) + + override fun accumulator(): Agg.Accumulator { + TODO("Aggregation sum not implemented") + } +}