Skip to content

Commit

Permalink
Adds support for aggregations (GROUP BY)
Browse files Browse the repository at this point in the history
Adds support for COLL_AGGs
  • Loading branch information
johnedquinn committed Feb 21, 2024
1 parent f1aeb6f commit 1c7dce1
Show file tree
Hide file tree
Showing 67 changed files with 1,698 additions and 447 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package org.partiql.ast.normalize

import org.partiql.ast.AstNode
import org.partiql.ast.Expr
import org.partiql.ast.GroupBy
import org.partiql.ast.Statement
Expand All @@ -30,6 +31,13 @@ object NormalizeGroupBy : AstPass {

private object Visitor : AstRewriter<Int>() {

override fun visitGroupBy(node: GroupBy, ctx: Int): AstNode {
val keys = node.keys.mapIndexed { index, key ->
visitGroupByKey(key, index + 1)
}
return node.copy(keys = keys)
}

override fun visitGroupByKey(node: GroupBy.Key, ctx: Int): GroupBy.Key {
val expr = visitExpr(node.expr, 0) as Expr
val alias = when (node.asAlias) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.partiql.eval.internal

import org.partiql.eval.PartiQLEngine
import org.partiql.eval.internal.operator.Operator
import org.partiql.eval.internal.operator.rel.RelAggregate
import org.partiql.eval.internal.operator.rel.RelDistinct
import org.partiql.eval.internal.operator.rel.RelExclude
import org.partiql.eval.internal.operator.rel.RelFilter
Expand Down Expand Up @@ -45,6 +46,7 @@ import org.partiql.plan.Rex
import org.partiql.plan.Statement
import org.partiql.plan.debug.PlanPrinter
import org.partiql.plan.visitor.PlanBaseVisitor
import org.partiql.spi.fn.Agg
import org.partiql.spi.fn.FnExperimental
import org.partiql.types.StaticType
import org.partiql.value.PartiQLValueExperimental
Expand Down Expand Up @@ -170,6 +172,30 @@ internal class Compiler(

override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: StaticType?): Operator = symbols.getGlobal(node.ref)

override fun visitRelOpAggregate(node: Rel.Op.Aggregate, ctx: StaticType?): Operator.Relation {
val input = visitRel(node.input, ctx)
val calls = node.calls.map {
visitRelOpAggregateCall(it, ctx)
}
val groups = node.groups.map { visitRex(it, ctx).modeHandled() }
return RelAggregate(input, groups, calls)
}

@OptIn(FnExperimental::class)
override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: StaticType?): Operator.Accumulator {
val args = node.args.map { visitRex(it, it.type).modeHandled() } // TODO: Should we support multiple arguments?
val setQuantifier: Operator.Accumulator.SetQuantifier = when (node.setQuantifier) {
Rel.Op.Aggregate.Call.SetQuantifier.ALL -> Operator.Accumulator.SetQuantifier.ALL
Rel.Op.Aggregate.Call.SetQuantifier.DISTINCT -> Operator.Accumulator.SetQuantifier.DISTINCT
}
val agg = symbols.getAgg(node.agg)
return object : Operator.Accumulator {
override val delegate: Agg = agg
override val args: List<Operator.Expr> = args
override val setQuantifier: Operator.Accumulator.SetQuantifier = setQuantifier
}
}

override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: StaticType?): Operator {
val root = visitRex(node.root, ctx)
val key = visitRex(node.key, ctx)
Expand Down Expand Up @@ -206,7 +232,7 @@ internal class Compiler(
val args = node.args.map { visitRex(it, ctx).modeHandled() }.toTypedArray()
val candidates = node.candidates.map { candidate ->
val fn = symbols.getFn(candidate.fn)
val types = fn.signature.parameters.map { it.type }.toTypedArray()
val types = candidate.parameters.toTypedArray()
val coercions = candidate.coercions.toTypedArray()
ExprCallDynamic.Candidate(fn, types, coercions)
}
Expand Down
16 changes: 16 additions & 0 deletions partiql-eval/src/main/kotlin/org/partiql/eval/internal/Symbols.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import org.partiql.eval.internal.operator.rex.ExprVarGlobal
import org.partiql.plan.Catalog
import org.partiql.plan.PartiQLPlan
import org.partiql.plan.Ref
import org.partiql.spi.connector.ConnectorAggProvider
import org.partiql.spi.connector.ConnectorBindings
import org.partiql.spi.connector.ConnectorFnProvider
import org.partiql.spi.connector.ConnectorPath
import org.partiql.spi.fn.Agg
import org.partiql.spi.fn.Fn
import org.partiql.spi.fn.FnExperimental

Expand All @@ -25,6 +27,7 @@ internal class Symbols private constructor(private val catalogs: Array<C>) {
val name: String,
val bindings: ConnectorBindings,
val functions: ConnectorFnProvider,
val aggregations: ConnectorAggProvider,
val items: Array<Catalog.Item>,
) {

Expand Down Expand Up @@ -53,6 +56,18 @@ internal class Symbols private constructor(private val catalogs: Array<C>) {
?: error("Catalog `$catalog` has no entry for function $item")
}

fun getAgg(ref: Ref): Agg {
val catalog = catalogs[ref.catalog]
val item = catalog.items.getOrNull(ref.symbol)
if (item == null || item !is Catalog.Item.Agg) {
error("Invalid reference $ref; missing function entry for catalog `$catalog`.")
}
// Lookup in connector
val path = ConnectorPath(item.path)
return catalog.aggregations.getAgg(path, item.specific)
?: error("Catalog `$catalog` has no entry for aggregation function $item")
}

companion object {

/**
Expand All @@ -71,6 +86,7 @@ internal class Symbols private constructor(private val catalogs: Array<C>) {
name = it.name,
bindings = connector.getBindings(),
functions = connector.getFunctions(),
aggregations = connector.getAggregations(),
items = it.items.toTypedArray()
)
}.toTypedArray()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package org.partiql.eval.internal.operator

import org.partiql.eval.internal.Record
import org.partiql.spi.fn.Agg
import org.partiql.spi.fn.FnExperimental
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental

Expand All @@ -26,4 +28,19 @@ internal sealed interface Operator {

override fun close()
}

interface Accumulator : Operator {

@OptIn(FnExperimental::class)
val delegate: Agg

val args: List<Expr>

val setQuantifier: SetQuantifier

enum class SetQuantifier {
ALL,
DISTINCT
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package org.partiql.eval.internal.operator.rel

import org.partiql.eval.internal.Record
import org.partiql.eval.internal.operator.Operator
import org.partiql.spi.fn.Agg
import org.partiql.spi.fn.FnExperimental
import org.partiql.value.ListValue
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.PartiQLValueType
import org.partiql.value.listValue
import org.partiql.value.nullValue
import java.util.TreeMap
import java.util.TreeSet

internal class RelAggregate(
val input: Operator.Relation,
val keys: List<Operator.Expr>,
val functions: List<Operator.Accumulator>
) : Operator.Relation {

lateinit var records: Iterator<Record>

@OptIn(PartiQLValueExperimental::class)
val aggregationMap = TreeMap<PartiQLValue, List<AccumulatorWrapper>>(PartiQLValue.comparator(nullsFirst = false))

@OptIn(PartiQLValueExperimental::class)
object PartiQLValueListComparator : Comparator<List<PartiQLValue>> {
private val delegate = PartiQLValue.comparator(nullsFirst = false)
override fun compare(o1: List<PartiQLValue>, o2: List<PartiQLValue>): Int {
if (o1.size < o2.size) {
return -1
}
if (o1.size > o2.size) {
return 1
}
for (index in 0..o2.lastIndex) {
val element1 = o1[index]
val element2 = o2[index]
val compared = delegate.compare(element1, element2)
if (compared != 0) {
return compared
}
}
return 0
}
}

/**
* Wraps an [Operator.Accumulator.Instance] to help with filtering distinct values.
*
* @property seen maintains which values have already been seen. If null, we accumulate all values coming through.
*/
class AccumulatorWrapper @OptIn(PartiQLValueExperimental::class, FnExperimental::class) constructor(
val delegate: Agg.Accumulator,
val args: List<Operator.Expr>,
val seen: TreeSet<List<PartiQLValue>>?
)

@OptIn(PartiQLValueExperimental::class, FnExperimental::class)
override fun open() {
input.open()
var inputRecord = input.next()
while (inputRecord != null) {
// Initialize the AggregationMap
val evaluatedGroupByKeys = listValue(
keys.map {
val key = it.eval(inputRecord!!)
when (key.type == PartiQLValueType.MISSING) {
true -> nullValue()
false -> key
}
}
)
val accumulators = aggregationMap.getOrPut(evaluatedGroupByKeys) {
functions.map {
AccumulatorWrapper(
delegate = it.delegate.accumulator(),
args = it.args,
seen = when (it.setQuantifier) {
Operator.Accumulator.SetQuantifier.DISTINCT -> TreeSet(PartiQLValueListComparator)
Operator.Accumulator.SetQuantifier.ALL -> null
}
)
}
}

// Aggregate Values in Aggregation State
accumulators.forEachIndexed { index, function ->
val valueToAggregate = function.args.map { it.eval(inputRecord!!) }
// Skip over aggregation if NULL/MISSING
if (valueToAggregate.any { it.type == PartiQLValueType.MISSING || it.isNull }) {
return@forEachIndexed
}
// Skip over aggregation if DISTINCT and SEEN
if (function.seen != null && (function.seen.add(valueToAggregate).not())) {
return@forEachIndexed
}
accumulators[index].delegate.next(valueToAggregate.toTypedArray())
}
inputRecord = input.next()
}

// No Aggregations Created // TODO: How would this be possible?
if (keys.isEmpty() && aggregationMap.isEmpty()) {
val record = mutableListOf<PartiQLValue>()
functions.forEach { function ->
val accumulator = function.delegate.accumulator()
record.add(accumulator.value())
}
records = iterator { yield(Record.of(*record.toTypedArray())) }
return
}

records = iterator {
aggregationMap.forEach { (pValue, accumulators) ->
val keysEvaluated = pValue as ListValue<*>
val recordValues = accumulators.map { acc -> acc.delegate.value() } + keysEvaluated.map { value -> value }
yield(Record.of(*recordValues.toTypedArray()))
}
}
}

override fun next(): Record? {
return if (records.hasNext()) {
records.next()
} else {
null
}
}

@OptIn(PartiQLValueExperimental::class)
override fun close() {
aggregationMap.clear()
input.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ internal class ExprCallDynamic(
return candidate.eval(actualArgs)
}
}
throw TypeCheckException()
val errorString = buildString {
val argString = actualArgs.joinToString(", ")
append("Could not dynamically find function for arguments $argString in $candidates.")
}
throw TypeCheckException(errorString)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ internal class ExprPathSymbol(
return v
}
}
throw TypeCheckException()
throw TypeCheckException("Couldn't find symbol '$symbol' in $struct.")
}
}
Loading

0 comments on commit 1c7dce1

Please sign in to comment.