Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do not merge] Run Eval Engine Against the existing unit test. #1376

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions buildSrc/src/main/kotlin/partiql.conventions.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ java {
tasks.test {
useJUnitPlatform() // Enable JUnit5
jvmArgs.addAll(listOf("-Duser.language=en", "-Duser.country=US"))
// disable timeout during debugging session for all junit5 tests
jvmArgs.add("-Djunit.jupiter.execution.timeout.mode=disabled_on_debug")
maxHeapSize = "4g"
testLogging {
events.add(TestLogEvent.FAILED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package org.partiql.ast.normalize

import org.partiql.ast.AstNode
import org.partiql.ast.Expr
import org.partiql.ast.GroupBy
import org.partiql.ast.Statement
Expand All @@ -30,6 +31,13 @@ object NormalizeGroupBy : AstPass {

private object Visitor : AstRewriter<Int>() {

override fun visitGroupBy(node: GroupBy, ctx: Int): AstNode {
val keys = node.keys.mapIndexed { index, key ->
visitGroupByKey(key, index + 1)
}
return node.copy(keys = keys)
}

override fun visitGroupByKey(node: GroupBy.Key, ctx: Int): GroupBy.Key {
val expr = visitExpr(node.expr, 0) as Expr
val alias = when (node.asAlias) {
Expand Down
14 changes: 13 additions & 1 deletion partiql-eval/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ plugins {
id(Plugins.conventions)
id(Plugins.library)
id(Plugins.publish)
id(Plugins.testFixtures)
}

dependencies {
Expand All @@ -32,10 +33,21 @@ dependencies {
testImplementation(project(":plugins:partiql-local"))
testImplementation(project(":plugins:partiql-memory"))
testImplementation(testFixtures(project(":partiql-planner")))
testImplementation(testFixtures(project(":partiql-lang")))
testImplementation(Deps.junit4)
testImplementation(Deps.junit4Params)
testImplementation(Deps.junitVintage) // Enables JUnit4

testFixturesImplementation(project(":partiql-lang")) // To be decoupled
testFixturesImplementation(project(":lib:isl"))
testFixturesImplementation(Deps.kotlinTest)
testFixturesImplementation(Deps.kotlinTestJunit)
testFixturesImplementation(Deps.assertj)
testFixturesImplementation(Deps.junit4)
testFixturesImplementation(Deps.junit4Params)
testFixturesImplementation(Deps.junitApi)
testFixturesImplementation(Deps.junitParams)
testFixturesImplementation(Deps.junitVintage) // Enables JUnit4
testFixturesImplementation(Deps.mockk)
}

// Disabled for partiql-eval project at initialization.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.partiql.eval.internal

import org.partiql.eval.PartiQLEngine
import org.partiql.eval.internal.operator.Operator
import org.partiql.eval.internal.operator.rel.RelAggregate
import org.partiql.eval.internal.operator.rel.RelDistinct
import org.partiql.eval.internal.operator.rel.RelExclude
import org.partiql.eval.internal.operator.rel.RelFilter
Expand Down Expand Up @@ -45,6 +46,7 @@ import org.partiql.plan.Rex
import org.partiql.plan.Statement
import org.partiql.plan.debug.PlanPrinter
import org.partiql.plan.visitor.PlanBaseVisitor
import org.partiql.spi.fn.Agg
import org.partiql.spi.fn.FnExperimental
import org.partiql.types.StaticType
import org.partiql.value.PartiQLValueExperimental
Expand Down Expand Up @@ -170,6 +172,30 @@ internal class Compiler(

override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: StaticType?): Operator = symbols.getGlobal(node.ref)

override fun visitRelOpAggregate(node: Rel.Op.Aggregate, ctx: StaticType?): Operator.Relation {
val input = visitRel(node.input, ctx)
val calls = node.calls.map {
visitRelOpAggregateCall(it, ctx)
}
val groups = node.groups.map { visitRex(it, ctx).modeHandled() }
return RelAggregate(input, groups, calls)
}

@OptIn(FnExperimental::class)
override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: StaticType?): Operator.Aggregation {
val args = node.args.map { visitRex(it, it.type).modeHandled() }
val setQuantifier: Operator.Aggregation.SetQuantifier = when (node.setQuantifier) {
Rel.Op.Aggregate.Call.SetQuantifier.ALL -> Operator.Aggregation.SetQuantifier.ALL
Rel.Op.Aggregate.Call.SetQuantifier.DISTINCT -> Operator.Aggregation.SetQuantifier.DISTINCT
}
val agg = symbols.getAgg(node.agg)
return object : Operator.Aggregation {
override val delegate: Agg = agg
override val args: List<Operator.Expr> = args
override val setQuantifier: Operator.Aggregation.SetQuantifier = setQuantifier
}
}

override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: StaticType?): Operator {
val root = visitRex(node.root, ctx)
val key = visitRex(node.key, ctx)
Expand Down Expand Up @@ -206,7 +232,7 @@ internal class Compiler(
val args = node.args.map { visitRex(it, ctx).modeHandled() }.toTypedArray()
val candidates = node.candidates.map { candidate ->
val fn = symbols.getFn(candidate.fn)
val types = fn.signature.parameters.map { it.type }.toTypedArray()
val types = candidate.parameters.toTypedArray()
val coercions = candidate.coercions.toTypedArray()
ExprCallDynamic.Candidate(fn, types, coercions)
}
Expand Down
16 changes: 16 additions & 0 deletions partiql-eval/src/main/kotlin/org/partiql/eval/internal/Symbols.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import org.partiql.eval.internal.operator.rex.ExprVarGlobal
import org.partiql.plan.Catalog
import org.partiql.plan.PartiQLPlan
import org.partiql.plan.Ref
import org.partiql.spi.connector.ConnectorAggProvider
import org.partiql.spi.connector.ConnectorBindings
import org.partiql.spi.connector.ConnectorFnProvider
import org.partiql.spi.connector.ConnectorPath
import org.partiql.spi.fn.Agg
import org.partiql.spi.fn.Fn
import org.partiql.spi.fn.FnExperimental

Expand All @@ -25,6 +27,7 @@ internal class Symbols private constructor(private val catalogs: Array<C>) {
val name: String,
val bindings: ConnectorBindings,
val functions: ConnectorFnProvider,
val aggregations: ConnectorAggProvider,
val items: Array<Catalog.Item>,
) {

Expand Down Expand Up @@ -53,6 +56,18 @@ internal class Symbols private constructor(private val catalogs: Array<C>) {
?: error("Catalog `$catalog` has no entry for function $item")
}

fun getAgg(ref: Ref): Agg {
val catalog = catalogs[ref.catalog]
val item = catalog.items.getOrNull(ref.symbol)
if (item == null || item !is Catalog.Item.Agg) {
error("Invalid reference $ref; missing aggregation entry for catalog `$catalog`.")
}
// Lookup in connector
val path = ConnectorPath(item.path)
return catalog.aggregations.getAgg(path, item.specific)
?: error("Catalog `$catalog` has no entry for aggregation function $item")
}

companion object {

/**
Expand All @@ -71,6 +86,7 @@ internal class Symbols private constructor(private val catalogs: Array<C>) {
name = it.name,
bindings = connector.getBindings(),
functions = connector.getFunctions(),
aggregations = connector.getAggregations(),
items = it.items.toTypedArray()
)
}.toTypedArray()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package org.partiql.eval.internal.operator

import org.partiql.eval.internal.Record
import org.partiql.spi.fn.Agg
import org.partiql.spi.fn.FnExperimental
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental

Expand All @@ -26,4 +28,19 @@ internal sealed interface Operator {

override fun close()
}

interface Aggregation : Operator {

@OptIn(FnExperimental::class)
val delegate: Agg

val args: List<Expr>

val setQuantifier: SetQuantifier

enum class SetQuantifier {
ALL,
DISTINCT
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package org.partiql.eval.internal.operator.rel

import org.partiql.eval.internal.Record
import org.partiql.eval.internal.operator.Operator
import org.partiql.spi.fn.Agg
import org.partiql.spi.fn.FnExperimental
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.PartiQLValueType
import org.partiql.value.nullValue
import java.util.TreeMap
import java.util.TreeSet

internal class RelAggregate(
val input: Operator.Relation,
val keys: List<Operator.Expr>,
val functions: List<Operator.Aggregation>
) : Operator.Relation {

lateinit var records: Iterator<Record>

@OptIn(PartiQLValueExperimental::class)
val aggregationMap = TreeMap<List<PartiQLValue>, List<AccumulatorWrapper>>(PartiQLValueListComparator)

@OptIn(PartiQLValueExperimental::class)
object PartiQLValueListComparator : Comparator<List<PartiQLValue>> {
private val delegate = PartiQLValue.comparator(nullsFirst = false)
override fun compare(o1: List<PartiQLValue>, o2: List<PartiQLValue>): Int {
if (o1.size < o2.size) {
return -1
}
if (o1.size > o2.size) {
return 1
}
for (index in 0..o2.lastIndex) {
val element1 = o1[index]
val element2 = o2[index]
val compared = delegate.compare(element1, element2)
if (compared != 0) {
return compared
}
}
return 0
}
}

/**
* Wraps an [Agg.Accumulator] to help with filtering distinct values.
*
* @property seen maintains which values have already been seen. If null, we accumulate all values coming through.
*/
class AccumulatorWrapper @OptIn(PartiQLValueExperimental::class, FnExperimental::class) constructor(
val delegate: Agg.Accumulator,
val args: List<Operator.Expr>,
val seen: TreeSet<List<PartiQLValue>>?
)

@OptIn(PartiQLValueExperimental::class, FnExperimental::class)
override fun open() {
input.open()
var inputRecord = input.next()
while (inputRecord != null) {
// Initialize the AggregationMap
val evaluatedGroupByKeys = keys.map {
val key = it.eval(inputRecord!!)
when (key.type == PartiQLValueType.MISSING) {
true -> nullValue()
false -> key
}
}
val accumulators = aggregationMap.getOrPut(evaluatedGroupByKeys) {
functions.map {
AccumulatorWrapper(
delegate = it.delegate.accumulator(),
args = it.args,
seen = when (it.setQuantifier) {
Operator.Aggregation.SetQuantifier.DISTINCT -> TreeSet(PartiQLValueListComparator)
Operator.Aggregation.SetQuantifier.ALL -> null
}
)
}
}

// Aggregate Values in Aggregation State
accumulators.forEachIndexed { index, function ->
val valueToAggregate = function.args.map { it.eval(inputRecord!!) }
// Skip over aggregation if NULL/MISSING
if (valueToAggregate.any { it.type == PartiQLValueType.MISSING || it.isNull }) {
return@forEachIndexed
}
// Skip over aggregation if DISTINCT and SEEN
if (function.seen != null && (function.seen.add(valueToAggregate).not())) {
return@forEachIndexed
}
accumulators[index].delegate.next(valueToAggregate.toTypedArray())
}
inputRecord = input.next()
}

// No Aggregations Created
if (keys.isEmpty() && aggregationMap.isEmpty()) {
val record = mutableListOf<PartiQLValue>()
functions.forEach { function ->
val accumulator = function.delegate.accumulator()
record.add(accumulator.value())
}
records = iterator { yield(Record.of(*record.toTypedArray())) }
return
}

records = iterator {
aggregationMap.forEach { (keysEvaluated, accumulators) ->
val recordValues = accumulators.map { acc -> acc.delegate.value() } + keysEvaluated.map { value -> value }
yield(Record.of(*recordValues.toTypedArray()))
}
}
}

override fun next(): Record? {
return if (records.hasNext()) {
records.next()
} else {
null
}
}

@OptIn(PartiQLValueExperimental::class)
override fun close() {
aggregationMap.clear()
input.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ internal class ExprCallDynamic(
return candidate.eval(actualArgs)
}
}
throw TypeCheckException()
val errorString = buildString {
val argString = actualArgs.joinToString(", ")
append("Could not dynamically find function for arguments $argString in $candidates.")
}
throw TypeCheckException(errorString)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ internal class ExprPathSymbol(
return v
}
}
throw TypeCheckException()
throw TypeCheckException("Couldn't find symbol '$symbol' in $struct.")
}
}
Loading
Loading