Skip to content

Commit

Permalink
Handles null calls in dynamic dispatch (#1436)
Browse files Browse the repository at this point in the history
  • Loading branch information
RCHowell authored Apr 23, 2024
1 parent cc05d7f commit 7289c9b
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 17 deletions.
29 changes: 27 additions & 2 deletions partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ import java.lang.IllegalStateException
internal class Compiler(
private val plan: PartiQLPlan,
private val session: PartiQLEngine.Session,
private val symbols: Symbols
private val symbols: Symbols,
) : PlanBaseVisitor<Operator, StaticType?>() {

fun compile(): Operator.Expr {
Expand Down Expand Up @@ -102,6 +102,7 @@ internal class Compiler(
val type = ctx ?: error("No type provided in ctx")
return ExprCollection(values, type)
}

override fun visitRexOpStruct(node: Rex.Op.Struct, ctx: StaticType?): Operator {
val fields = node.fields.map {
val value = visitRex(it.v, ctx).modeHandled()
Expand Down Expand Up @@ -222,13 +223,37 @@ internal class Compiler(
@OptIn(FnExperimental::class)
override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: StaticType?): Operator {
val args = node.args.map { visitRex(it, ctx).modeHandled() }.toTypedArray()
// Check candidate list size
when (node.candidates.size) {
0 -> error("Rex.Op.Call.Dynamic had an empty candidates list: $node.")
// TODO this seems like it should be an error, but is possible if the fn match was non-exhaustive
// 1 -> error("Rex.Op.Call.Dynamic had a single candidate; should be Rex.Op.Call.Static")
}
// Check candidate name and arity for uniformity
var arity: Int = -1
var name: String = "unknown"
// Compile the candidates
val candidates = Array(node.candidates.size) {
val candidate = node.candidates[it]
val fn = symbols.getFn(candidate.fn)
val coercions = candidate.coercions.toTypedArray()
// Check this candidate
val fnArity = fn.signature.parameters.size
val fnName = fn.signature.name.uppercase()
if (arity == -1) {
arity = fnArity
name = fnName
} else {
if (fnArity != arity) {
error("Dynamic call candidate had different arity than others; found $fnArity but expected $arity")
}
if (fnName != name) {
error("Dynamic call candidate had different name than others; found $fnName but expected $name")
}
}
ExprCallDynamic.Candidate(fn, coercions)
}
return ExprCallDynamic(candidates, args)
return ExprCallDynamic(name, candidates, args)
}

override fun visitRexOpCast(node: Rex.Op.Cast, ctx: StaticType?): Operator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package org.partiql.eval.internal.operator.rex

import org.partiql.errors.TypeCheckException
import org.partiql.eval.internal.Environment
import org.partiql.eval.internal.Record
import org.partiql.eval.internal.helpers.toNull
import org.partiql.eval.internal.operator.Operator
import org.partiql.plan.Ref
import org.partiql.spi.fn.Fn
Expand All @@ -21,7 +21,8 @@ import org.partiql.value.PartiQLValueType
*/
@OptIn(PartiQLValueExperimental::class, FnExperimental::class)
internal class ExprCallDynamic(
candidates: Array<Candidate>,
private val name: String,
private val candidates: Array<Candidate>,
private val args: Array<Operator.Expr>
) : Operator.Expr {

Expand All @@ -35,7 +36,7 @@ internal class ExprCallDynamic(
}
val errorString = buildString {
val argString = actualArgs.joinToString(", ")
append("Could not dynamically find function (${candidateIndex.name}) for arguments $argString.")
append("Could not dynamically find function ($name) for arguments $argString.")
}
throw TypeCheckException(errorString)
}
Expand All @@ -53,8 +54,16 @@ internal class ExprCallDynamic(
val coercions: Array<Ref.Cast?>
) {

/**
* Memoize creation of nulls
*/
private val nil = fn.signature.returns.toNull()

fun eval(originalArgs: Array<PartiQLValue>, env: Environment): PartiQLValue {
val args = originalArgs.mapIndexed { i, arg ->
if (arg.isNull && fn.signature.isNullCall) {
return nil()
}
when (val c = coercions[i]) {
null -> arg
else -> ExprCast(ExprLiteral(arg), c).eval(env)
Expand Down Expand Up @@ -111,12 +120,9 @@ internal class ExprCallDynamic(
*
* @param candidates
*/
class All(
candidates: Array<Candidate>,
) : CandidateIndex {
class All(private val candidates: Array<Candidate>) : CandidateIndex {

private val lookups: List<CandidateIndex>
internal val name: String = candidates.first().fn.signature.name

init {
val lookupsMutable = mutableListOf<CandidateIndex>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ internal class ExprCast(val arg: Operator.Expr, val cast: Ref.Cast) : Operator.E
PartiQLValueType.INTERVAL,
PartiQLValueType.BAG, PartiQLValueType.LIST,
PartiQLValueType.SEXP,
PartiQLValueType.STRUCT -> error("can not perform cast from INT8 to $t")
PartiQLValueType.STRUCT -> error("can not perform cast from struct to $t")
PartiQLValueType.NULL -> error("cast to null not supported")
PartiQLValueType.MISSING -> error("cast to missing not supported")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ internal class ExprSelect(
}
}

/**
* @param record
* @return
*/
@PartiQLValueExperimental
override fun eval(env: Environment): PartiQLValue {
val elements = Elements(input, constructor, env)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.MethodSource
import org.partiql.eval.PartiQLEngine
import org.partiql.eval.PartiQLResult
import org.partiql.eval.internal.PartiQLEngineDefaultTest.SuccessTestCase.Global
import org.partiql.parser.PartiQLParser
import org.partiql.plan.PartiQLPlan
import org.partiql.plan.debug.PlanPrinter
Expand Down Expand Up @@ -1171,7 +1172,7 @@ class PartiQLEngineDefaultTest {
val input: String,
val expected: PartiQLValue,
val mode: PartiQLEngine.Mode = PartiQLEngine.Mode.PERMISSIVE,
val globals: List<Global> = emptyList()
val globals: List<Global> = emptyList(),
) {

private val engine = PartiQLEngine.builder().build()
Expand Down Expand Up @@ -1243,7 +1244,7 @@ class PartiQLEngineDefaultTest {
public class TypingTestCase @OptIn(PartiQLValueExperimental::class) constructor(
val name: String,
val input: String,
val expectedPermissive: PartiQLValue
val expectedPermissive: PartiQLValue,
) {

private val engine = PartiQLEngine.builder().build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ class ExprCallDynamicTest {

@OptIn(PartiQLValueExperimental::class)
fun assert() {
val expr = ExprCallDynamic(candidates, args = arrayOf(ExprLiteral(lhs), ExprLiteral(rhs)))
val expr = ExprCallDynamic(
name = "example_function",
candidates = candidates,
args = arrayOf(ExprLiteral(lhs), ExprLiteral(rhs)),
)
val result = expr.eval(Environment.empty).check<Int32Value>()
assertEquals(expectedIndex, result.value)
}
Expand Down Expand Up @@ -72,6 +76,7 @@ class ExprCallDynamicTest {
FnParameter("second", type = it.second),
)
)

override fun invoke(args: Array<PartiQLValue>): PartiQLValue = int32Value(index)
},
coercions = arrayOf(null, null)
Expand Down

0 comments on commit 7289c9b

Please sign in to comment.