Skip to content

Commit

Permalink
DRAFT -- Adds support for aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
johnedquinn committed Feb 12, 2024
1 parent 8f3c7f2 commit 5ee9c0b
Show file tree
Hide file tree
Showing 23 changed files with 752 additions and 17 deletions.
37 changes: 37 additions & 0 deletions partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@ package org.partiql.eval.internal

import org.partiql.eval.PartiQLEngine
import org.partiql.eval.internal.operator.Operator
import org.partiql.eval.internal.operator.agg.AccumulatorAnySome
import org.partiql.eval.internal.operator.agg.AccumulatorAvg
import org.partiql.eval.internal.operator.agg.AccumulatorCount
import org.partiql.eval.internal.operator.agg.AccumulatorEvery
import org.partiql.eval.internal.operator.agg.AccumulatorGroupAs
import org.partiql.eval.internal.operator.agg.AccumulatorMax
import org.partiql.eval.internal.operator.agg.AccumulatorMin
import org.partiql.eval.internal.operator.agg.AccumulatorSum
import org.partiql.eval.internal.operator.rel.RelAggregate
import org.partiql.eval.internal.operator.rel.RelDistinct
import org.partiql.eval.internal.operator.rel.RelExclude
import org.partiql.eval.internal.operator.rel.RelFilter
Expand Down Expand Up @@ -123,6 +132,34 @@ internal class Compiler(
return ExprVar(node.ref)
}

override fun visitRelOpAggregate(node: Rel.Op.Aggregate, ctx: StaticType?): Operator.Relation {
val input = visitRel(node.input, ctx)
val calls = node.calls.map { visitRelOpAggregateCall(it, ctx) }
val groups = node.groups.map { visitRex(it, ctx).modeHandled() }
return RelAggregate(input, groups, calls)
}

@OptIn(PartiQLValueExperimental::class)
override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: StaticType?): Operator.Accumulator {
val args = node.args.map { visitRex(it, it.type).modeHandled() } // TODO: Should we support multiple arguments?
val setQuantifier: Operator.Accumulator.SetQuantifier = when (node.setQuantifier) {
Rel.Op.Aggregate.Call.SetQuantifier.ALL -> Operator.Accumulator.SetQuantifier.ALL
Rel.Op.Aggregate.Call.SetQuantifier.DISTINCT -> Operator.Accumulator.SetQuantifier.DISTINCT
}
return when (node.agg.uppercase()) {
"MIN" -> AccumulatorMin.Factory(args, setQuantifier)
"MAX" -> AccumulatorMax.Factory(args, setQuantifier)
"AVG" -> AccumulatorAvg.Factory(args, setQuantifier)
"COUNT" -> AccumulatorCount.Factory(args, setQuantifier)
"SUM" -> AccumulatorSum.Factory(args, setQuantifier)
"GROUP_AS" -> AccumulatorGroupAs.Factory(args, setQuantifier)
"EVERY" -> AccumulatorEvery.Factory(args, setQuantifier)
"ANY" -> AccumulatorAnySome.Factory(args, setQuantifier)
"SOME" -> AccumulatorAnySome.Factory(args, setQuantifier)
else -> error("Unexpected aggregation: ${node.agg}.")
}
}

override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: StaticType?): Operator = symbols.getGlobal(node.ref)

override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: StaticType?): Operator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,31 @@ internal sealed interface Operator {

override fun close()
}

