Skip to content

Commit

Permalink
Updates rules for variable resolution
Browse files Browse the repository at this point in the history
Adds support for casting from dynamic

Updates tests to give greater visibility into errors
  • Loading branch information
johnedquinn committed Jul 11, 2024
1 parent 7ed91d2 commit 304514e
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 142 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.partiql.value.Int64Value
import org.partiql.value.Int8Value
import org.partiql.value.IntValue
import org.partiql.value.ListValue
import org.partiql.value.MissingValue
import org.partiql.value.NullValue
import org.partiql.value.NumericValue
import org.partiql.value.PartiQLValue
Expand All @@ -46,6 +47,7 @@ import org.partiql.value.int64Value
import org.partiql.value.int8Value
import org.partiql.value.intValue
import org.partiql.value.listValue
import org.partiql.value.missingValue
import org.partiql.value.sexpValue
import org.partiql.value.stringValue
import org.partiql.value.structValue
Expand All @@ -59,7 +61,8 @@ import java.math.BigInteger
internal class ExprCast(val arg: Operator.Expr, val cast: Ref.Cast) : Operator.Expr {
@OptIn(PartiQLValueExperimental::class)
override fun eval(env: Environment): Datum {
val arg = arg.eval(env).toPartiQLValue()
val argDatum = arg.eval(env)
val arg = argDatum.toPartiQLValue()
try {
val partiqlValue = when (PType.fromPartiQLValueType(arg.type).kind) {
PType.Kind.DYNAMIC -> TODO("Not Possible")
Expand All @@ -86,9 +89,9 @@ internal class ExprCast(val arg: Operator.Expr, val cast: Ref.Cast) : Operator.E
PType.Kind.BAG -> castFromCollection(arg as BagValue<*>, cast.target)
PType.Kind.LIST -> castFromCollection(arg as ListValue<*>, cast.target)
PType.Kind.SEXP -> castFromCollection(arg as SexpValue<*>, cast.target)
PType.Kind.STRUCT -> TODO("CAST FROM STRUCT not yet implemented")
PType.Kind.STRUCT -> castFromStruct(argDatum, cast.target).toPartiQLValue()
PType.Kind.ROW -> TODO("CAST FROM ROW not yet implemented")
PType.Kind.UNKNOWN -> TODO("CAST FROM UNKNOWN not yet implemented")
PType.Kind.UNKNOWN -> castFromUnknown(arg, cast.target)
PType.Kind.VARCHAR -> TODO("CAST FROM VARCHAR not yet implemented")
}
return Datum.of(partiqlValue)
Expand All @@ -97,6 +100,22 @@ internal class ExprCast(val arg: Operator.Expr, val cast: Ref.Cast) : Operator.E
}
}

/**
* For now, we cannot cast from struct to anything else. Throw a type check exception.
*/
private fun castFromStruct(value: Datum, t: PType): Datum {
throw TypeCheckException()
}

@OptIn(PartiQLValueExperimental::class)
private fun castFromUnknown(value: PartiQLValue, t: PType): PartiQLValue {
return when (value) {
is NullValue -> castFromNull(value, t)
is MissingValue -> missingValue() // TODO: Is this allowed?
else -> error("This shouldn't have happened")
}
}

@OptIn(PartiQLValueExperimental::class)
private fun castFromNull(value: NullValue, t: PType): PartiQLValue {
return when (t.kind) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1161,7 +1161,16 @@ class PartiQLEngineDefaultTest {

internal fun assert() {
val permissiveResult = run(mode = PartiQLEngine.Mode.PERMISSIVE)
assert(expectedPermissive == permissiveResult.first) {
val assertionCondition = try {
expectedPermissive == permissiveResult.first
} catch (t: Throwable) {
val str = buildString {
appendLine("Test Name: $name")
PlanPrinter.append(this, permissiveResult.second)
}
throw RuntimeException(str, t)
}
assert(assertionCondition) {
comparisonString(expectedPermissive, permissiveResult.first, permissiveResult.second)
}
var error: Throwable? = null
Expand Down Expand Up @@ -1194,7 +1203,13 @@ class PartiQLEngineDefaultTest {
val prepared = engine.prepare(plan.plan, PartiQLEngine.Session(mapOf("memory" to connector), mode = mode))
when (val result = engine.execute(prepared)) {
is PartiQLResult.Value -> return result.value to plan.plan
is PartiQLResult.Error -> throw result.cause
is PartiQLResult.Error -> {
val str = buildString {
appendLine("Execution resulted in an unexpected error. Plan:")
PlanPrinter.append(this, plan.plan)
}
throw RuntimeException(str, result.cause)
}
}
}

Expand All @@ -1218,51 +1233,26 @@ class PartiQLEngineDefaultTest {
}

@Test
@Disabled
fun developmentTest() {
val tc = SuccessTestCase(
input = """
SELECT *
EXCLUDE
t.a.b.c[*].field_x
FROM [{
'a': {
'b': {
'c': [
{ -- c[0]; field_x to be removed
'field_x': 0,
'field_y': 0
},
{ -- c[1]; field_x to be removed
'field_x': 1,
'field_y': 1
},
{ -- c[2]; field_x to be removed
'field_x': 2,
'field_y': 2
}
]
}
}
}] AS t
""".trimIndent(),
expected = bagValue(
structValue(
"a" to structValue(
"b" to structValue(
"c" to listValue(
structValue(
"field_y" to int32Value(0)
),
structValue(
"field_y" to int32Value(1)
),
structValue(
"field_y" to int32Value(2)
)
)
)
)
)
SELECT VALUE
CASE x + 1
WHEN NULL THEN 'shouldnt be null'
WHEN MISSING THEN 'shouldnt be missing'
WHEN i THEN 'ONE'
WHEN f THEN 'TWO'
WHEN d THEN 'THREE'
ELSE '?'
END
FROM << i, f, d, null, missing >> AS x
""",
expected = boolValue(true),
globals = listOf(
Global("i", "1"),
Global("f", "2e0"),
Global("d", "3.")
)
)
tc.assert()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.partiql.planner.internal
import org.partiql.planner.internal.casts.Coercions
import org.partiql.planner.internal.ir.Ref
import org.partiql.planner.internal.typer.CompilerType
import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType
import org.partiql.spi.fn.FnExperimental
import org.partiql.spi.fn.FnSignature
import org.partiql.types.PType.Kind
Expand Down Expand Up @@ -144,10 +145,14 @@ internal object FnResolver {
exactInputTypes++
continue
}
// 2. Match ANY, no coercion needed
// TODO: Rewrite args in this scenario
arg.kind == Kind.UNKNOWN || p.type.kind == Kind.DYNAMIC || arg.kind == Kind.DYNAMIC -> continue
// 3. Check for a coercion
// 2. Match ANY parameter, no coercion needed
p.type.kind == Kind.DYNAMIC -> continue
arg.kind == Kind.UNKNOWN -> continue
// 3. Allow for ANY arguments
arg.kind == Kind.DYNAMIC -> {
mapping[i] = Ref.Cast(arg, p.type.toCType(), Ref.Cast.Safety.UNSAFE, true)
}
// 4. Check for a coercion
else -> when (val coercion = Coercions.get(arg, p.type)) {
null -> return null // short-circuit
else -> mapping[i] = coercion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,15 +436,15 @@ internal class PlanTyper(private val env: Env) {
override fun visitRelOpJoin(node: Rel.Op.Join, ctx: Rel.Type?): Rel {
// Rewrite LHS and RHS
val lhs = visitRel(node.lhs, ctx)
val stack = outer + listOf(TypeEnv(lhs.type.schema, outer))
val stack = outer + listOf(TypeEnv(env, lhs.type.schema, outer))
val rhs = RelTyper(stack, Scope.GLOBAL).visitRel(node.rhs, ctx)

// Calculate output schema given JOIN type
val schema = lhs.type.schema + rhs.type.schema
val type = relType(schema, ctx!!.props)

// Type the condition on the output schema
val condition = node.rex.type(TypeEnv(type.schema, outer))
val condition = node.rex.type(TypeEnv(env, type.schema, outer))

val op = relOpJoin(lhs, rhs, condition, node.type)
return rel(type, op)
Expand Down Expand Up @@ -497,7 +497,7 @@ internal class PlanTyper(private val env: Env) {
val resolvedRoot = when (val root = path.root) {
is Rex.Op.Var.Unresolved -> {
// resolve `root` to local binding
val locals = TypeEnv(input.type.schema, outer)
val locals = TypeEnv(env, input.type.schema, outer)
val path = root.identifier.toBindingPath()
val resolved = locals.resolve(path)
if (resolved == null) {
Expand Down Expand Up @@ -537,7 +537,7 @@ internal class PlanTyper(private val env: Env) {
val input = visitRel(node.input, ctx)

// type the calls and groups
val typer = RexTyper(TypeEnv(input.type.schema, outer), Scope.LOCAL)
val typer = RexTyper(TypeEnv(env, input.type.schema, outer), Scope.LOCAL)

// typing of aggregate calls is slightly more complicated because they are not expressions.
val calls = node.calls.mapIndexed { i, call ->
Expand Down Expand Up @@ -607,8 +607,8 @@ internal class PlanTyper(private val env: Env) {
Rex.Op.Var.Scope.LOCAL -> Scope.LOCAL
}
val resolvedVar = when (scope) {
Scope.LOCAL -> locals.resolve(path) ?: env.resolveObj(path)
Scope.GLOBAL -> env.resolveObj(path) ?: locals.resolve(path)
Scope.LOCAL -> locals.resolve(path, TypeEnv.LookupStrategy.LOCALS_FIRST)
Scope.GLOBAL -> locals.resolve(path, TypeEnv.LookupStrategy.GLOBALS_FIRST)
}
if (resolvedVar == null) {
val id = PlanUtils.externalize(node.identifier)
Expand Down Expand Up @@ -1062,7 +1062,7 @@ internal class PlanTyper(private val env: Env) {
override fun visitRexOpPivot(node: Rex.Op.Pivot, ctx: CompilerType?): Rex {
val stack = locals.outer + listOf(locals)
val rel = node.rel.type(stack)
val typeEnv = TypeEnv(rel.type.schema, stack)
val typeEnv = TypeEnv(env, rel.type.schema, stack)
val typer = RexTyper(typeEnv, Scope.LOCAL)
val key = typer.visitRex(node.key, null)
val value = typer.visitRex(node.value, null)
Expand All @@ -1072,7 +1072,7 @@ internal class PlanTyper(private val env: Env) {

override fun visitRexOpSubquery(node: Rex.Op.Subquery, ctx: CompilerType?): Rex {
val rel = node.rel.type(locals.outer + listOf(locals))
val newTypeEnv = TypeEnv(schema = rel.type.schema, outer = locals.outer + listOf(locals))
val newTypeEnv = TypeEnv(env, schema = rel.type.schema, outer = locals.outer + listOf(locals))
val constructor = node.constructor.type(newTypeEnv)
val subquery = rexOpSubquery(constructor, rel, node.coercion)
return when (node.coercion) {
Expand Down Expand Up @@ -1119,7 +1119,7 @@ internal class PlanTyper(private val env: Env) {
// TODO: Should we support the ROW type?
override fun visitRexOpSelect(node: Rex.Op.Select, ctx: CompilerType?): Rex {
val rel = node.rel.type(locals.outer + listOf(locals))
val newTypeEnv = TypeEnv(schema = rel.type.schema, outer = locals.outer + listOf(locals))
val newTypeEnv = TypeEnv(env, schema = rel.type.schema, outer = locals.outer + listOf(locals))
val constructor = node.constructor.type(newTypeEnv)
val type = when (rel.isOrdered()) {
true -> PType.typeList(constructor.type)
Expand Down Expand Up @@ -1296,7 +1296,7 @@ internal class PlanTyper(private val env: Env) {
* This types the [Rex] given the input record ([input]) and [stack] of [TypeEnv] (representing the outer scopes).
*/
private fun Rex.type(input: List<Rel.Binding>, stack: List<TypeEnv>, strategy: Scope = Scope.LOCAL) =
RexTyper(TypeEnv(input, stack), strategy).visitRex(this, this.type)
RexTyper(TypeEnv(env, input, stack), strategy).visitRex(this, this.type)

/**
* This types the [Rex] given a [TypeEnv]. We use the [TypeEnv.schema] as the input schema and the [TypeEnv.outer]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.partiql.planner.internal.typer

import org.partiql.planner.internal.Env
import org.partiql.planner.internal.ir.Rel
import org.partiql.planner.internal.ir.Rex
import org.partiql.planner.internal.ir.rex
Expand All @@ -22,10 +23,16 @@ import org.partiql.value.stringValue
* @property outer refers to the outer variable scopes that we have access to.
*/
internal data class TypeEnv(
private val globals: Env,
public val schema: List<Rel.Binding>,
public val outer: List<TypeEnv>
) {

enum class LookupStrategy {
LOCALS_FIRST,
GLOBALS_FIRST
}

internal fun getScope(depth: Int): TypeEnv {
return when (depth) {
0 -> this
Expand All @@ -34,24 +41,48 @@ internal data class TypeEnv(
}

/**
* We resolve a local with the following rules. See, PartiQL Specification p.35.
*
* 1) Check if the path root unambiguously matches a local binding name, set as root.
* 2) Check if the path root unambiguously matches a local binding struct value field.
*
* Search Algorithm (LOCALS_FIRST):
* 1. Match Binding Name
* - Match Locals
* - Match Globals
* 2. Match Nested Field
* - Match Locals
* Search Algorithm (GLOBALS_FIRST):
* 1. Match Binding Name
* - Match Globals
* - Match Locals
* 2. Match Nested Field
* - Match Locals
*/
fun resolve(path: BindingPath, strategy: LookupStrategy = LookupStrategy.LOCALS_FIRST): Rex? {
return when (strategy) {
LookupStrategy.LOCALS_FIRST -> resolveLocalName(path) ?: globals.resolveObj(path) ?: resolveLocalField(path)
LookupStrategy.GLOBALS_FIRST -> globals.resolveObj(path) ?: resolveLocalName(path) ?: resolveLocalField(path)
}
}

/**
* Attempts to resolve using just the local binding name.
*/
private fun resolveLocalName(path: BindingPath): Rex? {
val head: BindingName = path.steps[0]
val tail: List<BindingName> = path.steps.drop(1)
val r = matchRoot(head) ?: return null
// Convert any remaining binding names (tail) to an untyped path expression.
return if (tail.isEmpty()) r else r.toPath(tail)
}

/**
* Check if the path root unambiguously matches a local binding struct value field.
* Convert any remaining binding names (tail) to a path expression.
*
* @param path
* @return
*/
fun resolve(path: BindingPath): Rex? {
private fun resolveLocalField(path: BindingPath): Rex? {
val head: BindingName = path.steps[0]
var tail: List<BindingName> = path.steps.drop(1)
var r = matchRoot(head)
if (r == null) {
r = matchStruct(head) ?: return null
tail = path.steps
}
val r = matchStruct(head) ?: return null
val tail = path.steps
// Convert any remaining binding names (tail) to an untyped path expression.
return if (tail.isEmpty()) r else r.toPath(tail)
}
Expand Down
Loading

0 comments on commit 304514e

Please sign in to comment.