From 9310e4a8b85758ab0a302d5b5dfd189e5c111014 Mon Sep 17 00:00:00 2001 From: David Lurton Date: Fri, 29 Apr 2022 14:57:41 -0700 Subject: [PATCH] Query planner passes: AST->logical, logical->resolved, resolved->physical Not yet integrated with anything. --- lang/src/org/partiql/lang/domains/util.kt | 38 +- .../partiql/lang/eval/EvaluatingCompiler.kt | 2 +- .../lang/eval/builtins/BuiltinFunctions.kt | 31 + .../visitors/PartiqlAstSanityValidator.kt | 2 +- .../partiql/lang/planner/GlobalBindings.kt | 50 ++ .../org/partiql/lang/planner/PassResult.kt | 15 + .../lang/planner/transforms/AstNormalize.kt | 25 + .../AstToLogicalVisitorTransform.kt | 168 +++++ ...gicalResolvedToPhysicalVisitorTransform.kt | 94 +++ ...ogicalToLogicalResolvedVisitorTransform.kt | 393 ++++++++++ .../transforms/PlanningProblemDetails.kt | 46 ++ .../partiql/lang/planner/transforms/Util.kt | 13 + .../planner/transforms/VariableIdAllocator.kt | 45 ++ lang/test/org/partiql/lang/planner/Util.kt | 22 + .../AstToLogicalVisitorTransformTests.kt | 152 ++++ ...ResolvedToPhysicalVisitorTransformTests.kt | 69 ++ ...lToLogicalResolvedVisitorTransformTests.kt | 685 ++++++++++++++++++ 17 files changed, 1836 insertions(+), 14 deletions(-) create mode 100644 lang/src/org/partiql/lang/planner/GlobalBindings.kt create mode 100644 lang/src/org/partiql/lang/planner/PassResult.kt create mode 100644 lang/src/org/partiql/lang/planner/transforms/AstNormalize.kt create mode 100644 lang/src/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt create mode 100644 lang/src/org/partiql/lang/planner/transforms/LogicalResolvedToPhysicalVisitorTransform.kt create mode 100644 lang/src/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransform.kt create mode 100644 lang/src/org/partiql/lang/planner/transforms/PlanningProblemDetails.kt create mode 100644 lang/src/org/partiql/lang/planner/transforms/Util.kt create mode 100644 lang/src/org/partiql/lang/planner/transforms/VariableIdAllocator.kt create mode 100644 lang/test/org/partiql/lang/planner/Util.kt create mode 100644 lang/test/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransformTests.kt create mode 100644 lang/test/org/partiql/lang/planner/transforms/LogicalResolvedToPhysicalVisitorTransformTests.kt create mode 100644 lang/test/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransformTests.kt diff --git a/lang/src/org/partiql/lang/domains/util.kt b/lang/src/org/partiql/lang/domains/util.kt index 713fdc4f54..b885d0706b 100644 --- a/lang/src/org/partiql/lang/domains/util.kt +++ b/lang/src/org/partiql/lang/domains/util.kt @@ -1,5 +1,6 @@ package org.partiql.lang.domains +import com.amazon.ionelement.api.IonElement import com.amazon.ionelement.api.MetaContainer import com.amazon.ionelement.api.emptyMetaContainer import com.amazon.ionelement.api.metaContainerOf @@ -14,6 +15,19 @@ import org.partiql.lang.eval.BindingCase fun PartiqlAst.Builder.id(name: String) = id(name, caseInsensitive(), unqualified()) +// TODO: once https://github.com/partiql/partiql-ir-generator/issues/6 has been completed, we can delete this. +fun PartiqlLogical.Builder.id(name: String) = + id(name, caseInsensitive(), unqualified()) + +// TODO: once https://github.com/partiql/partiql-ir-generator/issues/6 has been completed, we can delete this. +fun PartiqlLogical.Builder.pathExpr(exp: PartiqlLogical.Expr) = + pathExpr(exp, caseInsensitive()) + +// Workaround for a bug in PIG that is fixed in its next release: +// https://github.com/partiql/partiql-ir-generator/issues/41 +fun List.asAnyElement() = + this.map { it.asAnyElement() } + val MetaContainer.staticType: StaticTypeMeta? get() = this[StaticTypeMeta.TAG] as StaticTypeMeta? /** Constructs a container with the specified metas. */ @@ -60,17 +74,17 @@ fun PartiqlAst.CaseSensitivity.toBindingCase(): BindingCase = when (this) { } /** - * Returns the [SourceLocationMeta] as an error context if the [SourceLocationMeta.TAG] exists in the passed - * [metaContainer]. Otherwise, returns an empty map. + * Converts a [PartiqlLogical.CaseSensitivity] to a [BindingCase]. */ -fun errorContextFrom(metaContainer: MetaContainer?): PropertyValueMap { - if (metaContainer == null) { - return PropertyValueMap() - } - val location = metaContainer[SourceLocationMeta.TAG] as? SourceLocationMeta - return if (location != null) { - org.partiql.lang.eval.errorContextFrom(location) - } else { - PropertyValueMap() - } +fun PartiqlLogical.CaseSensitivity.toBindingCase(): BindingCase = when (this) { + is PartiqlLogical.CaseSensitivity.CaseInsensitive -> BindingCase.INSENSITIVE + is PartiqlLogical.CaseSensitivity.CaseSensitive -> BindingCase.SENSITIVE +} + +/** + * Converts a [PartiqlLogical.CaseSensitivity] to a [BindingCase]. + */ +fun PartiqlPhysical.CaseSensitivity.toBindingCase(): BindingCase = when (this) { + is PartiqlPhysical.CaseSensitivity.CaseInsensitive -> BindingCase.INSENSITIVE + is PartiqlPhysical.CaseSensitivity.CaseSensitive -> BindingCase.SENSITIVE } diff --git a/lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt b/lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt index 646f602587..64c9812e1b 100644 --- a/lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt +++ b/lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt @@ -3058,7 +3058,7 @@ private class SingleProjectionElement(val name: ExprValue, val thunk: ThunkEnv) */ private class MultipleProjectionElement(val thunks: List) : ProjectionElement() -private val MetaContainer.sourceLocationMeta get() = this[SourceLocationMeta.TAG] as? SourceLocationMeta +internal val MetaContainer.sourceLocationMeta get() = this[SourceLocationMeta.TAG] as? SourceLocationMeta private fun StaticType.getTypes() = when (val flattened = this.flatten()) { is AnyOfType -> flattened.types diff --git a/lang/src/org/partiql/lang/eval/builtins/BuiltinFunctions.kt b/lang/src/org/partiql/lang/eval/builtins/BuiltinFunctions.kt index 5a685f37db..8c094b1c99 100644 --- a/lang/src/org/partiql/lang/eval/builtins/BuiltinFunctions.kt +++ b/lang/src/org/partiql/lang/eval/builtins/BuiltinFunctions.kt @@ -15,15 +15,20 @@ package org.partiql.lang.eval.builtins import com.amazon.ion.system.IonSystemBuilder +import org.partiql.lang.eval.DEFAULT_COMPARATOR import org.partiql.lang.eval.EvaluationSession import org.partiql.lang.eval.ExprFunction import org.partiql.lang.eval.ExprValue import org.partiql.lang.eval.ExprValueFactory import org.partiql.lang.eval.stringValue +import org.partiql.lang.eval.unnamedValue import org.partiql.lang.types.AnyOfType import org.partiql.lang.types.FunctionSignature import org.partiql.lang.types.StaticType import org.partiql.lang.types.UnknownArguments +import java.util.TreeSet + +internal const val DYNAMIC_LOOKUP_FUNCTION_NAME = "\$__dynamic_lookup__" internal fun createBuiltinFunctionSignatures(): Map = // Creating a new IonSystem in this instance is not the problem it would normally be since we are @@ -40,6 +45,7 @@ internal fun createBuiltinFunctions(valueFactory: ExprValueFactory) = createCharacterLength("character_length", valueFactory), createCharacterLength("char_length", valueFactory), createUtcNow(valueFactory), + createFilterDistinct(valueFactory), DateAddExprFunction(valueFactory), DateDiffExprFunction(valueFactory), ExtractExprFunction(valueFactory), @@ -52,6 +58,7 @@ internal fun createBuiltinFunctions(valueFactory: ExprValueFactory) = SizeExprFunction(valueFactory), FromUnixTimeFunction(valueFactory), UnixTimestampFunction(valueFactory) + // Note that we do not include DynamicLookupExprFunction here since it is only needed by the plan evaluator. ) internal fun createExists(valueFactory: ExprValueFactory): ExprFunction = object : ExprFunction { @@ -77,6 +84,30 @@ internal fun createUtcNow(valueFactory: ExprValueFactory): ExprFunction = object valueFactory.newTimestamp(session.now) } +internal fun createFilterDistinct(valueFactory: ExprValueFactory): ExprFunction = object : ExprFunction { + override val signature = FunctionSignature( + "filter_distinct", + listOf(StaticType.unionOf(StaticType.BAG, StaticType.LIST, StaticType.SEXP, StaticType.STRUCT)), + returnType = StaticType.BAG + ) + + override fun callWithRequired(session: EvaluationSession, required: List): ExprValue { + val argument = required.first() + // We cannot use a [HashSet] here because [ExprValue] does not implement .equals() and .hashCode() + val encountered = TreeSet(DEFAULT_COMPARATOR) + return valueFactory.newBag( + sequence { + argument.asSequence().forEach { + if (!encountered.contains(it)) { + encountered.add(it.unnamedValue()) + yield(it) + } + } + } + ) + } +} + internal fun createCharacterLength(name: String, valueFactory: ExprValueFactory): ExprFunction = object : ExprFunction { override val signature: FunctionSignature diff --git a/lang/src/org/partiql/lang/eval/visitors/PartiqlAstSanityValidator.kt b/lang/src/org/partiql/lang/eval/visitors/PartiqlAstSanityValidator.kt index c86dd7141a..39d4b89ceb 100644 --- a/lang/src/org/partiql/lang/eval/visitors/PartiqlAstSanityValidator.kt +++ b/lang/src/org/partiql/lang/eval/visitors/PartiqlAstSanityValidator.kt @@ -23,7 +23,6 @@ import org.partiql.lang.ast.IsCountStarMeta import org.partiql.lang.ast.passes.SemanticException import org.partiql.lang.domains.PartiqlAst import org.partiql.lang.domains.addSourceLocation -import org.partiql.lang.domains.errorContextFrom import org.partiql.lang.errors.ErrorCode import org.partiql.lang.errors.Property import org.partiql.lang.errors.PropertyValueMap @@ -31,6 +30,7 @@ import org.partiql.lang.eval.CompileOptions import org.partiql.lang.eval.EvaluationException import org.partiql.lang.eval.TypedOpBehavior import org.partiql.lang.eval.err +import org.partiql.lang.eval.errorContextFrom import org.partiql.pig.runtime.LongPrimitive /** diff --git a/lang/src/org/partiql/lang/planner/GlobalBindings.kt b/lang/src/org/partiql/lang/planner/GlobalBindings.kt new file mode 100644 index 0000000000..1802370c68 --- /dev/null +++ b/lang/src/org/partiql/lang/planner/GlobalBindings.kt @@ -0,0 +1,50 @@ +package org.partiql.lang.planner + +import org.partiql.lang.eval.BindingCase +import org.partiql.lang.eval.BindingName + +/** Indicates the result of an attempt to resolve a global binding. */ +sealed class ResolutionResult { + /** + * A success case, indicates the [uniqueId] of the match to the [BindingName] in the global scope. + * Typically, this is defined by the storage layer. + */ + data class GlobalVariable(val uniqueId: String) : ResolutionResult() + + /** + * A success case, indicates the [index] of the only possible match to the [BindingName] in a local lexical scope. + * This is `internal` because [index] is an implementation detail that shouldn't be accessible outside of this + * library. + */ + internal data class LocalVariable(val index: Int) : ResolutionResult() + + /** A failure case, indicates that resolution did not match any variable. */ + object Undefined : ResolutionResult() +} + +fun interface GlobalBindings { + /** + * Implementations try to resolve a global variable which is typically a database table, as identified by a + * [bindingName]. The [bindingName] includes both the name as specified by the query author and a [BindingCase] + * which indicates if query author included double quotes (") which mean the lookup should be case-sensitive. + * + * Implementations of this function must return: + * + * - [ResolutionResult.GlobalVariable] if [bindingName] matches a global variable (typically a database table). + * - [ResolutionResult.Undefined] if no identifier matches [bindingName]. + * + * When determining if a variable name matches a global variable, it is important to consider if the comparison + * should be case-sensitive or case-insensitive. @see [BindingName.bindingCase]. In the event that more than one + * variable matches a case-insensitive [BindingName], the implementation must still select one of them + * without providing an error. (This is consistent with Postres's behavior in this scenario.) + * + * Note that while [ResolutionResult.LocalVariable] exists, it is intentionally marked `internal` and cannot + * be used by outside of this project.. + */ + fun resolve(bindingName: BindingName): ResolutionResult +} + +private val EMPTY = GlobalBindings { ResolutionResult.Undefined } + +/** Convenience function for obtaining an instance of [GlobalBindings] with no defined variables. */ +fun emptyGlobalBindings(): GlobalBindings = EMPTY diff --git a/lang/src/org/partiql/lang/planner/PassResult.kt b/lang/src/org/partiql/lang/planner/PassResult.kt new file mode 100644 index 0000000000..f0901f6aa0 --- /dev/null +++ b/lang/src/org/partiql/lang/planner/PassResult.kt @@ -0,0 +1,15 @@ +package org.partiql.lang.planner +import org.partiql.lang.errors.Problem + +sealed class PassResult { + /** + * Indicates query planning was successful and includes a list of any warnings that were encountered along the way. + */ + data class Success(val result: TResult, val warnings: List) : PassResult() + + /** + * Indicates query planning was not successful and includes a list of errors and warnings that were encountered + * along the way. + */ + data class Error(val errors: List) : PassResult() +} diff --git a/lang/src/org/partiql/lang/planner/transforms/AstNormalize.kt b/lang/src/org/partiql/lang/planner/transforms/AstNormalize.kt new file mode 100644 index 0000000000..9e7d362b34 --- /dev/null +++ b/lang/src/org/partiql/lang/planner/transforms/AstNormalize.kt @@ -0,0 +1,25 @@ +package org.partiql.lang.planner.transforms + +import org.partiql.lang.domains.PartiqlAst +import org.partiql.lang.eval.visitors.FromSourceAliasVisitorTransform +import org.partiql.lang.eval.visitors.PipelinedVisitorTransform +import org.partiql.lang.eval.visitors.SelectListItemAliasVisitorTransform +import org.partiql.lang.eval.visitors.SelectStarVisitorTransform + +/** + * Executes the [SelectListItemAliasVisitorTransform], [FromSourceAliasVisitorTransform] and + * [SelectStarVisitorTransform] passes on the receiver. + */ +fun PartiqlAst.Statement.normalize(): PartiqlAst.Statement { + // Since these passes all work on PartiqlAst, we can use a PipelinedVisitorTransform which executes each + // specified VisitorTransform in sequence. + val transforms = PipelinedVisitorTransform( + // Synthesizes unspecified `SELECT AS ...` aliases + SelectListItemAliasVisitorTransform(), + // Synthesizes unspecified `FROM AS ...` aliases + FromSourceAliasVisitorTransform(), + // Changes `SELECT * FROM a, b` to SELECT a.*, b.* FROM a, b` + SelectStarVisitorTransform() + ) + return transforms.transformStatement(this) +} diff --git a/lang/src/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt b/lang/src/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt new file mode 100644 index 0000000000..e9e7e16312 --- /dev/null +++ b/lang/src/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt @@ -0,0 +1,168 @@ +package org.partiql.lang.planner.transforms + +import com.amazon.ionelement.api.ionBool +import org.partiql.lang.domains.PartiqlAst +import org.partiql.lang.domains.PartiqlAstToPartiqlLogicalVisitorTransform +import org.partiql.lang.domains.PartiqlLogical + +/** + * Transforms an instance of [PartiqlAst.Statement] to [PartiqlLogical.Statement]. + * + * Performs no semantic checks. + * + * This conversion (and the logical algebra) are early in their lifecycle and so only a very limited subset of + * SFW queries are transformable. See tests for this class to see which queries are transformable. + */ +internal fun PartiqlAst.Statement.toLogicalPlan(): PartiqlLogical.Plan = + PartiqlLogical.build { + plan( + AstToLogicalVisitorTransform.transformStatement(this@toLogicalPlan), + version = PLAN_VERSION_NUMBER.toLong() + ) + } + +private object AstToLogicalVisitorTransform : PartiqlAstToPartiqlLogicalVisitorTransform() { + + override fun transformExprSelect(node: PartiqlAst.Expr.Select): PartiqlLogical.Expr { + checkForUnsupportedSelectClauses(node) + + var algebra: PartiqlLogical.Bexpr = FromSourceToBexpr.convert(node.from) + + algebra = node.fromLet?.let { fromLet -> + PartiqlLogical.build { + let(algebra, fromLet.letBindings.map { transformLetBinding(it) }, node.fromLet.metas) + } + } ?: algebra + + algebra = node.where?.let { + PartiqlLogical.build { filter(transformExpr(it), algebra, it.metas) } + } ?: algebra + + algebra = node.offset?.let { + PartiqlLogical.build { offset(transformExpr(it), algebra, node.offset.metas) } + } ?: algebra + + algebra = node.limit?.let { + PartiqlLogical.build { limit(transformExpr(it), algebra, node.limit.metas) } + } ?: algebra + + return convertProjectionToBindingsToValues(node, algebra) + } + + private fun convertProjectionToBindingsToValues(node: PartiqlAst.Expr.Select, algebra: PartiqlLogical.Bexpr) = + PartiqlLogical.build { + bindingsToValues( + when (val project = node.project) { + is PartiqlAst.Projection.ProjectValue -> transformExpr(project.value) + is PartiqlAst.Projection.ProjectList -> { + struct( + List(project.projectItems.size) { idx -> + when (val projectItem = project.projectItems[idx]) { + is PartiqlAst.ProjectItem.ProjectExpr -> + structField( + lit( + projectItem.asAlias?.toIonElement() + ?: errAstNotNormalized("SELECT-list item alias not specified") + ), + transformExpr(projectItem.expr), + ) + is PartiqlAst.ProjectItem.ProjectAll -> { + structFields(transformExpr(projectItem.expr), projectItem.metas) + } + } + } + ) + } + is PartiqlAst.Projection.ProjectStar -> + // `SELECT * FROM bar AS b` is rewritten to `SELECT b.* FROM bar as b` by + // [SelectStarVisitorTransform]. Therefore, there is no need to support `SELECT *` here. + errAstNotNormalized("Expected SELECT * to be removed") + + is PartiqlAst.Projection.ProjectPivot -> TODO("PIVOT ...") + }, + algebra, + node.project.metas + ) + }.let { q -> + // in case of SELECT DISTINCT, wrap bindingsToValues in call to filter_distinct + when (node.setq) { + null, is PartiqlAst.SetQuantifier.All -> q + is PartiqlAst.SetQuantifier.Distinct -> PartiqlLogical.build { call("filter_distinct", q) } + } + } + + /** + * Throws [NotImplementedError] if any `SELECT` clauses were used that are not mappable to [PartiqlLogical]. + * + * This function is temporary and will be removed when all the clauses of the `SELECT` expression are mappable + * to [PartiqlLogical]. + */ + private fun checkForUnsupportedSelectClauses(node: PartiqlAst.Expr.Select) { + when { + node.group != null -> TODO("Support for GROUP BY") + node.order != null -> TODO("Support for ORDER BY") + node.having != null -> TODO("Support for HAVING") + } + } + + override fun transformLetBinding(node: PartiqlAst.LetBinding): PartiqlLogical.LetBinding = + PartiqlLogical.build { + letBinding( + transformExpr(node.expr), + varDecl_(node.name, node.name.metas), + node.metas + ) + } + + override fun transformStatementDml(node: PartiqlAst.Statement.Dml): PartiqlLogical.Statement { + TODO("Support for DML") + } + + override fun transformStatementDdl(node: PartiqlAst.Statement.Ddl): PartiqlLogical.Statement { + TODO("Support for DDL") + } + + override fun transformExprStruct(node: PartiqlAst.Expr.Struct): PartiqlLogical.Expr = + PartiqlLogical.build { + struct( + node.fields.map { + structField( + transformExpr(it.first), + transformExpr(it.second) + ) + }, + metas = node.metas + ) + } +} + +private object FromSourceToBexpr : PartiqlAst.FromSource.Converter { + + override fun convertScan(node: PartiqlAst.FromSource.Scan): PartiqlLogical.Bexpr { + val asAlias = node.asAlias ?: errAstNotNormalized("Expected as alias to be non-null") + return PartiqlLogical.build { + scan( + AstToLogicalVisitorTransform.transformExpr(node.expr), + varDecl_(asAlias, asAlias.metas), + node.atAlias?.let { varDecl_(it, it.metas) }, + node.byAlias?.let { varDecl_(it, it.metas) }, + node.metas + ) + } + } + + override fun convertUnpivot(node: PartiqlAst.FromSource.Unpivot): PartiqlLogical.Bexpr { + TODO("Support for UNPIVOT") + } + + override fun convertJoin(node: PartiqlAst.FromSource.Join): PartiqlLogical.Bexpr = + PartiqlLogical.build { + join( + joinType = AstToLogicalVisitorTransform.transformJoinType(node.type), + left = convert(node.left), + right = convert(node.right), + predicate = node.predicate?.let { AstToLogicalVisitorTransform.transformExpr(it) } ?: lit(ionBool(true)), + node.metas + ) + } +} diff --git a/lang/src/org/partiql/lang/planner/transforms/LogicalResolvedToPhysicalVisitorTransform.kt b/lang/src/org/partiql/lang/planner/transforms/LogicalResolvedToPhysicalVisitorTransform.kt new file mode 100644 index 0000000000..b659961d8a --- /dev/null +++ b/lang/src/org/partiql/lang/planner/transforms/LogicalResolvedToPhysicalVisitorTransform.kt @@ -0,0 +1,94 @@ +package org.partiql.lang.planner.transforms + +import org.partiql.lang.domains.PartiqlLogicalResolved +import org.partiql.lang.domains.PartiqlLogicalResolvedToPartiqlPhysicalVisitorTransform +import org.partiql.lang.domains.PartiqlPhysical + +/** + * Transforms an instance of [PartiqlLogicalResolved.Statement] to [PartiqlPhysical.Statement], + * specifying `(impl default)` for each relational operator. + */ +internal fun PartiqlLogicalResolved.Plan.toPhysicalPlan() = + LogicalResolvedToPhysicalVisitorTransform().transformPlan(this) + +internal val DEFAULT_IMPL = PartiqlPhysical.build { impl("default") } + +internal class LogicalResolvedToPhysicalVisitorTransform : PartiqlLogicalResolvedToPartiqlPhysicalVisitorTransform() { + + /** Copies [PartiqlLogicalResolved.Bexpr.Scan] to [PartiqlPhysical.Bexpr.Scan], adding the default impl. */ + override fun transformBexprScan(node: PartiqlLogicalResolved.Bexpr.Scan): PartiqlPhysical.Bexpr { + val thiz = this + return PartiqlPhysical.build { + scan( + i = DEFAULT_IMPL, + expr = thiz.transformExpr(node.expr), + asDecl = thiz.transformVarDecl(node.asDecl), + atDecl = node.atDecl?.let { thiz.transformVarDecl(it) }, + byDecl = node.byDecl?.let { thiz.transformVarDecl(it) }, + metas = node.metas + ) + } + } + + /** Copies [PartiqlLogicalResolved.Bexpr.Filter] to [PartiqlPhysical.Bexpr.Filter], adding the default impl. */ + override fun transformBexprFilter(node: PartiqlLogicalResolved.Bexpr.Filter): PartiqlPhysical.Bexpr { + val thiz = this + return PartiqlPhysical.build { + filter( + i = DEFAULT_IMPL, + predicate = thiz.transformExpr(node.predicate), + source = thiz.transformBexpr(node.source), + metas = node.metas + ) + } + } + + override fun transformBexprJoin(node: PartiqlLogicalResolved.Bexpr.Join): PartiqlPhysical.Bexpr { + val thiz = this + return PartiqlPhysical.build { + join( + i = DEFAULT_IMPL, + joinType = thiz.transformJoinType(node.joinType), + left = thiz.transformBexpr(node.left), + right = thiz.transformBexpr(node.right), + predicate = thiz.transformExpr(node.predicate), + metas = node.metas + ) + } + } + + override fun transformBexprOffset(node: PartiqlLogicalResolved.Bexpr.Offset): PartiqlPhysical.Bexpr { + val thiz = this + return PartiqlPhysical.build { + offset( + i = DEFAULT_IMPL, + rowCount = thiz.transformExpr(node.rowCount), + source = thiz.transformBexpr(node.source), + metas = node.metas + ) + } + } + + override fun transformBexprLimit(node: PartiqlLogicalResolved.Bexpr.Limit): PartiqlPhysical.Bexpr { + val thiz = this + return PartiqlPhysical.build { + limit( + i = DEFAULT_IMPL, + rowCount = thiz.transformExpr(node.rowCount), + source = thiz.transformBexpr(node.source), + metas = node.metas + ) + } + } + + override fun transformBexprLet(node: PartiqlLogicalResolved.Bexpr.Let): PartiqlPhysical.Bexpr { + val thiz = this + return PartiqlPhysical.build { + let( + i = DEFAULT_IMPL, + source = thiz.transformBexpr(node.source), + bindings = node.bindings.map { transformLetBinding(it) } + ) + } + } +} diff --git a/lang/src/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransform.kt b/lang/src/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransform.kt new file mode 100644 index 0000000000..fec1ab15d0 --- /dev/null +++ b/lang/src/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransform.kt @@ -0,0 +1,393 @@ +package org.partiql.lang.planner.transforms + +import com.amazon.ionelement.api.ionSymbol +import org.partiql.lang.ast.sourceLocation +import org.partiql.lang.domains.PartiqlLogical +import org.partiql.lang.domains.PartiqlLogicalResolved +import org.partiql.lang.domains.PartiqlLogicalToPartiqlLogicalResolvedVisitorTransform +import org.partiql.lang.domains.toBindingCase +import org.partiql.lang.errors.Problem +import org.partiql.lang.errors.ProblemHandler +import org.partiql.lang.eval.BindingName +import org.partiql.lang.eval.builtins.DYNAMIC_LOOKUP_FUNCTION_NAME +import org.partiql.lang.planner.GlobalBindings +import org.partiql.lang.planner.ResolutionResult +import org.partiql.pig.runtime.asPrimitive + +/** + * Resolves all variables by rewriting `(id )` to + * `(id )`) or `(global_id )`. The latter is usually a reference to + * a database table. `` is supplied by the integrating PartiQL service by means of the [globals] specified + * by callers of this function. Note that in general, all `(scan (global_id ...) ...)` operators will later be + * rewritten to an optimized physical read operator. + * + * The [problemHandler] is notified of any undefined variables. Resolution does not stop on the first error, rather + * we keep going to provide the end user any additional error messaging, unless [ProblemHandler.handleProblem] throws + * an exception when an error is logged. **If any undefined variables are detected, in order to allow traversal to + * continue, a fake index value is used in place of a real one and the resolved logical plan returned by this function + * is guaranteed to be invalid.** **Therefore, it is the responsibility therefore of callers to check if any problems + * have been logged with [org.partiql.lang.errors.ProblemSeverity.ERROR] and to abort further query planning if + * necessary.** + * + * Local variables are resolved independently within this pass, but we rely on [globals] to resolve global variables. + * + * Ths works in two passes: + * 1. All [PartiqlLogical.VarDecl] nodes are allocated unique indexes (which is stored in a meta). + * 2. Then, during the transform from the `partiql_logical` domain to the `partiql_logical_resolved` domain, we + * determine if the `id` node refers to a global variable or local variable. For global variables, the `id` node is + * replaced with `(global_id )`. For local variables, the original `id` node is replaced with a + * `(id )`), where `` is the index of the corresponding `var_decl`. + */ +internal fun PartiqlLogical.Plan.toResolvedPlan( + problemHandler: ProblemHandler, + globals: GlobalBindings, + allowUndefinedVariables: Boolean = false +): PartiqlLogicalResolved.Plan { + // Allocate a unique id for each `VarDecl` + val (planWithAllocatedVariables, allLocals) = this.allocateVariableIds() + + // Transform to `partiql_logical_resolved` while resolving variables. + val resolvedSt = LogicalToLogicalResolvedVisitorTransform(allowUndefinedVariables, problemHandler, globals) + .transformPlan(planWithAllocatedVariables) + .copy(locals = allLocals) + + return resolvedSt +} + +private fun PartiqlLogical.Expr.Id.asGlobalId(uniqueId: String): PartiqlLogicalResolved.Expr.GlobalId = + PartiqlLogicalResolved.build { + globalId_( + name = name, + uniqueId = uniqueId.asPrimitive(), + metas = this@asGlobalId.metas + ) + } + +private fun PartiqlLogical.Expr.Id.asLocalId(index: Int): PartiqlLogicalResolved.Expr = + PartiqlLogicalResolved.build { + localId_(index.asPrimitive(), this@asLocalId.metas) + } + +private fun PartiqlLogical.Expr.Id.asErrorId(): PartiqlLogicalResolved.Expr = + PartiqlLogicalResolved.build { + localId_((-1).asPrimitive(), this@asErrorId.metas) + } + +/** + * A local scope is a list of variable declarations that are produced by a relational operator and an optional + * reference to a parent scope. This is handled separately from global variables. + * + * This is a [List] of [PartiqlLogical.VarDecl] and not a [Map] or some other more efficient data structure + * because most variable lookups are case-insensitive, which makes storing them in a [Map] and benefiting from it hard. + */ +private data class LocalScope(val varDecls: List) + +private data class LogicalToLogicalResolvedVisitorTransform( + /** If set to `true`, do not log errors about undefined variables. Rewrite such variables to a `dynamic_id` node. */ + val allowUndefinedVariables: Boolean, + /** Where to send error reports. */ + private val problemHandler: ProblemHandler, + /** If a variable is not found using [inputScope], we will attempt to locate the binding here instead. */ + private val globals: GlobalBindings, + +) : PartiqlLogicalToPartiqlLogicalResolvedVisitorTransform() { + /** The current [LocalScope]. */ + private var inputScope: LocalScope = LocalScope(emptyList()) + + private enum class VariableLookupStrategy { + LOCALS_THEN_GLOBALS, + GLOBALS_THEN_LOCALS + } + + /** + * This is set to [VariableLookupStrategy.GLOBALS_THEN_LOCALS] for the `` in `(scan ...)` nodes and + * [VariableLookupStrategy.LOCALS_THEN_GLOBALS] for everything else. This is we resolve globals first within + * a `FROM`. + */ + private var currentVariableLookupStrategy: VariableLookupStrategy = VariableLookupStrategy.LOCALS_THEN_GLOBALS + + private fun withVariableLookupStrategy(nextVariableLookupStrategy: VariableLookupStrategy, block: () -> T): T { + val lastVariableLookupStrategy = this.currentVariableLookupStrategy + this.currentVariableLookupStrategy = nextVariableLookupStrategy + return block().also { + this.currentVariableLookupStrategy = lastVariableLookupStrategy + } + } + + private fun withInputScope(nextScope: LocalScope, block: () -> T): T { + val lastScope = inputScope + inputScope = nextScope + return block().also { + inputScope = lastScope + } + } + + override fun transformPlan(node: PartiqlLogical.Plan): PartiqlLogicalResolved.Plan = + PartiqlLogicalResolved.build { + plan_( + stmt = transformStatement(node.stmt), + version = node.version, + locals = emptyList(), // NOTE: locals will be populated by caller + metas = node.metas + ) + } + + override fun transformBexprScan_expr(node: PartiqlLogical.Bexpr.Scan): PartiqlLogicalResolved.Expr = + withVariableLookupStrategy(VariableLookupStrategy.GLOBALS_THEN_LOCALS) { + super.transformBexprScan_expr(node) + } + + override fun transformBexprJoin_right(node: PartiqlLogical.Bexpr.Join): PartiqlLogicalResolved.Bexpr { + // No need to change the current scope of the node.left. Node.right gets the current scope + + // the left output scope. + val leftOutputScope = getOutputScope(node.left) + val rightInputScope = inputScope.concatenate(leftOutputScope) + return withInputScope(rightInputScope) { + this.transformBexpr(node.right) + } + } + + override fun transformBexprLet(node: PartiqlLogical.Bexpr.Let): PartiqlLogicalResolved.Bexpr { + val thiz = this + return PartiqlLogicalResolved.build { + let( + source = transformBexpr(node.source), + bindings = withInputScope(getOutputScope(node.source)) { + // This "wonderful" (depending on your definition of the term) bit of code performs a fold + // combined with a map... The accumulator is a Pair, + // LocalScope>. + // accumulator.first: the current list of let bindings that have been transformed so far + // accumulator.second: an instance of LocalScope that includes all the variables defined up to + // this point, not including the current let binding. + val initial = emptyList() to thiz.inputScope + val (newBindings: List, _: LocalScope) = + node.bindings.fold(initial) { accumulator, current -> + // Each let binding's expression should be resolved within the scope of the *last* + // let binding (or the current scope if this is the first let binding). + val resolvedValueExpr = withInputScope(accumulator.second) { + thiz.transformExpr(current.value) + } + val nextScope = LocalScope(listOf(current.decl)).concatenate(accumulator.second) + val transformedLetBindings = accumulator.first + PartiqlLogicalResolved.build { + letBinding(resolvedValueExpr, transformVarDecl(current.decl)) + } + transformedLetBindings to nextScope + } + newBindings + } + ) + } + } + + // We are currently using bindings_to_values to denote a sub-query, which works for all the use cases we are + // presented with today, as every SELECT statement is replaced with `bindings_to_values at the top level. + override fun transformExprBindingsToValues(node: PartiqlLogical.Expr.BindingsToValues): PartiqlLogicalResolved.Expr = + // If we are in the expr of a scan node, we need to reset the lookup strategy + withVariableLookupStrategy(VariableLookupStrategy.LOCALS_THEN_GLOBALS) { + super.transformExprBindingsToValues(node) + } + + /** + * Grabs the index meta added by [VariableIdAllocator] and stores it as an element in + * [PartiqlLogicalResolved.VarDecl]. + */ + override fun transformVarDecl(node: PartiqlLogical.VarDecl): PartiqlLogicalResolved.VarDecl = + PartiqlLogicalResolved.build { + varDecl(node.indexMeta.toLong()) + } + + /** + * Returns [ResolutionResult.LocalVariable] if [bindingName] refers to a local variable. + * + * Otherwise, returns [ResolutionResult.Undefined]. (Elsewhere, [globals] will be checked next.) + */ + private fun lookupLocalVariable(bindingName: BindingName): ResolutionResult { + val found = this.inputScope.varDecls.firstOrNull { bindingName.isEquivalentTo(it.name.text) } + return if (found == null) { + ResolutionResult.Undefined + } else { + ResolutionResult.LocalVariable(found.indexMeta) + } + } + + /** + * Resolves the logical `(id ...)` node node to a `(local_id ...)`, `(global_id ...)`, or dynamic `(id...)` + * variable. + */ + override fun transformExprId(node: PartiqlLogical.Expr.Id): PartiqlLogicalResolved.Expr { + val bindingName = BindingName(node.name.text, node.case.toBindingCase()) + + val resolutionResult = if ( + this.currentVariableLookupStrategy == VariableLookupStrategy.GLOBALS_THEN_LOCALS && + node.qualifier is PartiqlLogical.ScopeQualifier.Unqualified + ) { + // look up variable in globals first, then locals + when (val globalResolutionResult = globals.resolve(bindingName)) { + ResolutionResult.Undefined -> lookupLocalVariable(bindingName) + else -> globalResolutionResult + } + } else { + // look up variable in locals first, then globals. + when (val localResolutionResult = lookupLocalVariable(bindingName)) { + ResolutionResult.Undefined -> globals.resolve(bindingName) + else -> localResolutionResult + } + } + return when (resolutionResult) { + is ResolutionResult.GlobalVariable -> { + node.asGlobalId(resolutionResult.uniqueId) + } + is ResolutionResult.LocalVariable -> { + node.asLocalId(resolutionResult.index) + } + ResolutionResult.Undefined -> { + if (this.allowUndefinedVariables) { + node.asDynamicLookupCallsite( + currentDynamicResolutionCandidates() + .map { + PartiqlLogicalResolved.build { + localId(it.indexMeta.toLong()) + } + } + ) + } else { + node.asErrorId().also { + problemHandler.handleProblem( + Problem( + node.metas.sourceLocation ?: error("MetaContainer is missing SourceLocationMeta"), + PlanningProblemDetails.UndefinedVariable( + node.name.text, + node.case is PartiqlLogical.CaseSensitivity.CaseSensitive + ) + ) + ) + } + } + } + } + } + + /** + * Returns a list of variables accessible from the current scope which contain variables that may contain + * an unqualified variable, in the order that they should be searched. + */ + fun currentDynamicResolutionCandidates(): List = + inputScope.varDecls.filter { it.includeInDynamicResolution } + + override fun transformExprBindingsToValues_exp(node: PartiqlLogical.Expr.BindingsToValues): PartiqlLogicalResolved.Expr { + val bindings = getOutputScope(node.query).concatenate(this.inputScope) + return withInputScope(bindings) { + this.transformExpr(node.exp) + } + } + + override fun transformBexprFilter_predicate(node: PartiqlLogical.Bexpr.Filter): PartiqlLogicalResolved.Expr { + val bindings = getOutputScope(node.source) + return withInputScope(bindings) { + this.transformExpr(node.predicate) + } + } + + override fun transformBexprJoin_predicate(node: PartiqlLogical.Bexpr.Join): PartiqlLogicalResolved.Expr { + val bindings = getOutputScope(node) + return withInputScope(bindings) { + this.transformExpr(node.predicate) + } + } + + /** + * This should be called any time we create a [LocalScope] with more than one variable to prevent duplicate + * variable names. When checking for duplication, the letter case of the variable names is not considered. + * + * Example: + * + * ``` + * SELECT * FROM foo AS X AT x + * duplicate variable: ^ + * ``` + */ + private fun checkForDuplicateVariables(varDecls: List) { + val usedVariableNames = hashSetOf() + varDecls.forEach { varDecl -> + val loweredVariableName = varDecl.name.text.toLowerCase() + if (usedVariableNames.contains(loweredVariableName)) { + this.problemHandler.handleProblem( + Problem( + varDecl.metas.sourceLocation ?: error("VarDecl was missing source location meta"), + PlanningProblemDetails.VariablePreviouslyDefined(varDecl.name.text) + ) + ) + } + usedVariableNames.add(loweredVariableName) + } + } + + /** + * Computes a [LocalScope] for containing all of the variables that are output from [bexpr]. + */ + private fun getOutputScope(bexpr: PartiqlLogical.Bexpr): LocalScope = + when (bexpr) { + is PartiqlLogical.Bexpr.Filter -> getOutputScope(bexpr.source) + is PartiqlLogical.Bexpr.Limit -> getOutputScope(bexpr.source) + is PartiqlLogical.Bexpr.Offset -> getOutputScope(bexpr.source) + is PartiqlLogical.Bexpr.Scan -> { + LocalScope( + listOfNotNull(bexpr.asDecl.markForDynamicResolution(), bexpr.atDecl, bexpr.byDecl).also { + checkForDuplicateVariables(it) + } + ) + } + is PartiqlLogical.Bexpr.Join -> { + val (leftBexpr, rightBexpr) = when (bexpr.joinType) { + is PartiqlLogical.JoinType.Full, + is PartiqlLogical.JoinType.Inner, + is PartiqlLogical.JoinType.Left -> bexpr.left to bexpr.right + // right join is same as left join but right and left operands are swapped. + is PartiqlLogical.JoinType.Right -> bexpr.right to bexpr.left + } + val leftScope = getOutputScope(leftBexpr) + val rightScope = getOutputScope(rightBexpr) + // right scope is first to allow RHS variables to "shadow" LHS variables. + rightScope.concatenate(leftScope) + } + is PartiqlLogical.Bexpr.Let -> { + val sourceScope = getOutputScope(bexpr.source) + // Note that .reversed() is important here to ensure that variable shadowing works correctly. + val letVariables = bexpr.bindings.reversed().map { it.decl } + sourceScope.concatenate(letVariables) + } + } + + private fun LocalScope.concatenate(other: LocalScope): LocalScope = + this.concatenate(other.varDecls) + + private fun LocalScope.concatenate(other: List): LocalScope { + val concatenatedScopeVariables = this.varDecls + other + return LocalScope(concatenatedScopeVariables) + } + + private fun PartiqlLogical.Expr.Id.asDynamicLookupCallsite( + search: List + ): PartiqlLogicalResolved.Expr { + val caseSensitivityString = when (case) { + is PartiqlLogical.CaseSensitivity.CaseInsensitive -> "case_insensitive" + is PartiqlLogical.CaseSensitivity.CaseSensitive -> "case_sensitive" + } + return PartiqlLogicalResolved.build { + call( + funcName = DYNAMIC_LOOKUP_FUNCTION_NAME, + args = listOf( + lit(name.toIonElement()), + lit(ionSymbol(caseSensitivityString)), + lit(ionSymbol(currentVariableLookupStrategy.toString().toLowerCase())), + ) + search, + metas = this@asDynamicLookupCallsite.metas + ) + } + } +} + +/** Marks a variable for dynamic resolution--i.e. if undefined, this vardecl will be included in any dynamic_id lookup. */ +fun PartiqlLogical.VarDecl.markForDynamicResolution() = this.withMeta("\$include_in_dynamic_resolution", Unit) +/** Returns true of the [VarDecl] has been marked to participate in unqualified field resolution */ +val PartiqlLogical.VarDecl.includeInDynamicResolution get() = this.metas.containsKey("\$include_in_dynamic_resolution") diff --git a/lang/src/org/partiql/lang/planner/transforms/PlanningProblemDetails.kt b/lang/src/org/partiql/lang/planner/transforms/PlanningProblemDetails.kt new file mode 100644 index 0000000000..5819479d2e --- /dev/null +++ b/lang/src/org/partiql/lang/planner/transforms/PlanningProblemDetails.kt @@ -0,0 +1,46 @@ +package org.partiql.lang.planner.transforms + +import org.partiql.lang.errors.ProblemDetails +import org.partiql.lang.errors.ProblemSeverity + +/** + * Contains detailed information about errors that may occur during query planning. + * + * This information can be used to generate end-user readable error messages and is also easy to assert + * equivalence in unit tests. + */ +sealed class PlanningProblemDetails( + override val severity: ProblemSeverity, + val messageFormatter: () -> String +) : ProblemDetails { + + override val message: String get() = messageFormatter() + + data class ParseError(val parseErrorMessage: String) : + PlanningProblemDetails(ProblemSeverity.ERROR, { parseErrorMessage }) + + data class CompileError(val errorMessage: String) : + PlanningProblemDetails(ProblemSeverity.ERROR, { errorMessage }) + + data class UndefinedVariable(val variableName: String, val caseSensitive: Boolean) : + PlanningProblemDetails( + ProblemSeverity.ERROR, + { + "Undefined variable '$variableName'." + + if (caseSensitive) { + // Individuals that are new to SQL often try to use double quotes for string literals. + // Let's help them out a bit. + " Hint: did you intend to use single-quotes (') here? Remember that double-quotes (\") denote " + + "quoted identifiers and single-quotes denote strings." + } else { + "" + } + } + ) + + data class VariablePreviouslyDefined(val variableName: String) : + PlanningProblemDetails( + ProblemSeverity.ERROR, + { "The variable '$variableName' was previously defined." } + ) +} diff --git a/lang/src/org/partiql/lang/planner/transforms/Util.kt b/lang/src/org/partiql/lang/planner/transforms/Util.kt new file mode 100644 index 0000000000..062a3ab97d --- /dev/null +++ b/lang/src/org/partiql/lang/planner/transforms/Util.kt @@ -0,0 +1,13 @@ + +package org.partiql.lang.planner.transforms + +/** + * This is the version number of the logical and physical plans supported by this version of PartiQL. + * + * It would be nice to embed this in the PIG domain somehow, but this isn't supported, so we have to include it + * here for now. + */ +const val PLAN_VERSION_NUMBER = 1 + +internal fun errAstNotNormalized(message: String): Nothing = + error("$message - have the basic visitor transforms been executed first?") diff --git a/lang/src/org/partiql/lang/planner/transforms/VariableIdAllocator.kt b/lang/src/org/partiql/lang/planner/transforms/VariableIdAllocator.kt new file mode 100644 index 0000000000..11c2e8addb --- /dev/null +++ b/lang/src/org/partiql/lang/planner/transforms/VariableIdAllocator.kt @@ -0,0 +1,45 @@ +package org.partiql.lang.planner.transforms + +import org.partiql.lang.domains.PartiqlLogical +import org.partiql.lang.domains.PartiqlLogicalResolved + +/** + * Allocates register indexes for all local variables in the plan. + * + * Returns pair containing a logical plan where all `var_decl`s have a [VARIABLE_ID_META_TAG] meta indicating the + * variable index (which can be utilized later when establishing variable scoping) and list of all local variables + * declared within the plan, which becomes the `locals` sub-node of the `plan` node. + */ +internal fun PartiqlLogical.Plan.allocateVariableIds(): Pair> { + + var allLocals = mutableListOf() + val planWithAllocatedVariables = VariableIdAllocator(allLocals).transformPlan(this) + return planWithAllocatedVariables to allLocals.toList() +} + +private const val VARIABLE_ID_META_TAG = "\$variable_id" + +internal val PartiqlLogical.VarDecl.indexMeta + get() = this.metas[VARIABLE_ID_META_TAG] as? Int ?: error("Meta $VARIABLE_ID_META_TAG was not present") + +/** + * Allocates a unique index to every `var_decl` in the logical plan. We use metas for this step to avoid a having + * create another permuted domain. + */ +private class VariableIdAllocator( + val allLocals: MutableList +) : PartiqlLogical.VisitorTransform() { + private var nextVariableId = 0 + + override fun transformVarDecl(node: PartiqlLogical.VarDecl): PartiqlLogical.VarDecl = + node.withMeta(VARIABLE_ID_META_TAG, nextVariableId).also { + + allLocals.add( + PartiqlLogicalResolved.build { + localVariable(node.name.text, nextVariableId.toLong()) + } + ) + + nextVariableId++ + } +} diff --git a/lang/test/org/partiql/lang/planner/Util.kt b/lang/test/org/partiql/lang/planner/Util.kt new file mode 100644 index 0000000000..0c184155b2 --- /dev/null +++ b/lang/test/org/partiql/lang/planner/Util.kt @@ -0,0 +1,22 @@ +package org.partiql.lang.planner + +import org.partiql.lang.ast.SourceLocationMeta +import org.partiql.lang.errors.Problem +import org.partiql.lang.errors.ProblemDetails + +/** + * Creates a fake implementation of [GlobalBindings] with the specified [globalVariableNames]. + * + * The fake unique identifier of bound variables is computed to be `fake_uid_for_${globalVariableName}`. + */ +fun createFakeGlobalBindings(vararg globalVariableNames: Pair) = + GlobalBindings { bindingName -> + val matches = globalVariableNames.filter { bindingName.isEquivalentTo(it.first) } + when (matches.size) { + 0 -> ResolutionResult.Undefined + else -> ResolutionResult.GlobalVariable(matches.first().second) + } + } + +fun problem(line: Int, charOffset: Int, detail: ProblemDetails): Problem = + Problem(SourceLocationMeta(line.toLong(), charOffset.toLong()), detail) diff --git a/lang/test/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransformTests.kt b/lang/test/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransformTests.kt new file mode 100644 index 0000000000..ccfdfe4628 --- /dev/null +++ b/lang/test/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransformTests.kt @@ -0,0 +1,152 @@ +package org.partiql.lang.planner.transforms + +import com.amazon.ion.system.IonSystemBuilder +import com.amazon.ionelement.api.ionBool +import com.amazon.ionelement.api.ionInt +import com.amazon.ionelement.api.ionString +import com.amazon.ionelement.api.toIonValue +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.assertDoesNotThrow +import org.junit.jupiter.api.assertThrows +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import org.partiql.lang.domains.PartiqlLogical +import org.partiql.lang.domains.id +import org.partiql.lang.domains.pathExpr +import org.partiql.lang.syntax.SqlParser +import org.partiql.lang.util.ArgumentsProviderBase +import org.partiql.lang.util.SexpAstPrettyPrinter + +/** + * Test cases in this class might seem a little light--that's because [AstToLogicalVisitorTransform] is getting + * heavily exercised during many other integration tests. These should be considered "smoke tests". + */ +class AstToLogicalVisitorTransformTests { + private val ion = IonSystemBuilder.standard().build() + private val parser = SqlParser(ion) + + private fun parseAndTransform(sql: String): PartiqlLogical.Statement { + val parseAstStatement = parser.parseAstStatement(sql) + println(SexpAstPrettyPrinter.format(parseAstStatement.toIonElement().asAnyElement().toIonValue(ion))) + return parseAstStatement.toLogicalPlan().stmt + } + + data class TestCase(val sql: String, val expectedAlgebra: PartiqlLogical.Statement) + + private fun runTestCase(tc: TestCase) { + val algebra = assertDoesNotThrow("Parsing TestCase.sql should not throw") { + parseAndTransform(tc.sql) + } + println(SexpAstPrettyPrinter.format(algebra.toIonElement().asAnyElement().toIonValue(ion))) + Assertions.assertEquals(tc.expectedAlgebra, algebra) + } + + @ParameterizedTest + @ArgumentsSource(ArgumentsForToLogicalTests::class) + fun `to logical`(tc: TestCase) = runTestCase(tc) + + class ArgumentsForToLogicalTests : ArgumentsProviderBase() { + override fun getParameters() = listOf( + TestCase( + // Note: + // `SELECT * FROM bar AS b` is rewritten to `SELECT b.* FROM bar as b` by [SelectStarVisitorTransform]. + // Therefore, there is no need to support `SELECT *` in `ToLogicalVisitorTransform`. + "SELECT b.* FROM bar AS b", + PartiqlLogical.build { + query( + bindingsToValues( + struct(structFields(id("b"))), + scan(id("bar"), varDecl("b")) + ) + ) + } + ), + TestCase( + // Note: This is supported by the AST -> logical -> physical transformation but should be rejected + // by the planner since it is a full table scan, which we won't support initially. + "SELECT b.* FROM bar AS b WHERE TRUE = TRUE", + PartiqlLogical.build { + query( + bindingsToValues( + struct(structFields(id("b"))), + filter( + eq(lit(ionBool(true)), lit(ionBool(true))), + scan(id("bar"), varDecl("b")) + ) + ) + ) + } + ), + TestCase( + "SELECT b.* FROM bar AS b WHERE b.primaryKey = 42", + PartiqlLogical.build { + query( + bindingsToValues( + struct(structFields(id("b"))), + filter( + eq(path(id("b"), pathExpr(lit(ionString("primaryKey")))), lit(ionInt(42))), + scan(id("bar"), varDecl("b")) + ) + ) + ) + } + ), + TestCase( + "SELECT DISTINCT b.* FROM bar AS b", + PartiqlLogical.build { + query( + call( + "filter_distinct", + bindingsToValues( + struct(structFields(id("b"))), + scan(id("bar"), varDecl("b")) + ) + ) + ) + } + ), + ) + } + + data class TodoTestCase(val sql: String) + @ParameterizedTest + @ArgumentsSource(ArgumentsForToToDoTests::class) + fun todo(tc: TodoTestCase) { + assertThrows("Parsing TestCase.sql should throw NotImplementedError") { + parseAndTransform(tc.sql) + } + } + + /** + * A list of statements that cannot be converted into the logical algebra yet by [ToLogicalVisitorTransform]. This + * is temporary--in the near future, we will accomplish this with a better language restriction feature which + * blocks all language features except those explicitly allowed. This will be needed to constrain possible queries + * to features supported by specific PartiQL-services. + */ + class ArgumentsForToToDoTests : ArgumentsProviderBase() { + override fun getParameters() = listOf( + // SELECT queries + TodoTestCase("SELECT b.* FROM UNPIVOT x as y"), + TodoTestCase("SELECT b.* FROM bar AS b GROUP BY a"), + TodoTestCase("SELECT b.* FROM bar AS b HAVING x"), + TodoTestCase("SELECT b.* FROM bar AS b ORDER BY y"), + TodoTestCase("PIVOT v AT n FROM data AS d"), + + // DML + TodoTestCase("CREATE TABLE foo"), + TodoTestCase("DROP TABLE foo"), + TodoTestCase("CREATE INDEX ON foo (x)"), + TodoTestCase("DROP INDEX bar ON foo"), + + // DDL + TodoTestCase("INSERT INTO foo VALUE 1"), + TodoTestCase("INSERT INTO foo VALUE 1"), + TodoTestCase("FROM x WHERE a = b SET k = 5"), + TodoTestCase("FROM x INSERT INTO foo VALUES (1, 2)"), + TodoTestCase("UPDATE x SET k = 5"), + TodoTestCase("UPDATE x INSERT INTO k << 1 >>"), + TodoTestCase("DELETE FROM y"), + TodoTestCase("REMOVE y"), + ) + } +} diff --git a/lang/test/org/partiql/lang/planner/transforms/LogicalResolvedToPhysicalVisitorTransformTests.kt b/lang/test/org/partiql/lang/planner/transforms/LogicalResolvedToPhysicalVisitorTransformTests.kt new file mode 100644 index 0000000000..c8f56e3789 --- /dev/null +++ b/lang/test/org/partiql/lang/planner/transforms/LogicalResolvedToPhysicalVisitorTransformTests.kt @@ -0,0 +1,69 @@ +package org.partiql.lang.planner.transforms + +import com.amazon.ionelement.api.ionBool +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import org.partiql.lang.domains.PartiqlLogicalResolved +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.util.ArgumentsProviderBase + +class LogicalResolvedToPhysicalVisitorTransformTests { + data class TestCase(val input: PartiqlLogicalResolved.Bexpr, val expected: PartiqlPhysical.Bexpr) + + @ParameterizedTest + @ArgumentsSource(ArgumentsForToPhysicalTests::class) + fun `to physical`(tc: TestCase) { + assertEquals(tc.expected, LogicalResolvedToPhysicalVisitorTransform().transformBexpr(tc.input)) + } + + class ArgumentsForToPhysicalTests : ArgumentsProviderBase() { + override fun getParameters() = listOf( + TestCase( + PartiqlLogicalResolved.build { + scan( + expr = globalId("foo", "foo"), + asDecl = varDecl(0), + atDecl = varDecl(1), + byDecl = varDecl(2) + ) + }, + PartiqlPhysical.build { + scan( + i = DEFAULT_IMPL, + expr = globalId("foo", "foo"), + asDecl = varDecl(0), + atDecl = varDecl(1), + byDecl = varDecl(2) + ) + } + ), + TestCase( + PartiqlLogicalResolved.build { + filter( + predicate = lit(ionBool(true)), + source = scan( + expr = globalId("foo", "foo"), + asDecl = varDecl(0), + atDecl = varDecl(1), + byDecl = varDecl(2) + ) + ) + }, + PartiqlPhysical.build { + filter( + i = DEFAULT_IMPL, + predicate = lit(ionBool(true)), + source = scan( + i = DEFAULT_IMPL, + expr = globalId("foo", "foo"), + asDecl = varDecl(0), + atDecl = varDecl(1), + byDecl = varDecl(2) + ) + ) + } + ) + ) + } +} diff --git a/lang/test/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransformTests.kt b/lang/test/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransformTests.kt new file mode 100644 index 0000000000..5d18eb7b94 --- /dev/null +++ b/lang/test/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransformTests.kt @@ -0,0 +1,685 @@ +package org.partiql.lang.planner.transforms + +import com.amazon.ion.system.IonSystemBuilder +import com.amazon.ionelement.api.ionSymbol +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.assertDoesNotThrow +import org.junit.jupiter.api.fail +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import org.partiql.lang.domains.PartiqlLogical +import org.partiql.lang.domains.PartiqlLogicalResolved +import org.partiql.lang.errors.Problem +import org.partiql.lang.errors.ProblemCollector +import org.partiql.lang.eval.BindingCase +import org.partiql.lang.eval.builtins.DYNAMIC_LOOKUP_FUNCTION_NAME +import org.partiql.lang.eval.sourceLocationMeta +import org.partiql.lang.planner.createFakeGlobalBindings +import org.partiql.lang.planner.problem +import org.partiql.lang.syntax.SqlParser +import org.partiql.lang.util.ArgumentsProviderBase +import org.partiql.lang.util.toIntExact + +private fun localVariable(name: String, index: Int) = + PartiqlLogicalResolved.build { localVariable(name, index.toLong()) } + +/** Shortcut for creating a dynamic lookup call site for the expected plans below. */ +private fun PartiqlLogicalResolved.Builder.dynamicLookup( + name: String, + case: BindingCase, + globalsFirst: Boolean = false, + vararg searchTargets: PartiqlLogicalResolved.Expr +) = + call( + DYNAMIC_LOOKUP_FUNCTION_NAME, + listOf( + lit(ionSymbol(name)), + lit( + ionSymbol( + when (case) { + BindingCase.SENSITIVE -> "case_sensitive" + BindingCase.INSENSITIVE -> "case_insensitive" + } + ) + ), + lit( + ionSymbol( + when { + globalsFirst -> "globals_then_locals" + else -> "locals_then_globals" + } + ) + ) + ) + searchTargets + ) + +class LogicalToLogicalResolvedVisitorTransformTests { + data class TestCase( + val sql: String, + val expectation: Expectation, + val allowUndefinedVariables: Boolean = false + ) + + data class ResolvedId( + val line: Int, + val charOffset: Int, + val expr: PartiqlLogicalResolved.Expr + ) { + constructor( + line: Int, + charOffset: Int, + build: PartiqlLogicalResolved.Builder.() -> PartiqlLogicalResolved.Expr + ) : this(line, charOffset, PartiqlLogicalResolved.BUILDER().build()) + + override fun toString(): String { + return "($line, $charOffset): $expr" + } + } + + sealed class Expectation { + data class Success( + val expectedIds: List, + val expectedLocalVariables: List + ) : Expectation() { + constructor(vararg expectedIds: ResolvedId) : this(expectedIds.toList(), emptyList()) + fun withLocals(vararg expectedLocalVariables: PartiqlLogicalResolved.LocalVariable) = + this.copy(expectedLocalVariables = expectedLocalVariables.toList()) + } + data class Problems(val problems: List) : Expectation() { + constructor(vararg problems: Problem) : this(problems.toList()) + } + } + + /** Mock table resolver. That can resolve f, foo, or UPPERCASE_FOO, while respecting case-sensitivity. */ + private val globalBindings = createFakeGlobalBindings( + *listOf( + "shadow", + "foo", + "bar", + "bat", + "UPPERCASE_FOO", + "case_AMBIGUOUS_foo", + "case_ambiguous_FOO" + ).map { + it to "fake_uid_for_$it" + }.toTypedArray() + ) + + private val ion = IonSystemBuilder.standard().build() + private val parser = SqlParser(ion) + + private fun runTestCase(tc: TestCase) { + val plan: PartiqlLogical.Plan = assertDoesNotThrow { + parser.parseAstStatement(tc.sql).toLogicalPlan() + } + + val problemHandler = ProblemCollector() + + when (tc.expectation) { + is Expectation.Success -> { + val resolved = plan.toResolvedPlan(problemHandler, globalBindings, tc.allowUndefinedVariables) + + // extract all of the dynamic, global and local ids from the resolved logical plan. + val actualResolvedIds = + object : PartiqlLogicalResolved.VisitorFold>() { + override fun visitExpr( + node: PartiqlLogicalResolved.Expr, + accumulator: List + ): List = + when (node) { + is PartiqlLogicalResolved.Expr.GlobalId, + is PartiqlLogicalResolved.Expr.LocalId -> accumulator + node + is PartiqlLogicalResolved.Expr.Call -> { + if (node.funcName.text == DYNAMIC_LOOKUP_FUNCTION_NAME) { + accumulator + node + } else { + accumulator + } + } + else -> accumulator + } + + // Don't include children of dynamic lookup callsites + override fun walkExprCall( + node: PartiqlLogicalResolved.Expr.Call, + accumulator: List + ): List { + return if (node.funcName.text == DYNAMIC_LOOKUP_FUNCTION_NAME) { + accumulator + } else { + super.walkExprCall(node, accumulator) + } + } + }.walkPlan(resolved, emptyList()) + + assertEquals( + tc.expectation.expectedIds.size, actualResolvedIds.size, + "Number of expected resovled variables must match actual" + ) + + val remainingActualResolvedIds = actualResolvedIds.map { + val location = it.metas.sourceLocationMeta ?: error("$it missing source location meta") + ResolvedId(location.lineNum.toIntExact(), location.charOffset.toIntExact()) { it } + }.filter { expectedId: ResolvedId -> + tc.expectation.expectedIds.none { actualId -> actualId == expectedId } + } + + if (remainingActualResolvedIds.isNotEmpty()) { + val sb = StringBuilder() + sb.appendLine("Unexpected ids:") + remainingActualResolvedIds.forEach { + sb.appendLine(it) + } + sb.appendLine("Expected ids:") + tc.expectation.expectedIds.forEach { + sb.appendLine(it) + } + + fail("Unmatched resolved ids were found.\n$sb") + } + + assertEquals( + tc.expectation.expectedLocalVariables, + resolved.locals, + "Expected and actual local variables must match" + ) + } + is Expectation.Problems -> { + assertDoesNotThrow("Should not throw when variables are undefined") { + plan.toResolvedPlan(problemHandler, globalBindings) + } + assertEquals(tc.expectation.problems, problemHandler.problems) + } + } + } + + @ParameterizedTest + @ArgumentsSource(CaseInsensitiveGlobalsCases::class) + fun `case-insensitive globals`(tc: TestCase) = runTestCase(tc) + class CaseInsensitiveGlobalsCases : ArgumentsProviderBase() { + override fun getParameters() = listOf( + // Case-insensitive resolution of global variables... + TestCase( + // all uppercase + sql = "FOO", + expectation = Expectation.Success(ResolvedId(1, 1) { globalId("FOO", "fake_uid_for_foo") }) + ), + TestCase( + // all lower case + "foo", + Expectation.Success(ResolvedId(1, 1) { globalId("foo", "fake_uid_for_foo") }) + ), + TestCase( + // mixed case + "fOo", + Expectation.Success(ResolvedId(1, 1) { globalId("fOo", "fake_uid_for_foo") }) + ), + TestCase( + // undefined + """ foobar """, + Expectation.Problems( + problem( + 1, + 2, + PlanningProblemDetails.UndefinedVariable("foobar", caseSensitive = false) + ) + ) + ), + + // Ambiguous case-insensitive lookup + TestCase( + // ambiguous + """case_ambiguous_foo """, + // In this case, we resolve to the first matching binding. This is consistent with Postres 9.6. + Expectation.Success( + ResolvedId(1, 1) { + globalId( + "case_ambiguous_foo", + "fake_uid_for_case_AMBIGUOUS_foo" + ) + } + ) + ), + + // Case-insensitive resolution of global variables with all uppercase letters... + TestCase( + // all uppercase + "UPPERCASE_FOO", + Expectation.Success( + ResolvedId(1, 1) { + globalId( + "UPPERCASE_FOO", + "fake_uid_for_UPPERCASE_FOO" + ) + } + ) + ), + TestCase( + // all lower case + "uppercase_foo", + Expectation.Success( + ResolvedId(1, 1) { + globalId( + "uppercase_foo", + "fake_uid_for_UPPERCASE_FOO" + ) + } + ) + ), + TestCase( + // mixed case + "UpPeRcAsE_fOo", + Expectation.Success( + ResolvedId(1, 1) { + globalId( + "UpPeRcAsE_fOo", + "fake_uid_for_UPPERCASE_FOO" + ) + } + ) + ), + + // undefined variables allowed + TestCase( + // undefined allowed (case-insensitive) + """some_undefined """, + Expectation.Success( + ResolvedId(1, 1) { + dynamicLookup("some_undefined", BindingCase.INSENSITIVE, globalsFirst = false) + } + ), + allowUndefinedVariables = true + ), + ) + } + + @ParameterizedTest + @ArgumentsSource(CaseSensitiveGlobalsCases::class) + fun `case-sensitive globals`(tc: TestCase) = runTestCase(tc) + class CaseSensitiveGlobalsCases : ArgumentsProviderBase() { + override fun getParameters() = listOf( + // Case-sensitive resolution of global variable with all lowercase letters + TestCase( + // all uppercase + "\"FOO\"", + Expectation.Problems( + problem( + 1, + 1, + PlanningProblemDetails.UndefinedVariable("FOO", caseSensitive = true) + ) + ) + ), + TestCase( + // all lowercase + "\"foo\"", + Expectation.Success(ResolvedId(1, 1) { globalId("foo", "fake_uid_for_foo") }) + ), + TestCase( + // mixed + "\"foO\"", + Expectation.Problems( + problem( + 1, + 1, + PlanningProblemDetails.UndefinedVariable("foO", caseSensitive = true) + ) + ) + ), + + // Case-sensitive resolution of global variables with all uppercase letters + TestCase( + // all uppercase + "\"UPPERCASE_FOO\"", + Expectation.Success( + ResolvedId(1, 1) { + globalId( + "UPPERCASE_FOO", + "fake_uid_for_UPPERCASE_FOO" + ) + } + ) + ), + TestCase( + // all lowercase + "\"uppercase_foo\"", + Expectation.Problems( + problem(1, 1, PlanningProblemDetails.UndefinedVariable("uppercase_foo", caseSensitive = true)) + ) + ), + TestCase( + // mixed + "\"UpPeRcAsE_fOo\"", + Expectation.Problems( + problem(1, 1, PlanningProblemDetails.UndefinedVariable("UpPeRcAsE_fOo", caseSensitive = true)) + ) + ), + TestCase( + // not ambiguous when case-sensitive + "\"case_AMBIGUOUS_foo\"", + Expectation.Success( + ResolvedId(1, 1) { + globalId("case_AMBIGUOUS_foo", "fake_uid_for_case_AMBIGUOUS_foo") + } + ) + ), + TestCase( + // not ambiguous when case-sensitive + "\"case_ambiguous_FOO\"", + Expectation.Success( + ResolvedId(1, 1) { + globalId("case_ambiguous_FOO", "fake_uid_for_case_ambiguous_FOO") + } + ) + ), + TestCase( + // undefined + """ FOOBAR """, + Expectation.Problems( + problem( + 1, + 2, + PlanningProblemDetails.UndefinedVariable("FOOBAR", caseSensitive = false) + ) + ) + ), + + TestCase( + // undefined allowed (case-sensitive) + "\"some_undefined\"", + Expectation.Success( + ResolvedId(1, 1) { + dynamicLookup("some_undefined", BindingCase.SENSITIVE) + } + ), + allowUndefinedVariables = true + ) + ) + } + + @ParameterizedTest + @ArgumentsSource(CaseInsensitiveLocalsVariablesCases::class) + fun `case-insensitive local variables`(tc: TestCase) = runTestCase(tc) + class CaseInsensitiveLocalsVariablesCases : ArgumentsProviderBase() { + override fun getParameters() = listOf( + // Case-insensitive resolution of local variables with all lowercase letters... + TestCase( + // all uppercase + "SELECT FOO.* FROM 1 AS foo WHERE FOO", + Expectation.Success( + ResolvedId(1, 8) { localId(0) }, + ResolvedId(1, 34) { localId(0) } + ).withLocals(localVariable("foo", 0)) + ), + TestCase( + // all lowercase + "SELECT foo.* FROM 1 AS foo WHERE foo", + Expectation.Success( + ResolvedId(1, 8) { localId(0) }, + ResolvedId(1, 34) { localId(0) } + ).withLocals(localVariable("foo", 0)) + ), + TestCase( + // mixed case + "SELECT FoO.* FROM 1 AS foo WHERE fOo", + Expectation.Success( + ResolvedId(1, 8) { localId(0) }, + ResolvedId(1, 34) { localId(0) } + ).withLocals(localVariable("foo", 0)) + ), + TestCase( + // foobar is undefined (select list) + "SELECT foobar.* FROM [] AS foo", + Expectation.Problems( + problem(1, 8, PlanningProblemDetails.UndefinedVariable("foobar", caseSensitive = false)) + ) + ), + TestCase( + // barbat is undefined (where clause) + "SELECT foo.* FROM [] AS foo WHERE barbat", + Expectation.Problems( + problem(1, 35, PlanningProblemDetails.UndefinedVariable("barbat", caseSensitive = false)) + ) + ) + ) + } + + @ParameterizedTest + @ArgumentsSource(CaseSensitiveLocalVariablesCases::class) + fun `case-sensitive locals variables`(tc: TestCase) = runTestCase(tc) + class CaseSensitiveLocalVariablesCases : ArgumentsProviderBase() { + override fun getParameters() = listOf( + // Case-insensitive resolution of local variables with all lowercase letters... + TestCase( + // all uppercase + "SELECT \"FOO\".* FROM 1 AS foo WHERE \"FOO\"", + Expectation.Problems( + problem(1, 8, PlanningProblemDetails.UndefinedVariable("FOO", caseSensitive = true)), + problem(1, 36, PlanningProblemDetails.UndefinedVariable("FOO", caseSensitive = true)) + ) + ), + TestCase( + // all lowercase + "SELECT \"foo\".* FROM 1 AS foo WHERE \"foo\"", + Expectation.Success( + ResolvedId(1, 8) { localId(0) }, + ResolvedId(1, 36) { localId(0) }, + ).withLocals(localVariable("foo", 0)) + ), + TestCase( + // mixed case + "SELECT \"FoO\".* FROM 1 AS foo WHERE \"fOo\"", + Expectation.Problems( + problem(1, 8, PlanningProblemDetails.UndefinedVariable("FoO", caseSensitive = true)), + problem(1, 36, PlanningProblemDetails.UndefinedVariable("fOo", caseSensitive = true)) + ) + ), + TestCase( + // "foobar" is undefined (select list) + "SELECT \"foobar\".* FROM [] AS foo ", + Expectation.Problems( + problem(1, 8, PlanningProblemDetails.UndefinedVariable("foobar", caseSensitive = true)) + ) + ), + TestCase( + // "barbat" is undefined (where clause) + "SELECT \"foo\".* FROM [] AS foo WHERE \"barbat\"", + Expectation.Problems( + problem(1, 37, PlanningProblemDetails.UndefinedVariable("barbat", caseSensitive = true)) + ) + ) + ) + } + + @ParameterizedTest + @ArgumentsSource(DuplicateVariableCases::class) + fun `duplicate variables`(tc: TestCase) = runTestCase(tc) + class DuplicateVariableCases : ArgumentsProviderBase() { + override fun getParameters() = listOf( + // Duplicate variables with same case + TestCase( + "SELECT {}.* FROM 1 AS a AT a", + Expectation.Problems(problem(1, 28, PlanningProblemDetails.VariablePreviouslyDefined("a"))), + ), + TestCase( + "SELECT {}.* FROM 1 AS a BY a", + Expectation.Problems(problem(1, 28, PlanningProblemDetails.VariablePreviouslyDefined("a"))), + ), + TestCase( + "SELECT {}.* FROM 1 AS notdup AT a BY a", + Expectation.Problems(problem(1, 38, PlanningProblemDetails.VariablePreviouslyDefined("a"))), + ), + TestCase( + "SELECT {}.* FROM 1 AS a AT a BY a", + Expectation.Problems( + problem(1, 28, PlanningProblemDetails.VariablePreviouslyDefined("a")), + problem(1, 33, PlanningProblemDetails.VariablePreviouslyDefined("a")) + ), + ), + // Duplicate variables with different cases + TestCase( + "SELECT {}.* FROM 1 AS a AT A", + Expectation.Problems(problem(1, 28, PlanningProblemDetails.VariablePreviouslyDefined("A"))), + ), + TestCase( + "SELECT {}.* FROM 1 AS A BY a", + Expectation.Problems(problem(1, 28, PlanningProblemDetails.VariablePreviouslyDefined("a"))), + ), + TestCase( + "SELECT {}.* FROM 1 AS notdup AT a BY A", + Expectation.Problems(problem(1, 38, PlanningProblemDetails.VariablePreviouslyDefined("A"))), + ), + TestCase( + "SELECT {}.* FROM 1 AS foo AT fOo BY foO", + Expectation.Problems( + problem(1, 30, PlanningProblemDetails.VariablePreviouslyDefined("fOo")), + problem(1, 37, PlanningProblemDetails.VariablePreviouslyDefined("foO")) + ), + ) + // Future test cases: duplicate variables across joins, i.e. `foo AS a, bar AS a`, etc. + ) + } + + @ParameterizedTest + @ArgumentsSource(MiscLocalVariableCases::class) + fun `misc local variable`(tc: TestCase) = runTestCase(tc) + class MiscLocalVariableCases : ArgumentsProviderBase() { + private fun createScanTestCase(varName: String, expectedIndex: Int) = + TestCase( + "SELECT $varName.* FROM foo AS a AT b BY c", + Expectation.Success( + ResolvedId(1, 8) { localId(expectedIndex.toLong()) }, + ResolvedId(1, 17) { globalId("foo", "fake_uid_for_foo") } + ).withLocals(localVariable("a", 0), localVariable("b", 1), localVariable("c", 2)) + ) + + override fun getParameters() = listOf( + // Demonstrates that FROM source AS aliases work + createScanTestCase("a", 0), + // Demonstrates that FROM source AT aliases work + createScanTestCase("b", 1), + // Demonstrates that FROM source BY aliases work + createScanTestCase("c", 2), + + // Covers local variables in select list, global variables in FROM source, local variables in WHERE clause + TestCase( + "SELECT b.* FROM bar AS b WHERE b.primaryKey = 42", + Expectation.Success( + ResolvedId(1, 8) { localId(0) }, + ResolvedId(1, 17) { globalId("bar", "fake_uid_for_bar") }, + ResolvedId(1, 32) { localId(0) }, + ).withLocals(localVariable("b", 0)) + ), + + // Demonstrate that globals-first variable lookup only happens in the FROM clause. + TestCase( + "SELECT shadow.* FROM shadow AS shadow", // `shadow` defined here shadows the global `shadow` + Expectation.Success( + ResolvedId(1, 8) { localId(0) }, + ResolvedId(1, 22) { globalId("shadow", "fake_uid_for_shadow") } + ).withLocals(localVariable("shadow", 0)) + ), + + // JOIN with shadowing + TestCase( + // first `AS s` shadowed by second `AS s`. + "SELECT s.* FROM 1 AS s, @s AS s", + Expectation.Success( + ResolvedId(1, 8) { localId(1) }, + ResolvedId(1, 26) { localId(0) } + ).withLocals(localVariable("s", 0), localVariable("s", 1)) + // Note that these two variables (^) have the same name but different indexes. + ), + ) + } + + @ParameterizedTest + @ArgumentsSource(DynamicIdSearchCases::class) + fun `dynamic_lookup search order cases`(tc: TestCase) = runTestCase(tc) + class DynamicIdSearchCases : ArgumentsProviderBase() { + // The important thing being asserted here is the contents of the dynamicId.search, which + // defines the places we'll look for variables that are unresolved at compile time. + override fun getParameters() = listOf( + // Not in an SFW query (empty search path) + TestCase( + "undefined1 + undefined2", + Expectation.Success( + ResolvedId(1, 1) { dynamicLookup("undefined1", BindingCase.INSENSITIVE, globalsFirst = false) }, + ResolvedId(1, 14) { dynamicLookup("undefined2", BindingCase.INSENSITIVE, globalsFirst = false) } + ), + allowUndefinedVariables = true + ), + + // In select list and where clause + TestCase( + "SELECT undefined1 AS u FROM 1 AS f WHERE undefined2", // 1 from source + Expectation.Success( + ResolvedId(1, 8) { dynamicLookup("undefined1", BindingCase.INSENSITIVE, globalsFirst = false, localId(0)) }, + ResolvedId(1, 42) { dynamicLookup("undefined2", BindingCase.INSENSITIVE, globalsFirst = false, localId(0)) } + ).withLocals(localVariable("f", 0)), + allowUndefinedVariables = true + ), + TestCase( + sql = "SELECT undefined1 AS u FROM 1 AS a, 2 AS b WHERE undefined2", // 2 from sources + Expectation.Success( + ResolvedId(1, 8) { dynamicLookup("undefined1", BindingCase.INSENSITIVE, globalsFirst = false, localId(1), localId(0)) }, + ResolvedId(1, 50) { dynamicLookup("undefined2", BindingCase.INSENSITIVE, globalsFirst = false, localId(1), localId(0)) } + ).withLocals(localVariable("a", 0), localVariable("b", 1)), + allowUndefinedVariables = true + ), + TestCase( + sql = "SELECT undefined1 AS u FROM 1 AS f, 1 AS b, 1 AS t WHERE undefined2", // 3 from sources + Expectation.Success( + ResolvedId(1, 8) { + dynamicLookup("undefined1", BindingCase.INSENSITIVE, globalsFirst = false, localId(2), localId(1), localId(0)) + }, + ResolvedId(1, 58) { + dynamicLookup("undefined2", BindingCase.INSENSITIVE, globalsFirst = false, localId(2), localId(1), localId(0)) + } + ).withLocals(localVariable("f", 0), localVariable("b", 1), localVariable("t", 2)), + allowUndefinedVariables = true + ) + ) + } + + @ParameterizedTest + @ArgumentsSource(SubqueryCases::class) + fun `sub-queries`(tc: TestCase) = runTestCase(tc) + class SubqueryCases : ArgumentsProviderBase() { + override fun getParameters() = listOf( + TestCase( + // inner query does not reference variables outer query + "SELECT b.* FROM (SELECT a.* FROM 1 AS a) AS b", + Expectation.Success( + ResolvedId(1, 8) { localId(1) }, + ResolvedId(1, 25) { localId(0) }, + ).withLocals(localVariable("a", 0), localVariable("b", 1)) + ), + TestCase( + // inner query references variable from outer query. + "SELECT a.*, b.* FROM 1 AS a, (SELECT a.*, b.* FROM 1 AS x) AS b", + Expectation.Success( + // The variables reference in the outer query + ResolvedId(1, 8) { localId(0) }, + ResolvedId(1, 13) { localId(2) }, + // The variables reference in the inner query + ResolvedId(1, 38) { localId(0) }, + // Note that `b` from the outer query is not accessible inside the query so we fall back on dynamic lookup + ResolvedId(1, 43) { dynamicLookup("b", BindingCase.INSENSITIVE, globalsFirst = false, localId(1), localId(0)) } + ).withLocals(localVariable("a", 0), localVariable("x", 1), localVariable("b", 2)), + allowUndefinedVariables = true + ), + + // In FROM source + TestCase( + "SELECT f.*, u.* FROM 1 AS f, undefined AS u", + Expectation.Success( + ResolvedId(1, 8) { localId(0) }, + ResolvedId(1, 13) { localId(1) }, + ResolvedId(1, 30) { dynamicLookup("undefined", BindingCase.INSENSITIVE, globalsFirst = true, localId(0)) } + ).withLocals(localVariable("f", 0), localVariable("u", 1)), + allowUndefinedVariables = true + ), + ) + } +}