diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt index 888cb4f238..6fddc843fc 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt @@ -4,9 +4,15 @@ import org.partiql.planner.PartiQLPlanner import org.partiql.planner.internal.ir.Agg import org.partiql.planner.internal.ir.Catalog import org.partiql.planner.internal.ir.Fn -import org.partiql.planner.internal.ir.Rel import org.partiql.planner.internal.ir.Rex +import org.partiql.planner.internal.ir.catalogSymbolRef +import org.partiql.planner.internal.ir.rex +import org.partiql.planner.internal.ir.rexOpGlobal +import org.partiql.planner.internal.ir.rexOpLit +import org.partiql.planner.internal.ir.rexOpPathKey +import org.partiql.planner.internal.ir.rexOpPathSymbol import org.partiql.planner.internal.typer.FnResolver +import org.partiql.planner.internal.typer.TypeEnv import org.partiql.spi.BindingCase import org.partiql.spi.BindingName import org.partiql.spi.BindingPath @@ -15,8 +21,9 @@ import org.partiql.spi.connector.ConnectorObjectHandle import org.partiql.spi.connector.ConnectorObjectPath import org.partiql.spi.connector.ConnectorSession import org.partiql.types.StaticType -import org.partiql.types.StructType -import org.partiql.types.TupleConstraint +import org.partiql.types.function.FunctionSignature +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.stringValue /** * Handle for associating a catalog name with catalog related metadata objects. @@ -24,84 +31,19 @@ import org.partiql.types.TupleConstraint internal typealias Handle = Pair /** - * TypeEnv represents the environment in which we type expressions and resolve variables while planning. + * Metadata for a resolved global variable * - * TODO TypeEnv should be a stack of locals; also the strategy has been kept here because it's easier to - * pass through the traversal like this, but is conceptually odd to associate with the TypeEnv. - * @property schema - * @property strategy - */ -internal class TypeEnv( - val schema: List, - val strategy: ResolutionStrategy, -) { - - /** - * Return a copy with GLOBAL lookup strategy - */ - fun global() = TypeEnv(schema, ResolutionStrategy.GLOBAL) - - /** - * Return a copy with LOCAL lookup strategy - */ - fun local() = TypeEnv(schema, ResolutionStrategy.LOCAL) - - /** - * Debug string - */ - override fun toString() = buildString { - append("(") - append("strategy=$strategy") - append(", ") - val bindings = "< " + schema.joinToString { "${it.name}: ${it.type}" } + " >" - append("bindings=$bindings") - append(")") - } -} - -/** - * Metadata regarding a resolved variable. + * @property type Resolved StaticType + * @property ordinal The relevant catalog's index offset in the [Env.catalogs] list * @property depth The depth/level of the path match. + * @property position The relevant value's index offset in the [Catalog.values] list */ -internal sealed interface ResolvedVar { - - public val type: StaticType - public val ordinal: Int - public val depth: Int - - /** - * Metadata for a resolved local variable. - * - * @property type Resolved StaticType - * @property ordinal Index offset in [TypeEnv] - * @property resolvedSteps The fully resolved path steps.s - */ - class Local( - override val type: StaticType, - override val ordinal: Int, - val rootType: StaticType, - val resolvedSteps: List, - ) : ResolvedVar { - // the depth are always going to be 1 because this is local variable. - // the global path, however the path length maybe, going to be replaced by a binding name. - override val depth: Int = 1 - } - - /** - * Metadata for a resolved global variable - * - * @property type Resolved StaticType - * @property ordinal The relevant catalog's index offset in the [Env.catalogs] list - * @property depth The depth/level of the path match. - * @property position The relevant value's index offset in the [Catalog.values] list - */ - class Global( - override val type: StaticType, - override val ordinal: Int, - override val depth: Int, - val position: Int, - ) : ResolvedVar -} +internal class ResolvedVar( + val type: StaticType, + val ordinal: Int, + val depth: Int, + val position: Int, +) /** * Variable resolution strategies — https://partiql.org/assets/PartiQL-Specification.pdf#page=35 @@ -220,7 +162,7 @@ internal class Env( type ) // Return resolution metadata - ResolvedVar.Global(type, catalogIndex, depth, valueIndex) + ResolvedVar(type, catalogIndex, depth, valueIndex) } } } @@ -267,31 +209,13 @@ internal class Env( } } - private fun BindingPath.toCaseSensitive(): BindingPath { - return this.copy(steps = this.steps.map { it.copy(bindingCase = BindingCase.SENSITIVE) }) - } - /** * Attempt to resolve a [BindingPath] in the global + local type environments. */ - fun resolve(path: BindingPath, locals: TypeEnv, scope: Rex.Op.Var.Scope): ResolvedVar? { - val strategy = when (scope) { - Rex.Op.Var.Scope.DEFAULT -> locals.strategy - Rex.Op.Var.Scope.LOCAL -> ResolutionStrategy.LOCAL - } + fun resolve(path: BindingPath, locals: TypeEnv, strategy: ResolutionStrategy): Rex? { return when (strategy) { - ResolutionStrategy.LOCAL -> { - var type: ResolvedVar? = null - type = type ?: resolveLocalBind(path, locals.schema) - type = type ?: resolveGlobalBind(path) - type - } - ResolutionStrategy.GLOBAL -> { - var type: ResolvedVar? = null - type = type ?: resolveGlobalBind(path) - type = type ?: resolveLocalBind(path, locals.schema) - type - } + ResolutionStrategy.LOCAL -> locals.resolve(path) ?: resolveGlobalBind(path) + ResolutionStrategy.GLOBAL -> resolveGlobalBind(path) ?: locals.resolve(path) } } @@ -305,7 +229,7 @@ internal class Env( * TODO: Add global bindings * TODO: Replace paths with global variable references if found */ - private fun resolveGlobalBind(path: BindingPath): ResolvedVar? { + private fun resolveGlobalBind(path: BindingPath): Rex? { val currentCatalog = session.currentCatalog?.let { BindingName(it, BindingCase.SENSITIVE) } val currentCatalogPath = BindingPath(session.currentDirectory.map { BindingName(it, BindingCase.SENSITIVE) }) val absoluteCatalogPath = BindingPath(currentCatalogPath.steps + path.steps) @@ -320,122 +244,13 @@ internal class Env( ?: getGlobalType(currentCatalog, path, path) ?: getGlobalType(currentCatalog, path, absoluteCatalogPath) } - } - return resolvedVar - } - - /** - * Check locals, else search structs. - */ - internal fun resolveLocalBind(path: BindingPath, locals: List): ResolvedVar? { - if (path.steps.isEmpty()) { - return null - } - - // 1. Check locals for root - locals.forEachIndexed { ordinal, binding -> - val root = path.steps[0] - if (root.isEquivalentTo(binding.name)) { - return ResolvedVar.Local(binding.type, ordinal, binding.type, path.steps) - } - } - - // 2. Check if this variable is referencing a struct field, carrying ordinals - val matches = mutableListOf() - for (ordinal in locals.indices) { - val rootType = locals[ordinal].type - val pathPrefix = BindingName(locals[ordinal].name, BindingCase.SENSITIVE) - if (rootType is StructType) { - val varType = inferStructLookup(rootType, path) - if (varType != null) { - // we found this path within a struct! - val match = ResolvedVar.Local( - varType.resolvedType, - ordinal, - rootType, - listOf(pathPrefix) + varType.replacementPath.steps, - ) - matches.add(match) - } - } - } - - // 0 -> no match - // 1 -> resolved - // N -> ambiguous - return when (matches.size) { - 0 -> null - 1 -> matches.single() - else -> null // TODO emit ambiguous error - } + } ?: return null + // rewrite as path expression for any remaining steps. + val root = rex(resolvedVar.type, rexOpGlobal(catalogSymbolRef(resolvedVar.ordinal, resolvedVar.position))) + val tail = path.steps.drop(resolvedVar.depth) + return if (tail.isEmpty()) root else root.toPath(tail) } - /** - * Searches for the path within the given struct, returning null if not found. - * - * @return a [ResolvedPath] that contains the disambiguated [ResolvedPath.replacementPath] and the path's - * [StaticType]. Returns NULL if unable to find the [path] given the [struct]. - */ - private fun inferStructLookup(struct: StructType, path: BindingPath): ResolvedPath? { - var curr: StaticType = struct - val replacementSteps = path.steps.map { step -> - // Assume ORDERED for now - val currentStruct = curr as? StructType ?: return null - val (replacement, stepType) = inferStructLookup(currentStruct, step) ?: return null - curr = stepType - replacement - } - // Lookup final field - return ResolvedPath( - BindingPath(replacementSteps), - curr - ) - } - - /** - * Represents a disambiguated [BindingPath] and its inferred [StaticType]. - */ - private class ResolvedPath( - val replacementPath: BindingPath, - val resolvedType: StaticType, - ) - - /** - * @return a disambiguated [key] and the resulting [StaticType]. - */ - private fun inferStructLookup(struct: StructType, key: BindingName): Pair? { - val isClosed = struct.constraints.contains(TupleConstraint.Open(false)) - val isOrdered = struct.constraints.contains(TupleConstraint.Ordered) - return when { - // 1. Struct is closed and ordered - isClosed && isOrdered -> { - struct.fields.firstOrNull { entry -> key.isEquivalentTo(entry.key) }?.let { - (sensitive(it.key) to it.value) - } - } - // 2. Struct is closed - isClosed -> { - val matches = struct.fields.filter { entry -> key.isEquivalentTo(entry.key) } - when (matches.size) { - 0 -> null - 1 -> matches.first().let { (sensitive(it.key) to it.value) } - else -> { - val firstKey = matches.first().key - val sharedKey = when (matches.all { it.key == firstKey }) { - true -> sensitive(firstKey) - false -> key - } - sharedKey to StaticType.unionOf(matches.map { it.value }.toSet()).flatten() - } - } - } - // 3. Struct is open - else -> key to StaticType.ANY - } - } - - private fun sensitive(str: String): BindingName = BindingName(str, BindingCase.SENSITIVE) - /** * Logic for determining how many BindingNames were “matched” by the ConnectorMetadata * 1. Matched = RelativePath - Not Found @@ -450,4 +265,13 @@ internal class Env( ): Int { return originalPath.steps.size + outputCatalogPath.steps.size - inputCatalogPath.steps.size } + + @OptIn(PartiQLValueExperimental::class) + private fun Rex.toPath(steps: List): Rex = steps.fold(this) { curr, step -> + val op = when (step.bindingCase) { + BindingCase.SENSITIVE -> rexOpPathKey(curr, rex(StaticType.STRING, rexOpLit(stringValue(step.name)))) + BindingCase.INSENSITIVE -> rexOpPathSymbol(curr, step.name) + } + rex(StaticType.ANY, op) + } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index f2d9ca5662..1c15078553 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -22,8 +22,6 @@ import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION import org.partiql.planner.PlanningProblemDetails import org.partiql.planner.internal.Env import org.partiql.planner.internal.ResolutionStrategy -import org.partiql.planner.internal.ResolvedVar -import org.partiql.planner.internal.TypeEnv import org.partiql.planner.internal.ir.Agg import org.partiql.planner.internal.ir.Fn import org.partiql.planner.internal.ir.Identifier @@ -32,7 +30,6 @@ import org.partiql.planner.internal.ir.Rel import org.partiql.planner.internal.ir.Rex import org.partiql.planner.internal.ir.Statement import org.partiql.planner.internal.ir.aggResolved -import org.partiql.planner.internal.ir.catalogSymbolRef import org.partiql.planner.internal.ir.fnResolved import org.partiql.planner.internal.ir.identifierSymbol import org.partiql.planner.internal.ir.rel @@ -60,7 +57,6 @@ import org.partiql.planner.internal.ir.rexOpCallStatic import org.partiql.planner.internal.ir.rexOpCaseBranch import org.partiql.planner.internal.ir.rexOpCollection import org.partiql.planner.internal.ir.rexOpErr -import org.partiql.planner.internal.ir.rexOpGlobal import org.partiql.planner.internal.ir.rexOpLit import org.partiql.planner.internal.ir.rexOpPathIndex import org.partiql.planner.internal.ir.rexOpPathKey @@ -70,7 +66,6 @@ import org.partiql.planner.internal.ir.rexOpSelect import org.partiql.planner.internal.ir.rexOpStruct import org.partiql.planner.internal.ir.rexOpStructField import org.partiql.planner.internal.ir.rexOpTupleUnion -import org.partiql.planner.internal.ir.rexOpVarResolved import org.partiql.planner.internal.ir.statementQuery import org.partiql.planner.internal.ir.util.PlanRewriter import org.partiql.spi.BindingCase @@ -124,10 +119,7 @@ internal class PlanTyper( throw IllegalArgumentException("PartiQLPlanner only supports Query statements") } // root TypeEnv has no bindings - val typeEnv = TypeEnv( - schema = emptyList(), - strategy = ResolutionStrategy.GLOBAL, - ) + val typeEnv = TypeEnv(schema = emptyList()) val root = statement.root.type(typeEnv) return statementQuery(root) } @@ -136,8 +128,12 @@ internal class PlanTyper( * Types the relational operators of a query expression. * * @property outer represents the outer TypeEnv of a query expression — only used by scan variable resolution. + * @property strategy */ - private inner class RelTyper(private val outer: TypeEnv) : PlanRewriter() { + private inner class RelTyper( + private val outer: TypeEnv, + private val strategy: ResolutionStrategy, + ) : PlanRewriter() { override fun visitRel(node: Rel, ctx: Rel.Type?) = visitRelOp(node.op, node.type) as Rel @@ -146,7 +142,7 @@ internal class PlanTyper( */ override fun visitRelOpScan(node: Rel.Op.Scan, ctx: Rel.Type?): Rel { // descend, with GLOBAL resolution strategy - val rex = node.rex.type(outer.global()) + val rex = node.rex.type(outer, ResolutionStrategy.GLOBAL) // compute rel type val valueT = getElementTypeForFromSource(rex.type) val type = ctx!!.copyWithSchema(listOf(valueT)) @@ -165,7 +161,7 @@ internal class PlanTyper( */ override fun visitRelOpScanIndexed(node: Rel.Op.ScanIndexed, ctx: Rel.Type?): Rel { // descend, with GLOBAL resolution strategy - val rex = node.rex.type(outer.global()) + val rex = node.rex.type(outer, ResolutionStrategy.GLOBAL) // compute rel type val valueT = getElementTypeForFromSource(rex.type) val indexT = StaticType.INT8 @@ -180,7 +176,7 @@ internal class PlanTyper( */ override fun visitRelOpUnpivot(node: Rel.Op.Unpivot, ctx: Rel.Type?): Rel { // descend, with GLOBAL resolution strategy - val rex = node.rex.type(outer.global()) + val rex = node.rex.type(outer, ResolutionStrategy.GLOBAL) // only UNPIVOT a struct if (rex.type !is StructType) { @@ -215,7 +211,7 @@ internal class PlanTyper( // compute input schema val input = visitRel(node.input, ctx) // type sub-nodes - val typeEnv = TypeEnv(input.type.schema, ResolutionStrategy.LOCAL) + val typeEnv = TypeEnv(input.type.schema) val predicate = node.predicate.type(typeEnv) // compute output schema val type = input.type @@ -228,10 +224,10 @@ internal class PlanTyper( // compute input schema val input = visitRel(node.input, ctx) // type sub-nodes - val typeEnv = TypeEnv(input.type.schema, ResolutionStrategy.LOCAL) + val typeEnv = TypeEnv(input.type.schema) val specs = node.specs.map { val rex = it.rex.type(typeEnv) - it.copy(rex) + it.copy(rex = rex) } // output schema of a sort is the same as the input val type = input.type.copy(props = setOf(Rel.Prop.ORDERED)) @@ -256,8 +252,7 @@ internal class PlanTyper( // compute input schema val input = visitRel(node.input, ctx) // type limit expression using outer scope with global resolution - val typeEnv = outer.global() - val limit = node.limit.type(typeEnv) + val limit = node.limit.type(outer, ResolutionStrategy.GLOBAL) // check types assertAsInt(limit.type) // compute output schema @@ -271,8 +266,7 @@ internal class PlanTyper( // compute input schema val input = visitRel(node.input, ctx) // type offset expression using outer scope with global resolution - val typeEnv = outer.global() - val offset = node.offset.type(typeEnv) + val offset = node.offset.type(outer, ResolutionStrategy.GLOBAL) // check types assertAsInt(offset.type) // compute output schema @@ -286,7 +280,7 @@ internal class PlanTyper( // compute input schema val input = visitRel(node.input, ctx) // type sub-nodes - val typeEnv = TypeEnv(input.type.schema, ResolutionStrategy.LOCAL) + val typeEnv = TypeEnv(input.type.schema) val projections = node.projections.map { it.type(typeEnv) } @@ -315,7 +309,7 @@ internal class PlanTyper( val type = relType(schema, ctx!!.props) // Type the condition on the output schema - val condition = node.rex.type(TypeEnv(type.schema, ResolutionStrategy.LOCAL)) + val condition = node.rex.type(TypeEnv(type.schema)) val op = relOpJoin(lhs, rhs, condition, node.type) return rel(type, op) @@ -359,20 +353,22 @@ internal class PlanTyper( val schema = node.items.fold((init)) { bindings, item -> excludeBindings(bindings, item) } // rewrite - val type = ctx!!.copy(schema) + val type = ctx!!.copy(schema = schema) // resolve exclude path roots val newItems = node.items.map { item -> val resolvedRoot = when (val root = item.root) { is Rex.Op.Var.Unresolved -> { // resolve `root` to local binding - val bindingPath = root.identifier.toBindingPath() - when (val resolved = env.resolveLocalBind(bindingPath, init)) { - null -> { - handleUnresolvedExcludeRoot(root.identifier) - root - } - else -> rexOpVarResolved(resolved.ordinal) + val locals = TypeEnv(input.type.schema) + val path = root.identifier.toBindingPath() + val resolved = locals.resolve(path) + if (resolved == null) { + handleUnresolvedExcludeRoot(root.identifier) + root + } else { + // root of exclude is always a symbol + resolved.op as Rex.Op.Var } } is Rex.Op.Var.Resolved -> root @@ -390,7 +386,7 @@ internal class PlanTyper( val input = visitRel(node.input, ctx) // type the calls and groups - val typer = RexTyper(locals = TypeEnv(input.type.schema, ResolutionStrategy.LOCAL)) + val typer = RexTyper(TypeEnv(input.type.schema), ResolutionStrategy.LOCAL) // typing of aggregate calls is slightly more complicated because they are not expressions. val calls = node.calls.mapIndexed { i, call -> @@ -427,7 +423,10 @@ internal class PlanTyper( * @property locals TypeEnv in which this rex tree is evaluated. */ @OptIn(PartiQLValueExperimental::class) - private inner class RexTyper(private val locals: TypeEnv) : PlanRewriter() { + private inner class RexTyper( + private val locals: TypeEnv, + private val strategy: ResolutionStrategy, + ) : PlanRewriter() { override fun visitRex(node: Rex, ctx: StaticType?): Rex = visitRexOp(node.op, node.type) as Rex @@ -444,52 +443,16 @@ internal class PlanTyper( override fun visitRexOpVarUnresolved(node: Rex.Op.Var.Unresolved, ctx: StaticType?): Rex { val path = node.identifier.toBindingPath() - val resolvedVar = env.resolve(path, locals, node.scope) - + val strategy = when (node.scope) { + Rex.Op.Var.Scope.DEFAULT -> strategy + Rex.Op.Var.Scope.LOCAL -> ResolutionStrategy.LOCAL + } + val resolvedVar = env.resolve(path, locals, strategy) if (resolvedVar == null) { handleUndefinedVariable(path.steps.last()) return rex(ANY, rexOpErr("Undefined variable ${node.identifier}")) } - val type = resolvedVar.type - return when (resolvedVar) { - is ResolvedVar.Global -> { - val variable = rex(type, rexOpGlobal(catalogSymbolRef(resolvedVar.ordinal, resolvedVar.position))) - when (resolvedVar.depth) { - path.steps.size -> variable - else -> { - val foldedPath = foldPath(path.steps, resolvedVar.depth, path.steps.size, variable) - visitRex(foldedPath, ctx) - } - } - } - is ResolvedVar.Local -> { - val variable = rex(type, rexOpVarResolved(resolvedVar.ordinal)) - when { - path.isEquivalentTo(resolvedVar.resolvedSteps) && path.steps.size == resolvedVar.depth -> variable - else -> { - val foldedPath = foldPath(resolvedVar.resolvedSteps, resolvedVar.depth, resolvedVar.resolvedSteps.size, variable) - visitRex(foldedPath, ctx) - } - } - } - } - } - - private fun foldPath(path: List, start: Int, end: Int, global: Rex) = - path.subList(start, end).fold(global) { current, step -> - when (step.bindingCase) { - BindingCase.SENSITIVE -> rex(ANY, rexOpPathKey(current, rex(STRING, rexOpLit(stringValue(step.name))))) - BindingCase.INSENSITIVE -> rex(ANY, rexOpPathSymbol(current, step.name)) - } - } - - private fun BindingPath.isEquivalentTo(other: List): Boolean { - this.steps.forEachIndexed { index, bindingName -> - if (bindingName != other[index]) { - return false - } - } - return true + return visitRex(resolvedVar, null) } override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: StaticType?): Rex { @@ -559,10 +522,16 @@ internal class PlanTyper( val paths = root.type.allTypes.map { type -> val struct = type as? StructType ?: return@map rex(MISSING, rexOpLit(missingValue())) - val (pathType, replacementId) = inferStructLookup(struct, identifierSymbol(node.key, Identifier.CaseSensitivity.INSENSITIVE)) + val (pathType, replacementId) = inferStructLookup( + struct, + identifierSymbol(node.key, Identifier.CaseSensitivity.INSENSITIVE) + ) when (replacementId.caseSensitivity) { Identifier.CaseSensitivity.INSENSITIVE -> rex(pathType, rexOpPathSymbol(root, replacementId.symbol)) - Identifier.CaseSensitivity.SENSITIVE -> rex(pathType, rexOpPathKey(root, rexString(replacementId.symbol))) + Identifier.CaseSensitivity.SENSITIVE -> rex( + pathType, + rexOpPathKey(root, rexString(replacementId.symbol)) + ) } } val type = unionOf(paths.map { it.type }.toSet()).flatten() @@ -901,7 +870,7 @@ internal class PlanTyper( override fun visitRexOpPivot(node: Rex.Op.Pivot, ctx: StaticType?): Rex { val rel = node.rel.type(locals) - val typeEnv = TypeEnv(rel.type.schema, ResolutionStrategy.LOCAL) + val typeEnv = TypeEnv(rel.type.schema) val key = node.key.type(typeEnv) val value = node.value.type(typeEnv) val type = StructType( @@ -960,7 +929,7 @@ internal class PlanTyper( override fun visitRexOpSelect(node: Rex.Op.Select, ctx: StaticType?): Rex { val rel = node.rel.type(locals) - val typeEnv = TypeEnv(rel.type.schema, ResolutionStrategy.LOCAL) + val typeEnv = TypeEnv(rel.type.schema) var constructor = node.constructor.type(typeEnv) var constructorType = constructor.type // add the ordered property to the constructor @@ -1243,9 +1212,11 @@ internal class PlanTyper( // HELPERS - private fun Rel.type(typeEnv: TypeEnv): Rel = RelTyper(typeEnv).visitRel(this, null) + private fun Rel.type(locals: TypeEnv, strategy: ResolutionStrategy = ResolutionStrategy.LOCAL): Rel = + RelTyper(locals, strategy).visitRel(this, null) - private fun Rex.type(typeEnv: TypeEnv) = RexTyper(typeEnv).visitRex(this, this.type) + private fun Rex.type(locals: TypeEnv, strategy: ResolutionStrategy = ResolutionStrategy.LOCAL) = + RexTyper(locals, strategy).visitRex(this, this.type) private fun rexErr(message: String) = rex(MISSING, rexOpErr(message)) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt new file mode 100644 index 0000000000..d413abde04 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt @@ -0,0 +1,155 @@ +package org.partiql.planner.internal.typer + +import org.partiql.planner.internal.ir.Rel +import org.partiql.planner.internal.ir.Rex +import org.partiql.planner.internal.ir.rex +import org.partiql.planner.internal.ir.rexOpLit +import org.partiql.planner.internal.ir.rexOpPathKey +import org.partiql.planner.internal.ir.rexOpPathSymbol +import org.partiql.planner.internal.ir.rexOpVarResolved +import org.partiql.spi.BindingCase +import org.partiql.spi.BindingName +import org.partiql.spi.BindingPath +import org.partiql.types.StaticType +import org.partiql.types.StructType +import org.partiql.types.TupleConstraint +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.stringValue + +/** + * TypeEnv represents a variables type environment. + */ +internal class TypeEnv(public val schema: List) { + + /** + * We resolve a local with the following rules. See, PartiQL Specification p.35. + * + * 1) Check if the path root unambiguously matches a local binding name, set as root. + * 2) Check if the path root unambiguously matches a local binding struct value field. + * + * Convert any remaining binding names (tail) to a path expression. + * + * @param path + * @return + */ + fun resolve(path: BindingPath): Rex? { + val head: BindingName = path.steps[0] + var tail: List = path.steps.drop(1) + var r = matchRoot(head) + if (r == null) { + r = matchStruct(head) ?: return null + tail = path.steps + } + // Convert any remaining binding names (tail) to an untyped path expression. + return if (tail.isEmpty()) r else r.toPath(tail) + } + + /** + * Debugging string, ex: < x: int, y: string > + * + * @return + */ + override fun toString(): String = "< " + schema.joinToString { "${it.name}: ${it.type}" } + " >" + + /** + * Check if `name` unambiguously matches a local binding name and return its reference; otherwise return null. + * + * @param name + * @return + */ + private fun matchRoot(name: BindingName): Rex? { + var r: Rex? = null + for (i in schema.indices) { + val local = schema[i] + val type = local.type + if (name.isEquivalentTo(local.name)) { + if (r != null) { + // TODO root was already matched, emit ambiguous error. + return null + } + r = rex(type, rexOpVarResolved(i)) + } + } + return r + } + + /** + * Check if `name` unambiguously matches a field within a struct and return its reference; otherwise return null. + * + * @param name + * @return + */ + private fun matchStruct(name: BindingName): Rex? { + var c: Rex? = null + var known = false + for (i in schema.indices) { + val local = schema[i] + val type = local.type + if (type is StructType) { + when (type.containsKey(name)) { + true -> { + if (c != null && known) { + // TODO root was already definitively matched, emit ambiguous error. + return null + } + c = rex(type, rexOpVarResolved(i)) + known = true + } + null -> { + if (c != null) { + if (known) { + continue + } else { + // TODO we have more than one possible match, emit ambiguous error. + return null + } + } + c = rex(type, rexOpVarResolved(i)) + known = false + } + false -> continue + } + } + } + return c + } + + /** + * Converts a list of [BindingName] to a path expression. + * + * 1) Case SENSITIVE identifiers become string literal key lookups. + * 2) Case INSENSITIVE identifiers become symbol lookups. + * + * @param steps + * @return + */ + @OptIn(PartiQLValueExperimental::class) + private fun Rex.toPath(steps: List): Rex = steps.fold(this) { curr, step -> + val op = when (step.bindingCase) { + BindingCase.SENSITIVE -> rexOpPathKey(curr, rex(StaticType.STRING, rexOpLit(stringValue(step.name)))) + BindingCase.INSENSITIVE -> rexOpPathSymbol(curr, step.name) + } + rex(StaticType.ANY, op) + } + + /** + * Searches for the [BindingName] within the given [StructType]. + * + * Returns + * - true iff known to contain key + * - false iff known to NOT contain key + * - null iff NOT known to contain key + * + * @param name + * @return + */ + private fun StructType.containsKey(name: BindingName): Boolean? { + for (f in fields) { + if (name.isEquivalentTo(f.key)) { + return true + } + } + val closed = constraints.contains(TupleConstraint.Open(false)) + return if (closed) false else null + } +} diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/EnvTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/EnvTest.kt index cbdb3892e0..3f83b5b309 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/EnvTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/EnvTest.kt @@ -5,7 +5,7 @@ import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.partiql.planner.PartiQLPlanner import org.partiql.planner.internal.ir.Catalog -import org.partiql.planner.internal.ir.Rex +import org.partiql.planner.internal.typer.TypeEnv import org.partiql.plugins.local.LocalConnector import org.partiql.spi.BindingCase import org.partiql.spi.BindingName @@ -22,7 +22,7 @@ class EnvTest { private val root = this::class.java.getResource("/catalogs/default/pql")!!.toURI().toPath() - private val EMPTY_TYPE_ENV = TypeEnv(schema = emptyList(), ResolutionStrategy.GLOBAL) + private val EMPTY_TYPE_ENV = TypeEnv(schema = emptyList()) private val GLOBAL_OS = Catalog( name = "pql", @@ -52,7 +52,7 @@ class EnvTest { @Test fun testGlobalMatchingSensitiveName() { val path = BindingPath(listOf(BindingName("os", BindingCase.SENSITIVE))) - assertNotNull(env.resolve(path, EMPTY_TYPE_ENV, Rex.Op.Var.Scope.DEFAULT)) + assertNotNull(env.resolve(path, EMPTY_TYPE_ENV, ResolutionStrategy.GLOBAL)) assertEquals(1, env.catalogs.size) assert(env.catalogs.contains(GLOBAL_OS)) } @@ -60,7 +60,7 @@ class EnvTest { @Test fun testGlobalMatchingInsensitiveName() { val path = BindingPath(listOf(BindingName("oS", BindingCase.INSENSITIVE))) - assertNotNull(env.resolve(path, EMPTY_TYPE_ENV, Rex.Op.Var.Scope.DEFAULT)) + assertNotNull(env.resolve(path, EMPTY_TYPE_ENV, ResolutionStrategy.GLOBAL)) assertEquals(1, env.catalogs.size) assert(env.catalogs.contains(GLOBAL_OS)) } @@ -68,14 +68,14 @@ class EnvTest { @Test fun testGlobalNotMatchingSensitiveName() { val path = BindingPath(listOf(BindingName("oS", BindingCase.SENSITIVE))) - assertNull(env.resolve(path, EMPTY_TYPE_ENV, Rex.Op.Var.Scope.DEFAULT)) + assertNull(env.resolve(path, EMPTY_TYPE_ENV, ResolutionStrategy.GLOBAL)) assert(env.catalogs.isEmpty()) } @Test fun testGlobalNotMatchingInsensitiveName() { val path = BindingPath(listOf(BindingName("nonexistent", BindingCase.INSENSITIVE))) - assertNull(env.resolve(path, EMPTY_TYPE_ENV, Rex.Op.Var.Scope.DEFAULT)) + assertNull(env.resolve(path, EMPTY_TYPE_ENV, ResolutionStrategy.GLOBAL)) assert(env.catalogs.isEmpty()) } } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/TypeEnvTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/TypeEnvTest.kt new file mode 100644 index 0000000000..67504b0f89 --- /dev/null +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/TypeEnvTest.kt @@ -0,0 +1,114 @@ +package org.partiql.planner.internal.typer + +import org.junit.jupiter.api.parallel.Execution +import org.junit.jupiter.api.parallel.ExecutionMode +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.partiql.planner.internal.ir.Rex +import org.partiql.planner.internal.ir.relBinding +import org.partiql.spi.BindingCase +import org.partiql.spi.BindingName +import org.partiql.spi.BindingPath +import org.partiql.types.BoolType +import org.partiql.types.StaticType +import org.partiql.types.StructType +import org.partiql.types.TupleConstraint +import kotlin.test.assertEquals +import kotlin.test.fail + +internal class TypeEnvTest { + + companion object { + + /** + * < + * A : { B: } }, + * a : { b: } }, + * X : { ... } }, + * x : { y: , ... } }, + * Y : { ... } }, + * T : { x: , x: } }, + * > + */ + @JvmStatic + val locals = TypeEnv( + listOf( + relBinding("A", struct("B" to BoolType())), + relBinding("a", struct("b" to BoolType())), + relBinding("X", struct(open = true)), + relBinding("x", struct("Y" to BoolType(), open = true)), + relBinding("y", struct(open = true)), + relBinding("T", struct("x" to BoolType(), "x" to BoolType())), + ) + ) + + private fun struct(vararg fields: Pair, open: Boolean = false): StructType { + return StructType( + fields = fields.map { StructType.Field(it.first, it.second) }, + constraints = setOf(TupleConstraint.Open(open)), + ) + } + + @JvmStatic + public fun cases() = listOf>( + // root matching + """ A.B """ to null, + """ A."B" """ to null, + """ "A".B """ to 0, + """ "A"."B" """ to 0, + """ "a".B """ to 1, + """ "a"."B" """ to 1, + """ x """ to null, + // """ x.y """ to 3, + """ y """ to 4, + + // struct searching + """ b """ to null, + """ "B" """ to 0, + """ "b" """ to 1, + """ "Y" """ to 3, + + // other + """ T.x """ to 5 + ) + } + + @ParameterizedTest + @MethodSource("cases") + @Execution(ExecutionMode.CONCURRENT) + fun resolve(case: Pair) { + val path = case.first.path() + val expected = case.second + val rex = locals.resolve(path) + if (rex == null) { + if (expected == null) { + return // pass + } else { + fail("could not resolve variable") + } + } + // For now, just traverse to the root + var root = rex.op + while (root !is Rex.Op.Var.Resolved) { + root = when (root) { + is Rex.Op.Path.Symbol -> root.root.op + is Rex.Op.Path.Key -> root.root.op + else -> { + fail("Expected path step of symbol or key, but found $root") + } + } + } + // + assertEquals(expected, root.ref) + } + + private fun String.path(): BindingPath { + val steps = trim().split(".").map { + when (it.startsWith("\"")) { + true -> BindingName(it.drop(1).dropLast(1), BindingCase.SENSITIVE) + else -> BindingName(it, BindingCase.INSENSITIVE) + } + } + return BindingPath(steps) + } +}