interface Accumulator : Operator {

val setQuantifier: SetQuantifier
fun create(): Instance

interface Instance {

/**
* The argument to invoke.
*/
val args: List<Expr>

/** Accumulates the next value into this [Instance]. */
@OptIn(PartiQLValueExperimental::class)
fun next(value: PartiQLValue)

/** Digests the result of the accumulated values. */
@OptIn(PartiQLValueExperimental::class)
fun compute(): PartiQLValue
}

enum class SetQuantifier {
ALL,
DISTINCT
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at:
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
* language governing permissions and limitations under the License.
*/

@file:OptIn(PartiQLValueExperimental::class)

package org.partiql.eval.internal.operator.agg

import com.amazon.ion.Decimal
import org.partiql.eval.internal.operator.Operator
import org.partiql.value.BoolValue
import org.partiql.value.DecimalValue
import org.partiql.value.Float32Value
import org.partiql.value.Float64Value
import org.partiql.value.Int16Value
import org.partiql.value.Int32Value
import org.partiql.value.Int64Value
import org.partiql.value.Int8Value
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.PartiQLValueType
import org.partiql.value.decimalValue
import org.partiql.value.float64Value
import org.partiql.value.int32Value
import org.partiql.value.int64Value
import java.math.BigDecimal
import java.math.MathContext
import java.math.RoundingMode

internal abstract class Accumulator : Operator.Accumulator.Instance {

/** Accumulates the next value into this [Accumulator]. */
@OptIn(PartiQLValueExperimental::class)
override fun next(value: PartiQLValue) {
if (value.isUnknown()) return
nextValue(value)
}

abstract fun nextValue(value: PartiQLValue)
}

// TODO: Make this better
@OptIn(PartiQLValueExperimental::class)
internal fun comparisonAccumulator(comparator: Comparator<PartiQLValue>): (PartiQLValue?, PartiQLValue) -> PartiQLValue =
{ left, right ->
when {
left == null || comparator.compare(left, right) > 0 -> right
else -> left
}
}

// TODO: Make this better
@OptIn(PartiQLValueExperimental::class)
internal fun checkIsNumberType(funcName: String, value: PartiQLValue) {
if (!value.type.isNumber()) {
TODO("NEED TO HANDLE")
}
}

// TODO: Make this better
@OptIn(PartiQLValueExperimental::class)
internal fun checkIsBooleanType(funcName: String, value: PartiQLValue) {
if (value.type != PartiQLValueType.BOOL) {
TODO("NEED TO HANDLE")
}
}

// TODO: Make this better
@OptIn(PartiQLValueExperimental::class)
internal fun PartiQLValue.isUnknown(): Boolean = this.type == PartiQLValueType.MISSING || this.isNull

// TODO: Make this better
@OptIn(PartiQLValueExperimental::class)
internal fun PartiQLValue.numberValue(): Number = when (this) {
is Int8Value -> this.value!!
is Int16Value -> this.value!!
is Int32Value -> this.value!!
is Int64Value -> this.value!!
is DecimalValue -> this.value!!
is Float32Value -> this.value!!
is Float64Value -> this.value!!
else -> error("Cannot convert PartiQLValue ($this) to number.")
}

// TODO: Make this better
@OptIn(PartiQLValueExperimental::class)
internal fun PartiQLValue.booleanValue(): Boolean = when (this) {
is BoolValue -> this.value!!
else -> error("Cannot convert PartiQLValue ($this) to boolean.")
}

// TODO: Make this better
@OptIn(PartiQLValueExperimental::class)
internal fun PartiQLValueType.isNumber(): Boolean = when (this) {
PartiQLValueType.INT,
PartiQLValueType.INT8,
PartiQLValueType.INT16,
PartiQLValueType.INT32,
PartiQLValueType.INT64,
PartiQLValueType.DECIMAL,
PartiQLValueType.FLOAT32,
PartiQLValueType.FLOAT64 -> true
else -> false
}

// TODO: Make this better
@OptIn(PartiQLValueExperimental::class)
internal fun Number.partiqlValue(): PartiQLValue = when (this) {
is Int -> int32Value(this)
is Long -> int64Value(this)
is Double -> float64Value(this)
is BigDecimal -> decimalValue(this)
else -> TODO("Error context")
}

// TODO: Make this better
private val MATH_CONTEXT = MathContext(38, RoundingMode.HALF_EVEN)

// TODO: Make this better
/**
* Factory function to create a [BigDecimal] using correct precision, use it in favor of native BigDecimal constructors
* and factory methods
*/
internal fun bigDecimalOf(num: Number, mc: MathContext = MATH_CONTEXT): BigDecimal = when (num) {
is Decimal -> num
is Int -> BigDecimal(num, mc)
is Long -> BigDecimal(num, mc)
is Double -> BigDecimal(num, mc)
is BigDecimal -> num
Decimal.NEGATIVE_ZERO -> num as Decimal
else -> throw IllegalArgumentException("Unsupported number type: $num, ${num.javaClass}")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package org.partiql.eval.internal.operator.agg

import org.partiql.eval.internal.operator.Operator
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.boolValue
import org.partiql.value.nullValue

@OptIn(PartiQLValueExperimental::class)
internal class AccumulatorAnySome(
override val args: List<Operator.Expr>
) : Accumulator() {

private var res: PartiQLValue? = null

override fun nextValue(value: PartiQLValue) {
checkIsBooleanType("ANY/SOME", value)
res = res?.let { boolValue(it.booleanValue() || value.booleanValue()) } ?: value
}

override fun compute(): PartiQLValue = res ?: nullValue()

class Factory(
val args: List<Operator.Expr>,
override val setQuantifier: Operator.Accumulator.SetQuantifier
) : Operator.Accumulator {
override fun create(): Operator.Accumulator.Instance = AccumulatorAnySome(args)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.partiql.eval.internal.operator.agg

import org.partiql.eval.internal.operator.Operator
import org.partiql.lang.util.div
import org.partiql.lang.util.plus
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.nullValue

@OptIn(PartiQLValueExperimental::class)
internal class AccumulatorAvg(
override val args: List<Operator.Expr>
) : Accumulator() {

var sum: Number = 0.0
var count: Long = 0L

override fun nextValue(value: PartiQLValue) {
checkIsNumberType(funcName = "AVG", value = value)
this.sum += value.numberValue()
this.count += 1L
}

override fun compute(): PartiQLValue = when (count) {
0L -> nullValue()
else -> (sum / bigDecimalOf(count)).partiqlValue()
}

class Factory(
val args: List<Operator.Expr>,
override val setQuantifier: Operator.Accumulator.SetQuantifier
) : Operator.Accumulator {
override fun create(): Operator.Accumulator.Instance = AccumulatorAvg(args)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package org.partiql.eval.internal.operator.agg

import org.partiql.eval.internal.operator.Operator
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.int64Value

@OptIn(PartiQLValueExperimental::class)
internal class AccumulatorCount(
override val args: List<Operator.Expr>
) : Accumulator() {

var count: Long = 0L

override fun nextValue(value: PartiQLValue) {
this.count += 1L
}

override fun compute(): PartiQLValue = int64Value(count)

class Factory(
val args: List<Operator.Expr>,
override val setQuantifier: Operator.Accumulator.SetQuantifier
) : Operator.Accumulator {
override fun create(): Operator.Accumulator.Instance = AccumulatorCount(args)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.partiql.eval.internal.operator.agg

import org.partiql.eval.internal.operator.Operator
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.boolValue
import org.partiql.value.nullValue

@OptIn(PartiQLValueExperimental::class)
internal class AccumulatorEvery(
override val args: List<Operator.Expr>
) : Accumulator() {

private var res: PartiQLValue? = null

@OptIn(PartiQLValueExperimental::class)
override fun nextValue(value: PartiQLValue) {
checkIsBooleanType("EVERY", value)
res = res?.let { boolValue(it.booleanValue() && value.booleanValue()) } ?: value
}

override fun compute(): PartiQLValue = res ?: nullValue()

class Factory(
val args: List<Operator.Expr>,
override val setQuantifier: Operator.Accumulator.SetQuantifier
) : Operator.Accumulator {
override fun create(): Operator.Accumulator.Instance = AccumulatorEvery(args)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package org.partiql.eval.internal.operator.agg

import org.partiql.eval.internal.operator.Operator
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.bagValue

@OptIn(PartiQLValueExperimental::class)
internal class AccumulatorGroupAs(
override val args: List<Operator.Expr>
) : Accumulator() {

val values = mutableListOf<PartiQLValue>()

override fun nextValue(value: PartiQLValue) {
values.add(value)
}

override fun compute(): PartiQLValue = bagValue(values)

class Factory(
val args: List<Operator.Expr>,
override val setQuantifier: Operator.Accumulator.SetQuantifier
) : Operator.Accumulator {
override fun create(): Operator.Accumulator.Instance = AccumulatorGroupAs(args)
}
}
Loading

0 comments on commit 5ee9c0b

Please sign in to comment.