Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RelSort and comparator between PartiQLValue #1343

Merged
merged 6 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.partiql.eval

import org.partiql.eval.internal.Compiler
import org.partiql.eval.internal.Record
import org.partiql.eval.internal.Symbols
import org.partiql.plan.PartiQLPlan
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
Expand All @@ -11,7 +12,10 @@ internal class PartiQLEngineDefault : PartiQLEngine {
@OptIn(PartiQLValueExperimental::class)
override fun prepare(plan: PartiQLPlan, session: PartiQLEngine.Session): PartiQLStatement<*> {
try {
val compiler = Compiler(plan, session)
// 1. Validate all references
val symbols = Symbols.build(plan, session)
// 2. Compile with built symbols
val compiler = Compiler(plan, session, symbols)
val expression = compiler.compile()
return object : PartiQLStatement.Query {
override fun execute(): PartiQLValue {
Expand Down
92 changes: 52 additions & 40 deletions partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import org.partiql.eval.internal.operator.rel.RelScan
import org.partiql.eval.internal.operator.rel.RelScanIndexed
import org.partiql.eval.internal.operator.rel.RelScanIndexedPermissive
import org.partiql.eval.internal.operator.rel.RelScanPermissive
import org.partiql.eval.internal.operator.rel.RelSort
import org.partiql.eval.internal.operator.rex.ExprCallDynamic
import org.partiql.eval.internal.operator.rex.ExprCallStatic
import org.partiql.eval.internal.operator.rex.ExprCase
Expand All @@ -40,73 +41,74 @@ import org.partiql.plan.Statement
import org.partiql.plan.debug.PlanPrinter
import org.partiql.plan.visitor.PlanBaseVisitor
import org.partiql.spi.fn.FnExperimental
import org.partiql.types.StaticType
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.PartiQLValueType
import java.lang.IllegalStateException

internal class Compiler(
private val plan: PartiQLPlan,
private val session: PartiQLEngine.Session
) : PlanBaseVisitor<Operator, Symbols>() {
private val session: PartiQLEngine.Session,
private val symbols: Symbols
) : PlanBaseVisitor<Operator, StaticType?>() {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Following rebase of partiql-eval) Change ctx to have type StaticType?. Initialize symbols as part of Compiler constructor since it is not scope sensitive.


fun compile(): Operator.Expr {
// 1. Validate all references
val symbols = Symbols.build(plan, session)
// 2. Compile with built symbols
return visitPartiQLPlan(plan, symbols)
return visitPartiQLPlan(plan, null)
}

override fun defaultReturn(node: PlanNode, ctx: Symbols): Operator {
override fun defaultReturn(node: PlanNode, ctx: StaticType?): Operator {
TODO("Not yet implemented")
}

override fun visitRexOpErr(node: Rex.Op.Err, ctx: Symbols): Operator {
override fun visitRexOpErr(node: Rex.Op.Err, ctx: StaticType?): Operator {
val message = buildString {
this.appendLine(node.message)
PlanPrinter.append(this, plan)
}
throw IllegalStateException(message)
}

override fun visitRelOpErr(node: Rel.Op.Err, ctx: Symbols): Operator {
override fun visitRelOpErr(node: Rel.Op.Err, ctx: StaticType?): Operator {
throw IllegalStateException(node.message)
}

// TODO: Re-look at
override fun visitPartiQLPlan(node: PartiQLPlan, ctx: Symbols): Operator.Expr {
override fun visitPartiQLPlan(node: PartiQLPlan, ctx: StaticType?): Operator.Expr {
return visitStatement(node.statement, ctx) as Operator.Expr
}

// TODO: Re-look at
override fun visitStatementQuery(node: Statement.Query, ctx: Symbols): Operator.Expr {
override fun visitStatementQuery(node: Statement.Query, ctx: StaticType?): Operator.Expr {
return visitRex(node.root, ctx).modeHandled()
}

// REX

override fun visitRex(node: Rex, ctx: Symbols): Operator.Expr {
return super.visitRexOp(node.op, ctx) as Operator.Expr
override fun visitRex(node: Rex, ctx: StaticType?): Operator.Expr {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current plan, the rex node's type is only accessible on the Rex and not the inner classes (e.g. Rex.Op.Collection). Passing it along in the context allows the visitor for the inner classes to access the type, which is needed by some nodes like Rex.Op.Collection.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would really like for our internal IRs to keep the types in the node by defining a type field as part of the PlanNode base class. This would require more codegen work which probably isn't worth the time at the moment. Please let me know if you have an additional ideas here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I'd need to play around with the generated code a bit to see if there's a better way. Only concern I'd have is that the type field for Rex is a different type (i.e. StaticType) than the type field for Rel (i.e. Rel.Type which has schema and props). Probably that representation in the PlanNode base class would be an enum that could be either of those type definitions?

Anyways, perhaps we should tackle this in another issue/PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you check out PlanTyper, you'll see we use two different visitors since the type is parameterized. There are many ways around this, but really I wish we didn't have the union types. What would be better is for the code generator to support abstract fields.

The best situation however would be handwritten nodes with annotation based generation like Lombok. This would give us the most control.

Copy link
Member Author

@alancai98 alancai98 Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you check out PlanTyper, you'll see we use two different visitors since the type is parameterized.

Oh I see. The separate visitors for Rel and Rex doesn't seem too cumbersome.

The best situation however would be handwritten nodes with annotation based generation like Lombok. This would give us the most control.

Agree w/ a mix of code-generated nodes and handwritten nodes would give us the most flexibility when it comes to these interfaces.

return super.visitRexOp(node.op, node.type) as Operator.Expr
}

override fun visitRexOpCollection(node: Rex.Op.Collection, ctx: Symbols): Operator {
override fun visitRexOpCollection(node: Rex.Op.Collection, ctx: StaticType?): Operator {
val values = node.values.map { visitRex(it, ctx).modeHandled() }
return ExprCollection(values)
val type = ctx ?: error("No type provided in ctx")
return ExprCollection(values, type)
}
override fun visitRexOpStruct(node: Rex.Op.Struct, ctx: Symbols): Operator {
override fun visitRexOpStruct(node: Rex.Op.Struct, ctx: StaticType?): Operator {
val fields = node.fields.map {
val value = visitRex(it.v, ctx).modeHandled()
ExprStruct.Field(visitRex(it.k, ctx), value)
}
return ExprStruct(fields)
}

override fun visitRexOpSelect(node: Rex.Op.Select, ctx: Symbols): Operator {
override fun visitRexOpSelect(node: Rex.Op.Select, ctx: StaticType?): Operator {
val rel = visitRel(node.rel, ctx)
val ordered = node.rel.type.props.contains(Rel.Prop.ORDERED)
val constructor = visitRex(node.constructor, ctx).modeHandled()
return ExprSelect(rel, constructor)
return ExprSelect(rel, constructor, ordered)
}

override fun visitRexOpPivot(node: Rex.Op.Pivot, ctx: Symbols): Operator {
override fun visitRexOpPivot(node: Rex.Op.Pivot, ctx: StaticType?): Operator {
val rel = visitRel(node.rel, ctx)
val key = visitRex(node.key, ctx)
val value = visitRex(node.value, ctx)
Expand All @@ -115,33 +117,33 @@ internal class Compiler(
PartiQLEngine.Mode.STRICT -> ExprPivot(rel, key, value)
}
}
override fun visitRexOpVar(node: Rex.Op.Var, ctx: Symbols): Operator {
override fun visitRexOpVar(node: Rex.Op.Var, ctx: StaticType?): Operator {
return ExprVar(node.ref)
}

override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: Symbols): Operator = ctx.getGlobal(node.ref)
override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: StaticType?): Operator = symbols.getGlobal(node.ref)

override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: Symbols): Operator {
override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: StaticType?): Operator {
val root = visitRex(node.root, ctx)
val key = visitRex(node.key, ctx)
return ExprPathKey(root, key)
}

override fun visitRexOpPathSymbol(node: Rex.Op.Path.Symbol, ctx: Symbols): Operator {
override fun visitRexOpPathSymbol(node: Rex.Op.Path.Symbol, ctx: StaticType?): Operator {
val root = visitRex(node.root, ctx)
val symbol = node.key
return ExprPathSymbol(root, symbol)
}

override fun visitRexOpPathIndex(node: Rex.Op.Path.Index, ctx: Symbols): Operator {
override fun visitRexOpPathIndex(node: Rex.Op.Path.Index, ctx: StaticType?): Operator {
val root = visitRex(node.root, ctx)
val index = visitRex(node.key, ctx)
return ExprPathIndex(root, index)
}

@OptIn(FnExperimental::class, PartiQLValueExperimental::class)
override fun visitRexOpCallStatic(node: Rex.Op.Call.Static, ctx: Symbols): Operator {
val fn = ctx.getFn(node.fn)
override fun visitRexOpCallStatic(node: Rex.Op.Call.Static, ctx: StaticType?): Operator {
val fn = symbols.getFn(node.fn)
val args = node.args.map { visitRex(it, ctx) }.toTypedArray()
val fnTakesInMissing = fn.signature.parameters.any {
it.type == PartiQLValueType.MISSING || it.type == PartiQLValueType.ANY
Expand All @@ -153,54 +155,54 @@ internal class Compiler(
}

@OptIn(FnExperimental::class, PartiQLValueExperimental::class)
override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: Symbols): Operator {
override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: StaticType?): Operator {
val args = node.args.map { visitRex(it, ctx).modeHandled() }.toTypedArray()
val candidates = node.candidates.map { candidate ->
val fn = ctx.getFn(candidate.fn)
val fn = symbols.getFn(candidate.fn)
val types = fn.signature.parameters.map { it.type }.toTypedArray()
val coercions = candidate.coercions.toTypedArray()
ExprCallDynamic.Candidate(fn, types, coercions)
}
return ExprCallDynamic(candidates, args)
}

override fun visitRexOpCast(node: Rex.Op.Cast, ctx: Symbols): Operator {
override fun visitRexOpCast(node: Rex.Op.Cast, ctx: StaticType?): Operator {
return ExprCast(visitRex(node.arg, ctx), node.cast)
}

// REL
override fun visitRel(node: Rel, ctx: Symbols): Operator.Relation {
override fun visitRel(node: Rel, ctx: StaticType?): Operator.Relation {
return super.visitRelOp(node.op, ctx) as Operator.Relation
}

override fun visitRelOpScan(node: Rel.Op.Scan, ctx: Symbols): Operator {
override fun visitRelOpScan(node: Rel.Op.Scan, ctx: StaticType?): Operator {
val rex = visitRex(node.rex, ctx)
return when (session.mode) {
PartiQLEngine.Mode.PERMISSIVE -> RelScanPermissive(rex)
PartiQLEngine.Mode.STRICT -> RelScan(rex)
}
}

override fun visitRelOpProject(node: Rel.Op.Project, ctx: Symbols): Operator {
override fun visitRelOpProject(node: Rel.Op.Project, ctx: StaticType?): Operator {
val input = visitRel(node.input, ctx)
val projections = node.projections.map { visitRex(it, ctx).modeHandled() }
return RelProject(input, projections)
}

override fun visitRelOpScanIndexed(node: Rel.Op.ScanIndexed, ctx: Symbols): Operator {
override fun visitRelOpScanIndexed(node: Rel.Op.ScanIndexed, ctx: StaticType?): Operator {
val rex = visitRex(node.rex, ctx)
return when (session.mode) {
PartiQLEngine.Mode.PERMISSIVE -> RelScanIndexedPermissive(rex)
PartiQLEngine.Mode.STRICT -> RelScanIndexed(rex)
}
}

override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: Symbols): Operator {
override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: StaticType?): Operator {
val args = node.args.map { visitRex(it, ctx) }.toTypedArray()
return ExprTupleUnion(args)
}

override fun visitRelOpJoin(node: Rel.Op.Join, ctx: Symbols): Operator {
override fun visitRelOpJoin(node: Rel.Op.Join, ctx: StaticType?): Operator {
val lhs = visitRel(node.lhs, ctx)
val rhs = visitRel(node.rhs, ctx)
val condition = visitRex(node.rex, ctx)
Expand All @@ -212,7 +214,7 @@ internal class Compiler(
}
}

override fun visitRexOpCase(node: Rex.Op.Case, ctx: Symbols): Operator {
override fun visitRexOpCase(node: Rex.Op.Case, ctx: StaticType?): Operator {
val branches = node.branches.map { branch ->
visitRex(branch.condition, ctx) to visitRex(branch.rex, ctx)
}
Expand All @@ -221,26 +223,36 @@ internal class Compiler(
}

@OptIn(PartiQLValueExperimental::class)
override fun visitRexOpLit(node: Rex.Op.Lit, ctx: Symbols): Operator {
override fun visitRexOpLit(node: Rex.Op.Lit, ctx: StaticType?): Operator {
return ExprLiteral(node.value)
}

override fun visitRelOpDistinct(node: Rel.Op.Distinct, ctx: Symbols): Operator {
override fun visitRelOpDistinct(node: Rel.Op.Distinct, ctx: StaticType?): Operator {
val input = visitRel(node.input, ctx)
return RelDistinct(input)
}

override fun visitRelOpFilter(node: Rel.Op.Filter, ctx: Symbols): Operator {
override fun visitRelOpFilter(node: Rel.Op.Filter, ctx: StaticType?): Operator {
val input = visitRel(node.input, ctx)
val condition = visitRex(node.predicate, ctx)
return RelFilter(input, condition)
}

override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: Symbols): Operator {
override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: StaticType?): Operator {
val input = visitRel(node.input, ctx)
return RelExclude(input, node.paths)
}

override fun visitRelOpSort(node: Rel.Op.Sort, ctx: StaticType?): Operator {
val input = visitRel(node.input, ctx)
val compiledSpecs = node.specs.map { spec ->
val expr = visitRex(spec.rex, ctx)
val order = spec.order
expr to order
}
return RelSort(input, compiledSpecs)
}

// HELPERS

private fun Operator.Expr.modeHandled(): Operator.Expr {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ internal class RelOffset(
input.next() ?: return null
seen += 1
}
init = true
}
return input.next()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package org.partiql.eval.internal.operator.rel

import org.partiql.eval.internal.Record
import org.partiql.eval.internal.operator.Operator
import org.partiql.plan.Rel
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import java.util.Collections

@OptIn(PartiQLValueExperimental::class)
internal class RelSort(
private val input: Operator.Relation,
private val specs: List<Pair<Operator.Expr, Rel.Op.Sort.Order>>

) : Operator.Relation {
private var records: Iterator<Record> = Collections.emptyIterator()
private var init: Boolean = false

private val nullsFirstComparator = PartiQLValue.comparator(nullsFirst = true)
private val nullsLastComparator = PartiQLValue.comparator(nullsFirst = false)

override fun open() {
input.open()
init = false
records = Collections.emptyIterator()
}

private val comparator = object : Comparator<Record> {
override fun compare(l: Record, r: Record): Int {
specs.forEach { spec ->
val lVal = spec.first.eval(l)
val rVal = spec.first.eval(r)

// DESC_NULLS_FIRST(l, r) == ASC_NULLS_LAST(r, l)
// DESC_NULLS_LAST(l, r) == ASC_NULLS_FIRST(r, l)
val cmpResult = when (spec.second) {
Rel.Op.Sort.Order.ASC_NULLS_FIRST -> nullsFirstComparator.compare(lVal, rVal)
Rel.Op.Sort.Order.ASC_NULLS_LAST -> nullsLastComparator.compare(lVal, rVal)
Rel.Op.Sort.Order.DESC_NULLS_FIRST -> nullsLastComparator.compare(rVal, lVal)
Rel.Op.Sort.Order.DESC_NULLS_LAST -> nullsFirstComparator.compare(rVal, lVal)
}
if (cmpResult != 0) {
return cmpResult
}
}
return 0 // Equal
}
}

override fun next(): Record? {
if (!init) {
val sortedRecords = mutableListOf<Record>()
while (true) {
val row = input.next() ?: break
sortedRecords.add(row)
}
sortedRecords.sortWith(comparator)
records = sortedRecords.iterator()
init = true
}
return when (records.hasNext()) {
true -> records.next()
false -> null
}
}

override fun close() {
input.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,28 @@ package org.partiql.eval.internal.operator.rex

import org.partiql.eval.internal.Record
import org.partiql.eval.internal.operator.Operator
import org.partiql.types.BagType
import org.partiql.types.ListType
import org.partiql.types.SexpType
import org.partiql.types.StaticType
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.bagValue
import org.partiql.value.listValue
import org.partiql.value.sexpValue

internal class ExprCollection(
private val values: List<Operator.Expr>
private val values: List<Operator.Expr>,
private val type: StaticType
) : Operator.Expr {

@PartiQLValueExperimental
override fun eval(record: Record): PartiQLValue {
return bagValue(
values.map { it.eval(record) }
)
return when (type) {
is BagType -> bagValue(values.map { it.eval(record) })
is SexpType -> sexpValue(values.map { it.eval(record) })
is ListType -> listValue(values.map { it.eval(record) })
else -> error("Unsupported type for collection $type")
}
}
}
Loading
Loading