Skip to content

Commit

Permalink
Add RelSort and comparator between PartiQLValues (#1343)
Browse files Browse the repository at this point in the history
  • Loading branch information
alancai98 authored Feb 6, 2024
1 parent 27081a7 commit 97f9926
Show file tree
Hide file tree
Showing 21 changed files with 1,127 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.partiql.eval

import org.partiql.eval.internal.Compiler
import org.partiql.eval.internal.Record
import org.partiql.eval.internal.Symbols
import org.partiql.plan.PartiQLPlan
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
Expand All @@ -11,7 +12,10 @@ internal class PartiQLEngineDefault : PartiQLEngine {
@OptIn(PartiQLValueExperimental::class)
override fun prepare(plan: PartiQLPlan, session: PartiQLEngine.Session): PartiQLStatement<*> {
try {
val compiler = Compiler(plan, session)
// 1. Validate all references
val symbols = Symbols.build(plan, session)
// 2. Compile with built symbols
val compiler = Compiler(plan, session, symbols)
val expression = compiler.compile()
return object : PartiQLStatement.Query {
override fun execute(): PartiQLValue {
Expand Down
92 changes: 52 additions & 40 deletions partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import org.partiql.eval.internal.operator.rel.RelScan
import org.partiql.eval.internal.operator.rel.RelScanIndexed
import org.partiql.eval.internal.operator.rel.RelScanIndexedPermissive
import org.partiql.eval.internal.operator.rel.RelScanPermissive
import org.partiql.eval.internal.operator.rel.RelSort
import org.partiql.eval.internal.operator.rex.ExprCallDynamic
import org.partiql.eval.internal.operator.rex.ExprCallStatic
import org.partiql.eval.internal.operator.rex.ExprCase
Expand All @@ -40,73 +41,74 @@ import org.partiql.plan.Statement
import org.partiql.plan.debug.PlanPrinter
import org.partiql.plan.visitor.PlanBaseVisitor
import org.partiql.spi.fn.FnExperimental
import org.partiql.types.StaticType
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.PartiQLValueType
import java.lang.IllegalStateException

internal class Compiler(
private val plan: PartiQLPlan,
private val session: PartiQLEngine.Session
) : PlanBaseVisitor<Operator, Symbols>() {
private val session: PartiQLEngine.Session,
private val symbols: Symbols
) : PlanBaseVisitor<Operator, StaticType?>() {

fun compile(): Operator.Expr {
// 1. Validate all references
val symbols = Symbols.build(plan, session)
// 2. Compile with built symbols
return visitPartiQLPlan(plan, symbols)
return visitPartiQLPlan(plan, null)
}

override fun defaultReturn(node: PlanNode, ctx: Symbols): Operator {
override fun defaultReturn(node: PlanNode, ctx: StaticType?): Operator {
TODO("Not yet implemented")
}

override fun visitRexOpErr(node: Rex.Op.Err, ctx: Symbols): Operator {
override fun visitRexOpErr(node: Rex.Op.Err, ctx: StaticType?): Operator {
val message = buildString {
this.appendLine(node.message)
PlanPrinter.append(this, plan)
}
throw IllegalStateException(message)
}

override fun visitRelOpErr(node: Rel.Op.Err, ctx: Symbols): Operator {
override fun visitRelOpErr(node: Rel.Op.Err, ctx: StaticType?): Operator {
throw IllegalStateException(node.message)
}

// TODO: Re-look at
override fun visitPartiQLPlan(node: PartiQLPlan, ctx: Symbols): Operator.Expr {
override fun visitPartiQLPlan(node: PartiQLPlan, ctx: StaticType?): Operator.Expr {
return visitStatement(node.statement, ctx) as Operator.Expr
}

// TODO: Re-look at
override fun visitStatementQuery(node: Statement.Query, ctx: Symbols): Operator.Expr {
override fun visitStatementQuery(node: Statement.Query, ctx: StaticType?): Operator.Expr {
return visitRex(node.root, ctx).modeHandled()
}

// REX

override fun visitRex(node: Rex, ctx: Symbols): Operator.Expr {
return super.visitRexOp(node.op, ctx) as Operator.Expr
override fun visitRex(node: Rex, ctx: StaticType?): Operator.Expr {
return super.visitRexOp(node.op, node.type) as Operator.Expr
}

override fun visitRexOpCollection(node: Rex.Op.Collection, ctx: Symbols): Operator {
override fun visitRexOpCollection(node: Rex.Op.Collection, ctx: StaticType?): Operator {
val values = node.values.map { visitRex(it, ctx).modeHandled() }
return ExprCollection(values)
val type = ctx ?: error("No type provided in ctx")
return ExprCollection(values, type)
}
override fun visitRexOpStruct(node: Rex.Op.Struct, ctx: Symbols): Operator {
override fun visitRexOpStruct(node: Rex.Op.Struct, ctx: StaticType?): Operator {
val fields = node.fields.map {
val value = visitRex(it.v, ctx).modeHandled()
ExprStruct.Field(visitRex(it.k, ctx), value)
}
return ExprStruct(fields)
}

override fun visitRexOpSelect(node: Rex.Op.Select, ctx: Symbols): Operator {
override fun visitRexOpSelect(node: Rex.Op.Select, ctx: StaticType?): Operator {
val rel = visitRel(node.rel, ctx)
val ordered = node.rel.type.props.contains(Rel.Prop.ORDERED)
val constructor = visitRex(node.constructor, ctx).modeHandled()
return ExprSelect(rel, constructor)
return ExprSelect(rel, constructor, ordered)
}

override fun visitRexOpPivot(node: Rex.Op.Pivot, ctx: Symbols): Operator {
override fun visitRexOpPivot(node: Rex.Op.Pivot, ctx: StaticType?): Operator {
val rel = visitRel(node.rel, ctx)
val key = visitRex(node.key, ctx)
val value = visitRex(node.value, ctx)
Expand All @@ -115,33 +117,33 @@ internal class Compiler(
PartiQLEngine.Mode.STRICT -> ExprPivot(rel, key, value)
}
}
override fun visitRexOpVar(node: Rex.Op.Var, ctx: Symbols): Operator {
override fun visitRexOpVar(node: Rex.Op.Var, ctx: StaticType?): Operator {
return ExprVar(node.ref)
}

override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: Symbols): Operator = ctx.getGlobal(node.ref)
override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: StaticType?): Operator = symbols.getGlobal(node.ref)

override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: Symbols): Operator {
override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: StaticType?): Operator {
val root = visitRex(node.root, ctx)
val key = visitRex(node.key, ctx)
return ExprPathKey(root, key)
}

override fun visitRexOpPathSymbol(node: Rex.Op.Path.Symbol, ctx: Symbols): Operator {
override fun visitRexOpPathSymbol(node: Rex.Op.Path.Symbol, ctx: StaticType?): Operator {
val root = visitRex(node.root, ctx)
val symbol = node.key
return ExprPathSymbol(root, symbol)
}

override fun visitRexOpPathIndex(node: Rex.Op.Path.Index, ctx: Symbols): Operator {
override fun visitRexOpPathIndex(node: Rex.Op.Path.Index, ctx: StaticType?): Operator {
val root = visitRex(node.root, ctx)
val index = visitRex(node.key, ctx)
return ExprPathIndex(root, index)
}

@OptIn(FnExperimental::class, PartiQLValueExperimental::class)
override fun visitRexOpCallStatic(node: Rex.Op.Call.Static, ctx: Symbols): Operator {
val fn = ctx.getFn(node.fn)
override fun visitRexOpCallStatic(node: Rex.Op.Call.Static, ctx: StaticType?): Operator {
val fn = symbols.getFn(node.fn)
val args = node.args.map { visitRex(it, ctx) }.toTypedArray()
val fnTakesInMissing = fn.signature.parameters.any {
it.type == PartiQLValueType.MISSING || it.type == PartiQLValueType.ANY
Expand All @@ -153,54 +155,54 @@ internal class Compiler(
}

@OptIn(FnExperimental::class, PartiQLValueExperimental::class)
override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: Symbols): Operator {
override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: StaticType?): Operator {
val args = node.args.map { visitRex(it, ctx).modeHandled() }.toTypedArray()
val candidates = node.candidates.map { candidate ->
val fn = ctx.getFn(candidate.fn)
val fn = symbols.getFn(candidate.fn)
val types = fn.signature.parameters.map { it.type }.toTypedArray()
val coercions = candidate.coercions.toTypedArray()
ExprCallDynamic.Candidate(fn, types, coercions)
}
return ExprCallDynamic(candidates, args)
}

override fun visitRexOpCast(node: Rex.Op.Cast, ctx: Symbols): Operator {
override fun visitRexOpCast(node: Rex.Op.Cast, ctx: StaticType?): Operator {
return ExprCast(visitRex(node.arg, ctx), node.cast)
}

// REL
override fun visitRel(node: Rel, ctx: Symbols): Operator.Relation {
override fun visitRel(node: Rel, ctx: StaticType?): Operator.Relation {
return super.visitRelOp(node.op, ctx) as Operator.Relation
}

override fun visitRelOpScan(node: Rel.Op.Scan, ctx: Symbols): Operator {
override fun visitRelOpScan(node: Rel.Op.Scan, ctx: StaticType?): Operator {
val rex = visitRex(node.rex, ctx)
return when (session.mode) {
PartiQLEngine.Mode.PERMISSIVE -> RelScanPermissive(rex)
PartiQLEngine.Mode.STRICT -> RelScan(rex)
}
}

override fun visitRelOpProject(node: Rel.Op.Project, ctx: Symbols): Operator {
override fun visitRelOpProject(node: Rel.Op.Project, ctx: StaticType?): Operator {
val input = visitRel(node.input, ctx)
val projections = node.projections.map { visitRex(it, ctx).modeHandled() }
return RelProject(input, projections)
}

override fun visitRelOpScanIndexed(node: Rel.Op.ScanIndexed, ctx: Symbols): Operator {
override fun visitRelOpScanIndexed(node: Rel.Op.ScanIndexed, ctx: StaticType?): Operator {
val rex = visitRex(node.rex, ctx)
return when (session.mode) {
PartiQLEngine.Mode.PERMISSIVE -> RelScanIndexedPermissive(rex)
PartiQLEngine.Mode.STRICT -> RelScanIndexed(rex)
}
}

override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: Symbols): Operator {
override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: StaticType?): Operator {
val args = node.args.map { visitRex(it, ctx) }.toTypedArray()
return ExprTupleUnion(args)
}

override fun visitRelOpJoin(node: Rel.Op.Join, ctx: Symbols): Operator {
override fun visitRelOpJoin(node: Rel.Op.Join, ctx: StaticType?): Operator {
val lhs = visitRel(node.lhs, ctx)
val rhs = visitRel(node.rhs, ctx)
val condition = visitRex(node.rex, ctx)
Expand All @@ -212,7 +214,7 @@ internal class Compiler(
}
}

override fun visitRexOpCase(node: Rex.Op.Case, ctx: Symbols): Operator {
override fun visitRexOpCase(node: Rex.Op.Case, ctx: StaticType?): Operator {
val branches = node.branches.map { branch ->
visitRex(branch.condition, ctx) to visitRex(branch.rex, ctx)
}
Expand All @@ -221,26 +223,36 @@ internal class Compiler(
}

@OptIn(PartiQLValueExperimental::class)
override fun visitRexOpLit(node: Rex.Op.Lit, ctx: Symbols): Operator {
override fun visitRexOpLit(node: Rex.Op.Lit, ctx: StaticType?): Operator {
return ExprLiteral(node.value)
}

override fun visitRelOpDistinct(node: Rel.Op.Distinct, ctx: Symbols): Operator {
override fun visitRelOpDistinct(node: Rel.Op.Distinct, ctx: StaticType?): Operator {
val input = visitRel(node.input, ctx)
return RelDistinct(input)
}

override fun visitRelOpFilter(node: Rel.Op.Filter, ctx: Symbols): Operator {
override fun visitRelOpFilter(node: Rel.Op.Filter, ctx: StaticType?): Operator {
val input = visitRel(node.input, ctx)
val condition = visitRex(node.predicate, ctx)
return RelFilter(input, condition)
}

override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: Symbols): Operator {
override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: StaticType?): Operator {
val input = visitRel(node.input, ctx)
return RelExclude(input, node.paths)
}

override fun visitRelOpSort(node: Rel.Op.Sort, ctx: StaticType?): Operator {
val input = visitRel(node.input, ctx)
val compiledSpecs = node.specs.map { spec ->
val expr = visitRex(spec.rex, ctx)
val order = spec.order
expr to order
}
return RelSort(input, compiledSpecs)
}

// HELPERS

private fun Operator.Expr.modeHandled(): Operator.Expr {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ internal class RelOffset(
input.next() ?: return null
seen += 1
}
init = true
}
return input.next()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package org.partiql.eval.internal.operator.rel

import org.partiql.eval.internal.Record
import org.partiql.eval.internal.operator.Operator
import org.partiql.plan.Rel
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import java.util.Collections

@OptIn(PartiQLValueExperimental::class)
internal class RelSort(
private val input: Operator.Relation,
private val specs: List<Pair<Operator.Expr, Rel.Op.Sort.Order>>

) : Operator.Relation {
private var records: Iterator<Record> = Collections.emptyIterator()
private var init: Boolean = false

private val nullsFirstComparator = PartiQLValue.comparator(nullsFirst = true)
private val nullsLastComparator = PartiQLValue.comparator(nullsFirst = false)

override fun open() {
input.open()
init = false
records = Collections.emptyIterator()
}

private val comparator = object : Comparator<Record> {
override fun compare(l: Record, r: Record): Int {
specs.forEach { spec ->
val lVal = spec.first.eval(l)
val rVal = spec.first.eval(r)

// DESC_NULLS_FIRST(l, r) == ASC_NULLS_LAST(r, l)
// DESC_NULLS_LAST(l, r) == ASC_NULLS_FIRST(r, l)
val cmpResult = when (spec.second) {
Rel.Op.Sort.Order.ASC_NULLS_FIRST -> nullsFirstComparator.compare(lVal, rVal)
Rel.Op.Sort.Order.ASC_NULLS_LAST -> nullsLastComparator.compare(lVal, rVal)
Rel.Op.Sort.Order.DESC_NULLS_FIRST -> nullsLastComparator.compare(rVal, lVal)
Rel.Op.Sort.Order.DESC_NULLS_LAST -> nullsFirstComparator.compare(rVal, lVal)
}
if (cmpResult != 0) {
return cmpResult
}
}
return 0 // Equal
}
}

override fun next(): Record? {
if (!init) {
val sortedRecords = mutableListOf<Record>()
while (true) {
val row = input.next() ?: break
sortedRecords.add(row)
}
sortedRecords.sortWith(comparator)
records = sortedRecords.iterator()
init = true
}
return when (records.hasNext()) {
true -> records.next()
false -> null
}
}

override fun close() {
input.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,28 @@ package org.partiql.eval.internal.operator.rex

import org.partiql.eval.internal.Record
import org.partiql.eval.internal.operator.Operator
import org.partiql.types.BagType
import org.partiql.types.ListType
import org.partiql.types.SexpType
import org.partiql.types.StaticType
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.bagValue
import org.partiql.value.listValue
import org.partiql.value.sexpValue

internal class ExprCollection(
private val values: List<Operator.Expr>
private val values: List<Operator.Expr>,
private val type: StaticType
) : Operator.Expr {

@PartiQLValueExperimental
override fun eval(record: Record): PartiQLValue {
return bagValue(
values.map { it.eval(record) }
)
return when (type) {
is BagType -> bagValue(values.map { it.eval(record) })
is SexpType -> sexpValue(values.map { it.eval(record) })
is ListType -> listValue(values.map { it.eval(record) })
else -> error("Unsupported type for collection $type")
}
}
}
Loading

0 comments on commit 97f9926

Please sign in to comment.