Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handles null calls in dynamic dispatch #1436

Merged
merged 2 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ import java.lang.IllegalStateException
internal class Compiler(
private val plan: PartiQLPlan,
private val session: PartiQLEngine.Session,
private val symbols: Symbols
private val symbols: Symbols,
) : PlanBaseVisitor<Operator, StaticType?>() {

fun compile(): Operator.Expr {
Expand Down Expand Up @@ -102,6 +102,7 @@ internal class Compiler(
val type = ctx ?: error("No type provided in ctx")
return ExprCollection(values, type)
}

override fun visitRexOpStruct(node: Rex.Op.Struct, ctx: StaticType?): Operator {
val fields = node.fields.map {
val value = visitRex(it.v, ctx).modeHandled()
Expand Down Expand Up @@ -222,13 +223,37 @@ internal class Compiler(
@OptIn(FnExperimental::class)
override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: StaticType?): Operator {
val args = node.args.map { visitRex(it, ctx).modeHandled() }.toTypedArray()
// Check candidate list size
when (node.candidates.size) {
0 -> error("Rex.Op.Call.Dynamic had an empty candidates list: $node.")
// TODO this seems like it should be an error, but is possible if the fn match was non-exhaustive
// 1 -> error("Rex.Op.Call.Dynamic had a single candidate; should be Rex.Op.Call.Static")
}
// Check candidate name and arity for uniformity
var arity: Int = -1
var name: String = "unknown"
// Compile the candidates
val candidates = Array(node.candidates.size) {
val candidate = node.candidates[it]
val fn = symbols.getFn(candidate.fn)
val coercions = candidate.coercions.toTypedArray()
// Check this candidate
val fnArity = fn.signature.parameters.size
val fnName = fn.signature.name.uppercase()
if (arity == -1) {
arity = fnArity
name = fnName
} else {
if (fnArity != arity) {
error("Dynamic call candidate had different arity than others; found $fnArity but expected $arity")
}
if (fnName != name) {
error("Dynamic call candidate had different name than others; found $fnName but expected $name")
}
}
ExprCallDynamic.Candidate(fn, coercions)
}
return ExprCallDynamic(candidates, args)
return ExprCallDynamic(name, candidates, args)
}

override fun visitRexOpCast(node: Rex.Op.Cast, ctx: StaticType?): Operator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package org.partiql.eval.internal.operator.rex

import org.partiql.errors.TypeCheckException
import org.partiql.eval.internal.Environment
import org.partiql.eval.internal.Record
import org.partiql.eval.internal.helpers.toNull
import org.partiql.eval.internal.operator.Operator
import org.partiql.plan.Ref
import org.partiql.spi.fn.Fn
Expand All @@ -21,7 +21,8 @@ import org.partiql.value.PartiQLValueType
*/
@OptIn(PartiQLValueExperimental::class, FnExperimental::class)
internal class ExprCallDynamic(
candidates: Array<Candidate>,
private val name: String,
private val candidates: Array<Candidate>,
private val args: Array<Operator.Expr>
) : Operator.Expr {

Expand All @@ -35,7 +36,7 @@ internal class ExprCallDynamic(
}
val errorString = buildString {
val argString = actualArgs.joinToString(", ")
append("Could not dynamically find function (${candidateIndex.name}) for arguments $argString.")
append("Could not dynamically find function ($name) for arguments $argString.")
}
throw TypeCheckException(errorString)
}
Expand All @@ -53,8 +54,16 @@ internal class ExprCallDynamic(
val coercions: Array<Ref.Cast?>
) {

/**
* Memoize creation of nulls
*/
private val nil = fn.signature.returns.toNull()

fun eval(originalArgs: Array<PartiQLValue>, env: Environment): PartiQLValue {
val args = originalArgs.mapIndexed { i, arg ->
if (arg.isNull && fn.signature.isNullCall) {
return nil()
}
when (val c = coercions[i]) {
null -> arg
else -> ExprCast(ExprLiteral(arg), c).eval(env)
Expand Down Expand Up @@ -111,12 +120,9 @@ internal class ExprCallDynamic(
*
* @param candidates
*/
class All(
candidates: Array<Candidate>,
) : CandidateIndex {
class All(private val candidates: Array<Candidate>) : CandidateIndex {

private val lookups: List<CandidateIndex>
internal val name: String = candidates.first().fn.signature.name

init {
val lookupsMutable = mutableListOf<CandidateIndex>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ internal class ExprCast(val arg: Operator.Expr, val cast: Ref.Cast) : Operator.E
PartiQLValueType.INTERVAL,
PartiQLValueType.BAG, PartiQLValueType.LIST,
PartiQLValueType.SEXP,
PartiQLValueType.STRUCT -> error("can not perform cast from INT8 to $t")
PartiQLValueType.STRUCT -> error("can not perform cast from struct to $t")
PartiQLValueType.NULL -> error("cast to null not supported")
PartiQLValueType.MISSING -> error("cast to missing not supported")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ internal class ExprSelect(
}
}

/**
* @param record
* @return
*/
@PartiQLValueExperimental
override fun eval(env: Environment): PartiQLValue {
val elements = Elements(input, constructor, env)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.MethodSource
import org.partiql.eval.PartiQLEngine
import org.partiql.eval.PartiQLResult
import org.partiql.eval.internal.PartiQLEngineDefaultTest.SuccessTestCase.Global
import org.partiql.parser.PartiQLParser
import org.partiql.plan.PartiQLPlan
import org.partiql.plan.debug.PlanPrinter
Expand Down Expand Up @@ -1171,7 +1172,7 @@ class PartiQLEngineDefaultTest {
val input: String,
val expected: PartiQLValue,
val mode: PartiQLEngine.Mode = PartiQLEngine.Mode.PERMISSIVE,
val globals: List<Global> = emptyList()
val globals: List<Global> = emptyList(),
) {

private val engine = PartiQLEngine.builder().build()
Expand Down Expand Up @@ -1243,7 +1244,7 @@ class PartiQLEngineDefaultTest {
public class TypingTestCase @OptIn(PartiQLValueExperimental::class) constructor(
val name: String,
val input: String,
val expectedPermissive: PartiQLValue
val expectedPermissive: PartiQLValue,
) {

private val engine = PartiQLEngine.builder().build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ class ExprCallDynamicTest {

@OptIn(PartiQLValueExperimental::class)
fun assert() {
val expr = ExprCallDynamic(candidates, args = arrayOf(ExprLiteral(lhs), ExprLiteral(rhs)))
val expr = ExprCallDynamic(
name = "example_function",
candidates = candidates,
args = arrayOf(ExprLiteral(lhs), ExprLiteral(rhs)),
)
val result = expr.eval(Environment.empty).check<Int32Value>()
assertEquals(expectedIndex, result.value)
}
Expand Down Expand Up @@ -72,6 +76,7 @@ class ExprCallDynamicTest {
FnParameter("second", type = it.second),
)
)

override fun invoke(args: Array<PartiQLValue>): PartiQLValue = int32Value(index)
},
coercions = arrayOf(null, null)
Expand Down
Loading