From 19837a24b92e28b08359e0c997109a6133f5ce7d Mon Sep 17 00:00:00 2001 From: "R. C. Howell" Date: Fri, 6 Oct 2023 16:00:46 -0700 Subject: [PATCH] Remove PlanFactory --- .../src/main/kotlin/org/partiql/ast/Ast.kt | 36 - .../org/partiql/ast/normalize/AstPass.kt | 11 + .../ast/normalize/NormalizeFromSource.kt | 1 - .../ast/normalize/NormalizeSelectList.kt | 1 - .../ast/normalize/NormalizeSelectStar.kt | 1 - .../lang/planner/transforms/AstToPlan.kt | 89 - .../lang/planner/transforms/plan/PlanTyper.kt | 1882 ---------- .../lang/planner/transforms/plan/PlanUtils.kt | 82 - .../planner/transforms/plan/RelConverter.kt | 518 --- .../planner/transforms/plan/RexConverter.kt | 692 ---- .../PartiQLSchemaInferencerTests.kt | 3197 +++++++++-------- .../src/main/kotlin/org/partiql/plan/Plan.kt | 19 - .../main/kotlin/org/partiql/planner/Env.kt | 10 +- .../partiql/planner/PartiQLPlannerDefault.kt | 4 +- .../partiql/planner/transforms/AstToPlan.kt | 34 +- .../planner/transforms/RelConverter.kt | 131 +- .../planner/transforms/RexConverter.kt | 143 +- .../org/partiql/planner/typer/PlanTyper.kt | 138 +- 18 files changed, 1901 insertions(+), 5088 deletions(-) delete mode 100644 partiql-ast/src/main/kotlin/org/partiql/ast/Ast.kt create mode 100644 partiql-ast/src/main/kotlin/org/partiql/ast/normalize/AstPass.kt delete mode 100644 partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/AstToPlan.kt delete mode 100644 partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/PlanTyper.kt delete mode 100644 partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/PlanUtils.kt delete mode 100644 partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/RelConverter.kt delete mode 100644 partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/RexConverter.kt delete mode 100644 partiql-plan/src/main/kotlin/org/partiql/plan/Plan.kt diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/Ast.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/Ast.kt deleted file mode 100644 index 8f8bc70b6b..0000000000 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/Ast.kt +++ /dev/null @@ -1,36 +0,0 @@ -package org.partiql.ast - -import org.partiql.ast.builder.AstFactoryImpl -import org.partiql.ast.sql.SqlBlock -import org.partiql.ast.sql.SqlDialect -import org.partiql.ast.sql.SqlLayout -import org.partiql.ast.sql.sql - -/** - * Singleton instance of the default factory; also accessible via `AstFactory.DEFAULT`. - */ -object Ast : AstBaseFactory() - -/** - * AstBaseFactory can be used to create a factory which extends from the factory provided by AstFactory.DEFAULT. - */ -public abstract class AstBaseFactory : AstFactoryImpl() { - // internal default overrides here -} - -/** - * Wraps a rewriter with a default entry point. - */ -public interface AstPass { - - public fun apply(statement: Statement): Statement -} - -/* - * Pretty-print this [AstNode] as SQL text with the given [SqlLayout] - */ -@JvmOverloads -public fun AstNode.sql( - layout: SqlLayout = SqlLayout.DEFAULT, - dialect: SqlDialect = SqlDialect.PARTIQL, -): String = accept(dialect, SqlBlock.Nil).sql(layout) diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/AstPass.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/AstPass.kt new file mode 100644 index 0000000000..e8ad093e70 --- /dev/null +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/AstPass.kt @@ -0,0 +1,11 @@ +package org.partiql.ast.normalize + +import org.partiql.ast.Statement + +/** + * Wraps a rewriter with a default entry point. + */ +public interface AstPass { + + public fun apply(statement: Statement): Statement +} diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeFromSource.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeFromSource.kt index 4b1915a4a2..ef6e9dde3b 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeFromSource.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeFromSource.kt @@ -1,7 +1,6 @@ package org.partiql.ast.normalize import org.partiql.ast.AstNode -import org.partiql.ast.AstPass import org.partiql.ast.Expr import org.partiql.ast.From import org.partiql.ast.Statement diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectList.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectList.kt index 5acc0c5d36..238b77bafb 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectList.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectList.kt @@ -1,6 +1,5 @@ package org.partiql.ast.normalize -import org.partiql.ast.AstPass import org.partiql.ast.Expr import org.partiql.ast.Select import org.partiql.ast.Statement diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectStar.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectStar.kt index 7105e2e382..7ad039559f 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectStar.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectStar.kt @@ -1,6 +1,5 @@ package org.partiql.ast.normalize -import org.partiql.ast.AstPass import org.partiql.ast.Expr import org.partiql.ast.From import org.partiql.ast.Identifier diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/AstToPlan.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/AstToPlan.kt deleted file mode 100644 index d79a20f0d4..0000000000 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/AstToPlan.kt +++ /dev/null @@ -1,89 +0,0 @@ -package org.partiql.lang.planner.transforms - -import org.partiql.lang.domains.PartiqlAst -import org.partiql.lang.eval.CompileOptions -import org.partiql.lang.eval.TypedOpBehavior -import org.partiql.lang.eval.visitors.AggregationVisitorTransform -import org.partiql.lang.eval.visitors.FromSourceAliasVisitorTransform -import org.partiql.lang.eval.visitors.OrderBySortSpecVisitorTransform -import org.partiql.lang.eval.visitors.PartiqlAstSanityValidator -import org.partiql.lang.eval.visitors.PipelinedVisitorTransform -import org.partiql.lang.eval.visitors.SelectListItemAliasVisitorTransform -import org.partiql.lang.eval.visitors.SelectStarVisitorTransform -import org.partiql.lang.planner.transforms.plan.RelConverter -import org.partiql.lang.planner.transforms.plan.RexConverter -import org.partiql.plan.PartiQLPlan -import org.partiql.plan.Rex -import org.partiql.plan.partiQLPlan - -/** - * Translate the PIG AST to an implementation of the PartiQL Plan Representation. - */ -object AstToPlan { - - /** - * Converts a PartiqlAst.Statement to a [PartiQLPlan] - */ - fun transform(statement: PartiqlAst.Statement): PartiQLPlan { - val ast = statement.normalize() - if (ast !is PartiqlAst.Statement.Query) { - unsupported(ast) - } - val root = transform(ast.expr) - return partiQLPlan( - version = PartiQLPlan.Version.PARTIQL_V0, - root = root, - ) - } - - // --- Internal --------------------------------------------- - - /** - * Common place to throw exceptions with access to the AST node. - * Error handling pattern is undecided - */ - internal fun unsupported(node: PartiqlAst.PartiqlAstNode): Nothing { - throw UnsupportedOperationException("node: $node") - } - - /** - * Normalizes a statement AST node. Copied from EvaluatingCompiler, and include the validation. - * - * Notes: - * - AST normalization assumes operating on statement rather than a query statement, but the normalization - * only changes the SFW nodes. There's room to simplify here. Also, you have to enter the transform at - * `transformStatement` or nothing happens. I initially had `transformQuery` but that doesn't work because - * the pipelinedVisitorTransform traversal can only be entered on statement. - */ - private fun PartiqlAst.Statement.normalize(): PartiqlAst.Statement { - val transform = PipelinedVisitorTransform( - SelectListItemAliasVisitorTransform(), - FromSourceAliasVisitorTransform(), - OrderBySortSpecVisitorTransform(), - AggregationVisitorTransform(), - SelectStarVisitorTransform() - ) - // normalize - val ast = transform.transformStatement(this) - // validate - val validatorCompileOptions = CompileOptions.build { typedOpBehavior(TypedOpBehavior.HONOR_PARAMETERS) } - PartiqlAstSanityValidator().validate(this, validatorCompileOptions) - return ast - } - - /** - * Convert Partiql.Ast.Expr to a Rex/Rel tree - */ - private fun transform(query: PartiqlAst.Expr): Rex = when (query) { - is PartiqlAst.Expr.Select -> { - // - val rex = RelConverter.convert(query) - rex - } - else -> { - // - val rex = RexConverter.convert(query) - rex - } - } -} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/PlanTyper.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/PlanTyper.kt deleted file mode 100644 index 2f3d7798e2..0000000000 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/PlanTyper.kt +++ /dev/null @@ -1,1882 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at: - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific - * language governing permissions and limitations under the License. - */ - -package org.partiql.lang.planner.transforms.plan - -import com.amazon.ionelement.api.ElementType -import com.amazon.ionelement.api.StringElement -import com.amazon.ionelement.api.TextElement -import org.partiql.errors.Problem -import org.partiql.errors.ProblemHandler -import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION -import org.partiql.lang.ast.passes.SemanticProblemDetails -import org.partiql.lang.ast.passes.inference.cast -import org.partiql.lang.eval.ExprValueType -import org.partiql.lang.eval.builtins.SCALAR_BUILTINS_DEFAULT -import org.partiql.lang.planner.PlanningProblemDetails -import org.partiql.lang.planner.transforms.PlannerSession -import org.partiql.lang.planner.transforms.impl.Metadata -import org.partiql.lang.planner.transforms.plan.PlanUtils.addType -import org.partiql.lang.planner.transforms.plan.PlanUtils.grabType -import org.partiql.lang.types.FunctionSignature -import org.partiql.lang.types.StaticTypeUtils -import org.partiql.lang.types.TypedOpParameter -import org.partiql.lang.types.UnknownArguments -import org.partiql.lang.util.cartesianProduct -import org.partiql.plan.Arg -import org.partiql.plan.Attribute -import org.partiql.plan.Binding -import org.partiql.plan.Case -import org.partiql.plan.ExcludeExpr -import org.partiql.plan.ExcludeStep -import org.partiql.plan.PlanNode -import org.partiql.plan.Property -import org.partiql.plan.Rel -import org.partiql.plan.Rex -import org.partiql.plan.Step -import org.partiql.plan.attribute -import org.partiql.plan.binding -import org.partiql.plan.rexId -import org.partiql.plan.util.PlanRewriter -import org.partiql.spi.BindingCase -import org.partiql.spi.BindingName -import org.partiql.spi.BindingPath -import org.partiql.types.AnyOfType -import org.partiql.types.AnyType -import org.partiql.types.BagType -import org.partiql.types.BoolType -import org.partiql.types.CollectionType -import org.partiql.types.DecimalType -import org.partiql.types.FloatType -import org.partiql.types.IntType -import org.partiql.types.ListType -import org.partiql.types.MissingType -import org.partiql.types.NullType -import org.partiql.types.NumberConstraint -import org.partiql.types.SexpType -import org.partiql.types.SingleType -import org.partiql.types.StaticType -import org.partiql.types.StringType -import org.partiql.types.StructType -import org.partiql.types.SymbolType -import org.partiql.types.TupleConstraint - -/** - * Types a given logical - */ -internal object PlanTyper : PlanRewriter() { - - /** - * Given a [Rex], types the logical plan by adding the output Type Environment to each relational operator. - * - * Along with typing, this also validates expressions for typing issues. - */ - internal fun type(node: Rex, ctx: Context): Rex { - return visitRex(node, ctx) as Rex - } - - /** - * Used for maintaining state through the visitors - */ - internal class Context( - internal val input: Rel?, - internal val session: PlannerSession, - internal val metadata: Metadata, - internal val scopingOrder: ScopingOrder, - internal val customFunctionSignatures: List, - internal val tolerance: MinimumTolerance = MinimumTolerance.FULL, - internal val problemHandler: ProblemHandler - ) { - internal val inputTypeEnv = input?.let { PlanUtils.getTypeEnv(it) } ?: emptyList() - internal val allFunctions: Map> = - (SCALAR_BUILTINS_DEFAULT.map { it.signature.name to it.signature } + customFunctionSignatures.map { it.name to it }) - .groupBy({ it.first }, { it.second }) - } - - /** - * Scoping - */ - internal enum class ScopingOrder { - GLOBALS_THEN_LEXICAL, - LEXICAL_THEN_GLOBALS - } - - /** - * [FULL] -- CANNOT tolerate references to unresolved variables - * [PARTIAL] -- CAN tolerate references to unresolved variables - */ - internal enum class MinimumTolerance { - FULL, - PARTIAL - } - - // - // - // RELATIONAL ALGEBRA OPERATORS - // - // - - override fun visitRelBag(node: Rel.Bag, ctx: Context): PlanNode { - TODO("BAG OPERATORS are not supported by the PartiQLTypeEnvInferencer yet.") - } - - override fun visitRel(node: Rel, ctx: Context): Rel = super.visitRel(node, ctx) as Rel - - override fun visitRelJoin(node: Rel.Join, ctx: Context): Rel.Join { - val lhs = visitRel(node.lhs, ctx) - val rhs = typeRel(node.rhs, lhs, ctx) - val newJoin = node.copy( - common = node.common.copy( - typeEnv = lhs.getTypeEnv() + rhs.getTypeEnv(), - ) - ) - val predicateType = when (val condition = node.condition) { - null -> StaticType.BOOL - else -> { - val predicate = typeRex(condition, newJoin, ctx) - // verify `JOIN` predicate is bool. If it's unknown, gives a null or missing error. If it could - // never be a bool, gives an incompatible data type for expression error - assertType(expected = StaticType.BOOL, actual = predicate.grabType() ?: handleMissingType(ctx), ctx) - - // continuation type (even in the case of an error) is [StaticType.BOOL] - StaticType.BOOL - } - } - return newJoin.copy( - condition = node.condition?.addType(predicateType) - ) - } - - /** - * Initial implementation of `EXCLUDE` schema inference. Until an RFC is finalized for `EXCLUDE` - * (https://github.com/partiql/partiql-spec/issues/39), this behavior is considered experimental and subject to - * change. - * - * So far this implementation includes - * - Excluding tuple attrs (e.g. t.a.b.c) - * - Excluding tuple wildcards (e.g. t.a.*.b) - * - Excluding collection indexes (e.g. t.a[0].b -- behavior subject to change; see below discussion) - * - Excluding collection wildcards (e.g. t.a[*].b) - * - * There are still discussion points regarding the following edge cases - * - EXCLUDE on a tuple attribute that doesn't exist -- give an error/warning? - * - currently no error - * - EXCLUDE on a tuple attribute that has duplicates -- give an error/warning? exclude one? exclude both? - * - currently excludes both w/ no error - * - EXCLUDE on a collection index as the last step -- mark element type as optional? - * - currently element type as-is - * - EXCLUDE on a collection index w/ remaining path steps -- mark last step's type as optional? - * - currently marks last step's type as optional - * - EXCLUDE on a binding tuple variable (e.g. SELECT ... EXCLUDE t FROM t) -- error? - * - currently a parser error - * - EXCLUDE on a union type -- give an error/warning? no-op? exclude on each type in union? - * - currently exclude on each union type - * - If SELECT list includes an attribute that is excluded, we could consider giving an error in PlanTyper or - * some other semantic pass - * - currently does not give an error - */ - override fun visitRelExclude(node: Rel.Exclude, ctx: Context): Rel.Exclude { - val input = visitRel(node.input, ctx) - val exprs = node.exprs - val typeEnv = input.getTypeEnv() - val newTypeEnv = exprs.fold(typeEnv) { tEnv, expr -> - excludeExpr(tEnv, expr, ctx) - } - return node.copy( - input = input, - common = node.common.copy( - typeEnv = newTypeEnv, - properties = input.getProperties() - ) - ) - } - - private fun attrEqualsExcludeRoot(attr: Attribute, expr: ExcludeExpr): Boolean { - val rootId = expr.root - return attr.name == rootId || (expr.rootCase == Case.INSENSITIVE && attr.name.equals(expr.root, ignoreCase = true)) - } - - private fun excludeExpr(attrs: List, expr: ExcludeExpr, ctx: Context): List { - val resultAttrs = mutableListOf() - val attrsExist = attrs.find { attr -> attrEqualsExcludeRoot(attr, expr) } != null - if (!attrsExist) { - handleUnresolvedExcludeExprRoot(expr.root, ctx) - } - attrs.forEach { attr -> - if (attrEqualsExcludeRoot(attr, expr)) { - if (expr.steps.isEmpty()) { - throw IllegalStateException("Empty `ExcludeExpr.steps` encountered. This should have been caught by the parser.") - } else { - val newType = excludeExprSteps(attr.type, expr.steps, lastStepAsOptional = false, ctx) - resultAttrs.add( - attr.copy( - type = newType - ) - ) - } - } else { - resultAttrs.add( - attr - ) - } - } - return resultAttrs - } - - private fun excludeExprSteps(type: StaticType, steps: List, lastStepAsOptional: Boolean, ctx: Context): StaticType { - fun excludeExprStepsStruct(s: StructType, steps: List, lastStepAsOptional: Boolean): StaticType { - val outputFields = mutableListOf() - val first = steps.first() - s.fields.forEach { field -> - when (first) { - is ExcludeStep.TupleAttr -> { - if (field.key == first.attr || (first.case == Case.INSENSITIVE && field.key.equals(first.attr, ignoreCase = true))) { - if (steps.size == 1) { - if (lastStepAsOptional) { - val newField = StructType.Field(field.key, field.value.asOptional()) - outputFields.add(newField) - } - } else { - outputFields.add(StructType.Field(field.key, excludeExprSteps(field.value, steps.drop(1), lastStepAsOptional, ctx))) - } - } else { - outputFields.add(field) - } - } - is ExcludeStep.TupleWildcard -> { - if (steps.size == 1) { - if (lastStepAsOptional) { - val newField = StructType.Field(field.key, field.value.asOptional()) - outputFields.add(newField) - } - } else { - outputFields.add(StructType.Field(field.key, excludeExprSteps(field.value, steps.drop(1), lastStepAsOptional, ctx))) - } - } - else -> { - // currently no change to field.value and no error thrown; could consider an error/warning in - // the future - outputFields.add(StructType.Field(field.key, field.value)) - } - } - } - return s.copy(fields = outputFields) - } - - fun excludeExprStepsCollection(c: CollectionType, steps: List, lastStepAsOptional: Boolean): StaticType { - var elementType = c.elementType - when (steps.first()) { - is ExcludeStep.CollectionIndex -> { - if (steps.size > 1) { - elementType = excludeExprSteps(elementType, steps.drop(1), lastStepAsOptional = true, ctx) - } - } - is ExcludeStep.CollectionWildcard -> { - if (steps.size > 1) { - elementType = - excludeExprSteps(elementType, steps.drop(1), lastStepAsOptional = lastStepAsOptional, ctx) - } - // currently no change to elementType if collection wildcard is last element; this behavior could - // change based on RFC definition - } - else -> { - // currently no change to elementType and no error thrown; could consider an error/warning in - // the future - } - } - return when (c) { - is BagType -> c.copy(elementType) - is ListType -> c.copy(elementType) - is SexpType -> c.copy(elementType) - } - } - - return when (type) { - is StructType -> excludeExprStepsStruct(type, steps, lastStepAsOptional) - is CollectionType -> excludeExprStepsCollection(type, steps, lastStepAsOptional) - is AnyOfType -> { - StaticType.unionOf( - type.types.map { - excludeExprSteps(it, steps, lastStepAsOptional, ctx) - }.toSet() - ) - } - else -> type - }.flatten() - } - - override fun visitRelUnpivot(node: Rel.Unpivot, ctx: Context): Rel.Unpivot { - val from = node - - val asSymbolicName = node.alias - ?: error("Unpivot alias is null. This wouldn't be the case if FromSourceAliasVisitorTransform was executed first.") - - val value = visitRex(from.value, ctx) as Rex - - val fromExprType = value.grabType() ?: handleMissingType(ctx) - - val valueType = getUnpivotValueType(fromExprType) - val typeEnv = mutableListOf(attribute(asSymbolicName, valueType)) - - from.at?.let { - val valueHasMissing = StaticTypeUtils.getTypeDomain(valueType).contains(ExprValueType.MISSING) - val valueOnlyHasMissing = valueHasMissing && StaticTypeUtils.getTypeDomain(valueType).size == 1 - when { - valueOnlyHasMissing -> { - typeEnv.add(attribute(it, StaticType.MISSING)) - } - valueHasMissing -> { - typeEnv.add(attribute(it, StaticType.STRING.asOptional())) - } - else -> { - typeEnv.add(attribute(it, StaticType.STRING)) - } - } - } - - node.by?.let { TODO("BY variable's inference is not implemented yet.") } - - return from.copy( - common = from.common.copy( - typeEnv = typeEnv - ), - value = value - ) - } - - override fun visitRelAggregate(node: Rel.Aggregate, ctx: Context): PlanNode { - val input = visitRel(node.input, ctx) - val calls = node.calls.map { binding(it.name, typeRex(it.value, input, ctx)) } - val groups = node.groups.map { binding(it.name, typeRex(it.value, input, ctx)) } - return node.copy( - calls = calls, - groups = groups, - common = node.common.copy( - typeEnv = groups.toAttributes(ctx) + calls.toAttributes(ctx) - ) - ) - } - - override fun visitRelProject(node: Rel.Project, ctx: Context): PlanNode { - val input = visitRel(node.input, ctx) - val typeEnv = node.bindings.flatMap { binding -> - val type = inferType(binding.value, input, ctx) - when (binding.value.isProjectAll()) { - true -> { - when (val structType = type as? StructType) { - null -> { - handleIncompatibleDataTypeForExprError(StaticType.STRUCT, type, ctx) - listOf(attribute(binding.name, type)) - } - else -> structType.fields.map { entry -> attribute(entry.key, entry.value) } - } - } - false -> listOf(attribute(binding.name, type)) - } - } - return node.copy( - input = input, - common = node.common.copy( - typeEnv = typeEnv, - properties = input.getProperties() - ) - ) - } - - override fun visitRelScan(node: Rel.Scan, ctx: Context): Rel { - val value = visitRex( - node.value, - Context( - ctx.input, - ctx.session, - ctx.metadata, - ScopingOrder.GLOBALS_THEN_LEXICAL, - ctx.customFunctionSignatures, - ctx.tolerance, - ctx.problemHandler - ) - ) as Rex - val asSymbolicName = node.alias ?: error("From Source Alias is null when it should not be.") - val valueType = value.grabType() ?: handleMissingType(ctx) - val sourceType = getElementTypeForFromSource(valueType) - - node.at?.let { TODO("AT is not supported yet.") } - node.by?.let { TODO("BY is not supported yet.") } - - return when (value) { - is Rex.Query.Collection -> when (value.constructor) { - null -> value.rel - else -> { - val typeEnv = listOf(attribute(asSymbolicName, sourceType)) - node.copy( - value = value, - common = node.common.copy( - typeEnv = typeEnv - ) - ) - } - } - else -> { - val typeEnv = listOf(attribute(asSymbolicName, sourceType)) - node.copy( - value = value, - common = node.common.copy( - typeEnv = typeEnv - ) - ) - } - } - } - - override fun visitRelFilter(node: Rel.Filter, ctx: Context): PlanNode { - val input = visitRel(node.input, ctx) - val condition = typeRex(node.condition, input, ctx) - assertType(StaticType.BOOL, condition.grabType() ?: handleMissingType(ctx), ctx) - return node.copy( - condition = condition, - input = input, - common = node.common.copy( - typeEnv = input.getTypeEnv(), - properties = input.getProperties() - ) - ) - } - - override fun visitRelSort(node: Rel.Sort, ctx: Context): PlanNode { - val input = visitRel(node.input, ctx) - return node.copy( - input = input, - common = node.common.copy( - typeEnv = input.getTypeEnv(), - properties = setOf(Property.ORDERED) - ) - ) - } - - override fun visitRelFetch(node: Rel.Fetch, ctx: Context): PlanNode { - val input = visitRel(node.input, ctx) - val limit = typeRex(node.limit, input, ctx) - val offset = typeRex(node.offset, input, ctx) - limit.grabType()?.let { assertAsInt(it, ctx) } - offset.grabType()?.let { assertAsInt(it, ctx) } - return node.copy( - input = input, - common = node.common.copy( - typeEnv = input.getTypeEnv(), - properties = input.getProperties() - ), - limit = limit, - offset = offset - ) - } - - // - // - // EXPRESSIONS - // - // - - override fun visitRexQueryScalarPivot(node: Rex.Query.Scalar.Pivot, ctx: Context): PlanNode { - // TODO: This is to match the StaticTypeInferenceVisitorTransform logic, but needs to be changed - return node.copy( - type = StaticType.STRUCT - ) - } - - override fun visitRexQueryScalarSubquery(node: Rex.Query.Scalar.Subquery, ctx: Context): PlanNode { - val query = visitRex(node.query, ctx) as Rex.Query.Collection - // If it is SELECT VALUE, do not coerce. - if (query.constructor != null) { - val type = query.type as? CollectionType - return node.copy(query = query, type = type?.elementType?.flatten()) - } - val type = when (val queryType = query.grabType() ?: handleMissingType(ctx)) { - is CollectionType -> queryType.elementType - else -> error("Query collection subqueries should always return a CollectionType.") - } - val resultType = when (type) { - is StructType -> { - if (StaticTypeUtils.isClosedSafe(type) == true && type.fields.size == 1) { - type.fields[0].value - } else { - handleCoercionError(ctx, type) - StaticType.ANY - } - } - else -> { - handleCoercionError(ctx, type) - StaticType.ANY - } - } - return node.copy( - query = query, - type = resultType.flatten() - ) - } - - override fun visitRex(node: Rex, ctx: Context): PlanNode = super.visitRex(node, ctx) - - override fun visitRexAgg(node: Rex.Agg, ctx: Context): PlanNode { - val funcName = node.id - val args = node.args.map { visitRex(it, ctx) as Rex } - // unwrap the type if this is a collectionType - val argType = when (val type = args[0].grabType() ?: handleMissingType(ctx)) { - is CollectionType -> type.elementType - else -> type - } - return node.copy( - type = computeReturnTypeForAggFunc(funcName, argType, ctx), - args = args - ) - } - - private fun computeReturnTypeForAggFunc(funcName: String, elementType: StaticType, ctx: Context): StaticType { - val elementTypes = elementType.allTypes - - fun List.convertMissingToNull() = toMutableSet().apply { - if (contains(StaticType.MISSING)) { - remove(StaticType.MISSING) - add(StaticType.NULL) - } - } - - fun StaticType.isUnknownOrNumeric() = isUnknown() || isNumeric() - - return when (funcName) { - "count" -> StaticType.INT - // In case that any element is MISSING or there is no element, we should return NULL - "max", "min" -> StaticType.unionOf(elementTypes.convertMissingToNull()) - "sum" -> when { - elementTypes.none { it.isUnknownOrNumeric() } -> { - handleInvalidInputTypeForAggFun(funcName, elementType, StaticType.unionOf(StaticType.NULL_OR_MISSING, StaticType.NUMERIC).flatten(), ctx) - StaticType.unionOf(StaticType.NULL, StaticType.NUMERIC) - } - // If any single type is mismatched, We should add MISSING to the result types set to indicate there is a chance of data mismatch error - elementTypes.any { !it.isUnknownOrNumeric() } -> StaticType.unionOf( - elementTypes.filter { it.isUnknownOrNumeric() }.toMutableSet().apply { add(StaticType.MISSING) } - ) - // In case that any element is MISSING or there is no element, we should return NULL - else -> StaticType.unionOf(elementTypes.convertMissingToNull()) - } - // "avg" returns DECIMAL or NULL - "avg" -> when { - elementTypes.none { it.isUnknownOrNumeric() } -> { - handleInvalidInputTypeForAggFun(funcName, elementType, StaticType.unionOf(StaticType.NULL_OR_MISSING, StaticType.NUMERIC).flatten(), ctx) - StaticType.unionOf(StaticType.NULL, StaticType.DECIMAL) - } - else -> StaticType.unionOf( - mutableSetOf().apply { - if (elementTypes.any { it.isUnknown() }) { add(StaticType.NULL) } - if (elementTypes.any { it.isNumeric() }) { add(StaticType.DECIMAL) } - // If any single type is mismatched, We should add MISSING to the result types set to indicate there is a chance of data mismatch error - if (elementTypes.any { !it.isUnknownOrNumeric() }) { add(StaticType.MISSING) } - } - ) - } - else -> error("Internal Error: Unsupported aggregate function. This probably indicates a parser bug.") - }.flatten() - } - - private fun handleInvalidInputTypeForAggFun(funcName: String, actualType: StaticType, expectedType: StaticType, ctx: Context) { - ctx.problemHandler.handleProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = SemanticProblemDetails.InvalidArgumentTypeForFunction( - functionName = funcName, - expectedType = expectedType, - actualType = actualType - ) - ) - ) - } - - override fun visitRexQueryScalar(node: Rex.Query.Scalar, ctx: Context): PlanNode = super.visitRexQueryScalar(node, ctx) - - override fun visitRexQuery(node: Rex.Query, ctx: Context): PlanNode = super.visitRexQuery(node, ctx) - - override fun visitRexQueryCollection(node: Rex.Query.Collection, ctx: Context): PlanNode { - val input = visitRel(node.rel, ctx) - val typeConstructor = when (input.getProperties().contains(Property.ORDERED)) { - true -> { type: StaticType -> ListType(type) } - false -> { type: StaticType -> BagType(type) } - } - return when (val constructor = node.constructor) { - null -> { - node.copy( - rel = input, - type = typeConstructor.invoke( - StructType( - fields = input.getTypeEnv().map { attribute -> - StructType.Field(attribute.name, attribute.type) - }, - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true), TupleConstraint.Ordered) - ) - ) - ) - } - else -> { - val constructorType = typeRex(constructor, input, ctx).grabType() ?: handleMissingType(ctx) - return node.copy( - type = typeConstructor.invoke(constructorType) - ) - } - } - } - - override fun visitRexPath(node: Rex.Path, ctx: Context): Rex.Path { - val ids = grabFirstIds(node) - val qualifier = ids.getOrNull(0)?.qualifier ?: Rex.Id.Qualifier.UNQUALIFIED - val path = BindingPath(ids.map { rexIdToBindingName(it) }) - val pathAndType = findBind(path, qualifier, ctx) - val remainingFirstIndex = pathAndType.levelsMatched - 1 - val remaining = when (remainingFirstIndex > node.steps.lastIndex) { - true -> emptyList() - false -> node.steps.subList(remainingFirstIndex, node.steps.size) - } - var currentType = pathAndType.type - remaining.forEach { pathComponent -> - currentType = when (pathComponent) { - is Step.Key -> { - val type = inferPathComponentExprType(currentType, pathComponent, ctx) - type - } - is Step.Wildcard -> currentType - is Step.Unpivot -> error("Not implemented yet") - } - } - return node.copy( - type = currentType - ) - } - - override fun visitRexId(node: Rex.Id, ctx: Context): Rex.Id { - val bindingPath = BindingPath(listOf(rexIdToBindingName(node))) - return node.copy(type = findBind(bindingPath, node.qualifier, ctx).type) - } - - override fun visitRexBinary(node: Rex.Binary, ctx: Context): Rex.Binary { - val lhs = visitRex(node.lhs, ctx).grabType() ?: handleMissingType(ctx) - val rhs = visitRex(node.rhs, ctx).grabType() ?: handleMissingType(ctx) - val args = listOf(lhs, rhs) - val type = when (node.op) { - Rex.Binary.Op.PLUS, Rex.Binary.Op.MINUS, Rex.Binary.Op.TIMES, Rex.Binary.Op.DIV, Rex.Binary.Op.MODULO -> when (hasValidOperandTypes(args, node.op.name, ctx) { it.isNumeric() }) { - true -> computeReturnTypeForNAry(args, PlanTyper::inferBinaryArithmeticOp) - false -> StaticType.NUMERIC // continuation type to prevent incompatible types and unknown errors from propagating - } - Rex.Binary.Op.BITWISE_AND -> when (hasValidOperandTypes(args, node.op.name, ctx) { it is IntType }) { - true -> computeReturnTypeForNAry(args, PlanTyper::inferBinaryArithmeticOp) - false -> StaticType.unionOf(StaticType.INT2, StaticType.INT4, StaticType.INT8, StaticType.INT) // continuation type to prevent incompatible types and unknown errors from propagating - } - - Rex.Binary.Op.CONCAT -> when (hasValidOperandTypes(args, node.op.name, ctx) { it.isText() }) { - true -> computeReturnTypeForNAry(args, PlanTyper::inferConcatOp) - false -> StaticType.STRING // continuation type to prevent incompatible types and unknown errors from propagating - } - Rex.Binary.Op.AND, Rex.Binary.Op.OR -> inferNaryLogicalOp(args, node.op.name, ctx) - Rex.Binary.Op.EQ, Rex.Binary.Op.NEQ -> when (operandsAreComparable(args, node.op.name, ctx)) { - true -> computeReturnTypeForNAry(args, PlanTyper::inferEqNeOp) - false -> StaticType.BOOL // continuation type to prevent incompatible types and unknown errors from propagating - } - Rex.Binary.Op.LT, Rex.Binary.Op.GT, Rex.Binary.Op.LTE, Rex.Binary.Op.GTE -> when (operandsAreComparable(args, node.op.name, ctx)) { - true -> computeReturnTypeForNAry(args, PlanTyper::inferComparatorOp) - false -> StaticType.BOOL // continuation type prevent incompatible types and unknown errors from propagating - } - } - return node.copy(type = type) - } - - override fun visitRexUnary(node: Rex.Unary, ctx: Context): PlanNode { - val valueType = visitRex(node.value, ctx).grabType() ?: handleMissingType(ctx) - val type = when (node.op) { - Rex.Unary.Op.NOT -> when (hasValidOperandTypes(listOf(valueType), node.op.name, ctx) { it is BoolType }) { - true -> computeReturnTypeForUnary(valueType, PlanTyper::inferNotOp) - false -> StaticType.BOOL // continuation type to prevent incompatible types and unknown errors from propagating - } - Rex.Unary.Op.POS -> when (hasValidOperandTypes(listOf(valueType), node.op.name, ctx) { it.isNumeric() }) { - true -> computeReturnTypeForUnary(valueType, PlanTyper::inferUnaryArithmeticOp) - false -> StaticType.NUMERIC - } - Rex.Unary.Op.NEG -> when (hasValidOperandTypes(listOf(valueType), node.op.name, ctx) { it.isNumeric() }) { - true -> computeReturnTypeForUnary(valueType, PlanTyper::inferUnaryArithmeticOp) - false -> StaticType.NUMERIC - } - } - return node.copy(type = type) - } - - // This type comes from RexConverter - override fun visitRexLit(node: Rex.Lit, ctx: Context): Rex.Lit = node - - override fun visitRexCollection(node: Rex.Collection, ctx: Context): PlanNode = super.visitRexCollection(node, ctx) - - override fun visitRexCollectionArray(node: Rex.Collection.Array, ctx: Context): PlanNode { - val typedValues = node.values.map { visitRex(it, ctx) as Rex } - val elementType = AnyOfType(typedValues.map { it.grabType() ?: handleMissingType(ctx) }.toSet()).flatten() - return node.copy(type = ListType(elementType), values = typedValues) - } - - override fun visitRexCollectionBag(node: Rex.Collection.Bag, ctx: Context): PlanNode { - val typedValues = node.values.map { visitRex(it, ctx) } - val elementType = AnyOfType(typedValues.map { it.grabType()!! }.toSet()).flatten() - return node.copy(type = BagType(elementType)) - } - - override fun visitRexCall(node: Rex.Call, ctx: Context): Rex.Call { - val processedNode = processRexCall(node, ctx) - visitRexCallManual(processedNode, ctx)?.let { return it } - val funcName = node.id - val signatures = ctx.allFunctions[funcName] - val arguments = processedNode.args.getTypes(ctx) - if (signatures == null) { - handleNoSuchFunctionError(ctx, funcName) - return node.copy(type = StaticType.ANY) - } - - var types: MutableSet = mutableSetOf() - val funcsMatchingArity = signatures.filter { it.arity.contains(arguments.size) } - if (funcsMatchingArity.isEmpty()) { - handleIncorrectNumberOfArgumentsToFunctionCallError(funcName, getMinMaxArities(signatures).first..getMinMaxArities(signatures).second, arguments.size, ctx) - } else { - if (node.type != null) { - return processedNode.copy(type = node.type) - } - for (sign in funcsMatchingArity) { - when (sign.unknownArguments) { - UnknownArguments.PROPAGATE -> types.add(returnTypeForPropagatingFunction(sign, arguments, ctx)) - UnknownArguments.PASS_THRU -> types.add(returnTypeForPassThruFunction(sign, arguments)) - } - } - } - - return processedNode.copy(type = StaticType.unionOf(types).flatten()) - } - - private fun getMinMaxArities(funcs: List): Pair { - val minArity = funcs.map { it.arity.first }.minOrNull() ?: Int.MAX_VALUE - val maxArity = funcs.map { it.arity.last }.maxOrNull() ?: Int.MIN_VALUE - - return Pair(minArity, maxArity) - } - - override fun visitRexSwitch(node: Rex.Switch, ctx: Context): PlanNode { - val match = node.match?.let { visitRex(it, ctx) as Rex } - val caseValueType = when (match) { - null -> null - else -> { - val type = match.grabType() ?: handleMissingType(ctx) - // comparison never succeeds if caseValue is an unknown - if (type.isUnknown()) { - handleExpressionAlwaysReturnsNullOrMissingError(ctx) - } - type - } - } - val check = when (caseValueType) { - null -> { conditionType: StaticType -> - conditionType.allTypes.none { it is BoolType } - } - else -> { conditionType: StaticType -> - !StaticTypeUtils.areStaticTypesComparable(caseValueType, conditionType) - } - } - val branches = node.branches.map { branch -> - val condition = visitRex(branch.condition, ctx) as Rex - val value = visitRex(branch.value, ctx) as Rex - val conditionType = condition.grabType() ?: handleMissingType(ctx) - // comparison never succeeds if whenExpr is unknown -> null or missing error - if (conditionType.isUnknown()) { - handleExpressionAlwaysReturnsNullOrMissingError(ctx) - } - // if caseValueType is incomparable to whenExprType -> data type mismatch - else if (check.invoke(conditionType)) { - handleIncompatibleDataTypesForOpError( - ctx, - actualTypes = listOfNotNull(caseValueType, conditionType), - op = "CASE" - ) - } - branch.copy(condition = condition, value = value) - } - val valueTypes = branches.map { it.value }.map { it.grabType() ?: handleMissingType(ctx) } - - // keep all the `THEN` expr types even if the comparison doesn't succeed - val default = node.default?.let { visitRex(it, ctx) } - val type = inferCaseWhenBranches(valueTypes, default?.grabType()) - return node.copy( - match = match, - branches = branches, - type = type - ) - } - - override fun visitRexTuple(node: Rex.Tuple, ctx: Context): PlanNode { - val fields = node.fields.map { field -> - field.copy( - name = visitRex(field.name, ctx) as Rex, - value = visitRex(field.value, ctx) as Rex - ) - } - - val structFields = mutableListOf() - var closedContent = true - fields.forEach { field -> - when (val name = field.name) { - is Rex.Lit -> - // A field is only included in the StructType if its key is a text literal - if (name.value is TextElement) { - val value = name.value as TextElement - val type = field.value.grabType() ?: handleMissingType(ctx) - structFields.add(StructType.Field(value.textValue, type)) - } - else -> { - // A field with a non-literal key name is not included. - // If the non-literal could be text, StructType will have open content. - val nameType = field.name.grabType() ?: handleMissingType(ctx) - if (nameType.allTypes.any { it.isText() }) { - closedContent = false - } - } - } - } - - val hasDuplicateKeys = structFields - .groupingBy { it.key } - .eachCount() - .any { it.value > 1 } - - return node.copy( - type = StructType( - structFields, - contentClosed = closedContent, - constraints = setOf(TupleConstraint.Open(closedContent.not()), TupleConstraint.UniqueAttrs(hasDuplicateKeys.not())) - ), - fields = fields - ) - } - - override fun visitArgValue(node: Arg.Value, ctx: Context): PlanNode { - return node.copy( - value = visitRex(node.value, ctx) as Rex - ) - } - - // - // - // HELPER METHODS - // - // - - private fun inferCaseWhenBranches(thenExprsTypes: List, elseExpr: StaticType?): StaticType { - val elseExprType = when (elseExpr) { - // If there is no ELSE clause in the expression, it possible that - // none of the WHEN clauses succeed and the output of CASE WHEN expression - // ends up being NULL - null -> StaticType.NULL - else -> elseExpr - } - - if (thenExprsTypes.any { it is AnyType } || elseExprType is AnyType) { - return StaticType.ANY - } - - val possibleTypes = thenExprsTypes + elseExprType - return AnyOfType(possibleTypes.toSet()).flatten() - } - - /** - * Assumes that [node] has been pre-processed. - */ - private fun visitRexCallManual(node: Rex.Call, ctx: Context): Rex.Call? { - return when (node.id) { - RexConverter.Constants.inCollection -> visitRexCallInCollection(node, ctx) - RexConverter.Constants.between -> visitRexCallBetween(node, ctx) - RexConverter.Constants.like, RexConverter.Constants.likeEscape -> visitRexCallLike(node, ctx) - RexConverter.Constants.canCast, RexConverter.Constants.canLosslessCast, RexConverter.Constants.isType -> node.copy(type = StaticType.BOOL) - RexConverter.Constants.coalesce -> visitRexCallCoalesce(node, ctx) - RexConverter.Constants.nullIf -> visitRexCallNullIf(node, ctx) - RexConverter.Constants.cast -> visitRexCallCast(node, ctx) - RexConverter.Constants.outerBagExcept, - RexConverter.Constants.outerBagIntersect, - RexConverter.Constants.outerBagUnion, - RexConverter.Constants.outerSetExcept, - RexConverter.Constants.outerSetIntersect, - RexConverter.Constants.outerSetUnion -> TODO("Bag Operators have not been implemented yet.") - else -> null - } - } - - private fun processRexCall(node: Rex.Call, ctx: Context): Rex.Call { - val args = node.args.visit(ctx) - return node.copy(args = args) - } - - /** - * [node] must be pre-processed - */ - private fun visitRexCallNullIf(node: Rex.Call, ctx: Context): Rex.Call { - // check for comparability of the two arguments to `NULLIF` - operandsAreComparable(node.args.getTypes(ctx), node.id, ctx) - - // output type will be the first argument's types along with `NULL` (even in the case of an error) - val possibleOutputTypes = node.args[0].grabType()?.asNullable() ?: handleMissingType(ctx) - return node.copy(type = possibleOutputTypes) - } - - /** - * [node] must be pre-processed - */ - private fun visitRexCallCast(node: Rex.Call, ctx: Context): Rex.Call { - val sourceType = node.args[0].grabType() ?: handleMissingType(ctx) - val targetType = node.args[1].grabType() ?: handleMissingType(ctx) - val targetTypeParam = targetType.toTypedOpParameter() - val castOutputType = sourceType.cast(targetType).let { - if (targetTypeParam.validationThunk == null) { - // There is no additional validation for this parameter, return this type as-is - it - } else { - StaticType.unionOf(StaticType.MISSING, it) - } - } - return node.copy(type = castOutputType) - } - - private fun StaticType.toTypedOpParameter(): TypedOpParameter { - return TypedOpParameter(staticType = this) - } - - /** - * [node] must be pre-processed - */ - private fun visitRexCallCoalesce(node: Rex.Call, ctx: Context): Rex.Call { - var allMissing = true - val outputTypes = mutableSetOf() - - val args = node.args.map { visitArg(it, ctx) } - for (arg in args) { - val staticType = arg.grabType() ?: handleMissingType(ctx) - val staticTypes = staticType.allTypes - outputTypes += staticTypes - // If at least one known type is found, remove null and missing from the result - // It means there is at least one type which doesn't contain unknown types. - if (staticTypes.all { type -> !type.isNullOrMissing() }) { - outputTypes.remove(StaticType.MISSING) - outputTypes.remove(StaticType.NULL) - break - } - if (!staticTypes.contains(StaticType.MISSING)) { - allMissing = false - } - } - // If every argument has MISSING as one of it's types, - // then output should contain MISSING and not otherwise. - if (!allMissing) { - outputTypes.remove(StaticType.MISSING) - } - - return node.copy( - type = when (outputTypes.size) { - 1 -> outputTypes.first() - else -> StaticType.unionOf(outputTypes) - } - ) - } - - /** - * [node] must be pre-processed - */ - private fun visitRexCallLike(node: Rex.Call, ctx: Context): Rex.Call { - val argTypes = node.args.getTypes(ctx) - val argsAllTypes = argTypes.map { it.allTypes } - - if (!hasValidOperandTypes(argTypes, "LIKE", ctx) { it.isText() }) { - return node.copy(type = StaticType.BOOL) - } - - val possibleReturnTypes: MutableSet = mutableSetOf() - argsAllTypes.cartesianProduct().forEach { argsChildType -> - val argsSingleType = argsChildType.map { it as SingleType } - when { - // If any one of the operands is null, return NULL - argsSingleType.any { it is NullType } -> possibleReturnTypes.add(StaticType.NULL) - // Arguments for LIKE need to be text type - argsSingleType.all { it.isText() } -> { - possibleReturnTypes.add(StaticType.BOOL) - // If the optional escape character is provided, it can result in failure even if the type is text (string, in this case) - // This is because the escape character needs to be a single character (string with length 1), - // Even if the escape character is of length 1, escape sequence can be incorrect. - if (node.args.getOrNull(2) != null) { - possibleReturnTypes.add(StaticType.MISSING) - } - } - else -> possibleReturnTypes.add(StaticType.MISSING) - } - } - - return node.copy(type = StaticType.unionOf(possibleReturnTypes).flatten()) - } - - /** - * [node] must be pre-processed - */ - private fun visitRexCallInCollection(node: Rex.Call, ctx: Context): Rex.Call { - val operands = node.args.getTypes(ctx) - val lhs = operands[0] - val rhs = operands[1] - var errorAdded = false - - // check if any operands are unknown, then null or missing error - if (operands.any { operand -> operand.isUnknown() }) { - handleExpressionAlwaysReturnsNullOrMissingError(ctx) - errorAdded = true - } - - // if none of the [rhs] types are [CollectionType]s with comparable element types to [lhs], then data type - // mismatch error - if (!rhs.isUnknown() && rhs.allTypes.none { - it is CollectionType && StaticTypeUtils.areStaticTypesComparable(it.elementType, lhs) - } - ) { - handleIncompatibleDataTypesForOpError(ctx, operands, "IN") - errorAdded = true - } - - return when (errorAdded) { - true -> StaticType.BOOL - false -> computeReturnTypeForNAryIn(operands) - }.let { node.copy(type = it) } - } - - private fun computeReturnTypeForNAryIn(argTypes: List): StaticType { - require(argTypes.size >= 2) { "IN must have at least two args" } - val leftTypes = argTypes.first().allTypes - val rightTypes = argTypes.drop(1).flatMap { it.allTypes } - - val finalTypes = leftTypes - .flatMap { left -> - rightTypes.flatMap { right -> - computeReturnTypeForBinaryIn(left, right).allTypes - } - }.distinct() - - return when (finalTypes.size) { - 1 -> finalTypes.first() - else -> StaticType.unionOf(*finalTypes.toTypedArray()) - } - } - - private fun computeReturnTypeForBinaryIn(left: StaticType, right: StaticType): StaticType = - when (right) { - is NullType -> when (left) { - is MissingType -> StaticType.MISSING - else -> StaticType.NULL - } - is MissingType -> StaticType.MISSING - is CollectionType -> when (left) { - is NullType -> StaticType.NULL - is MissingType -> StaticType.MISSING - else -> { - val rightElemTypes = right.elementType.allTypes - val possibleTypes = mutableSetOf() - if (rightElemTypes.any { it is MissingType }) { - possibleTypes.add(StaticType.MISSING) - } - if (rightElemTypes.any { it is NullType }) { - possibleTypes.add(StaticType.NULL) - } - if (rightElemTypes.any { !it.isNullOrMissing() }) { - possibleTypes.add(StaticType.BOOL) - } - StaticType.unionOf(possibleTypes).flatten() - } - } - else -> when (left) { - is NullType -> StaticType.unionOf(StaticType.NULL, StaticType.MISSING) - else -> StaticType.MISSING - } - } - - /** - * [node] must be pre-processed - */ - private fun visitRexCallBetween(node: Rex.Call, ctx: Context): Rex.Call { - val argTypes = listOf(node.args[0], node.args[1], node.args[2]).getTypes(ctx) - if (!operandsAreComparable(argTypes, node.id, ctx)) { - return node.copy(type = StaticType.BOOL) - } - - val argsAllTypes = argTypes.map { it.allTypes } - val possibleReturnTypes: MutableSet = mutableSetOf() - - argsAllTypes.cartesianProduct().forEach { argsChildType -> - val argsSingleType = argsChildType.map { it as SingleType } - when { - // If any one of the operands is null or missing, return NULL - argsSingleType.any { it is NullType || it is MissingType } -> possibleReturnTypes.add(StaticType.NULL) - StaticTypeUtils.areStaticTypesComparable( - argsSingleType[0], - argsSingleType[1] - ) || StaticTypeUtils.areStaticTypesComparable(argsSingleType[0], argsSingleType[2]) -> possibleReturnTypes.add(StaticType.BOOL) - else -> possibleReturnTypes.add(StaticType.MISSING) - } - } - return node.copy(type = StaticType.unionOf(possibleReturnTypes).flatten()) - } - - private fun List.getTypes(ctx: Context): List = this.map { it.grabType() ?: handleMissingType(ctx) } - - private fun List.visit(ctx: Context): List = this.map { arg -> - when (arg) { - is Arg.Value -> { - val rex = visitRex(arg.value, ctx) as Rex - arg.copy(value = rex) - } - is Arg.Type -> arg - } - } - - /** - * Verifies the given [actual] has type [expected]. If [actual] is unknown, a null or missing - * error is given. If [actual] could never be [expected], an incompatible data types for - * expression error is given. - */ - private fun assertType(expected: StaticType, actual: StaticType, ctx: Context) { - // Relates to `verifyExpressionType` - if (actual.isUnknown()) { - handleExpressionAlwaysReturnsNullOrMissingError(ctx) - } else if (actual.allTypes.none { it == expected }) { - handleIncompatibleDataTypeForExprError( - expectedType = expected, - actualType = actual, - ctx = ctx - ) - } - } - - private fun getElementTypeForFromSource(fromSourceType: StaticType): StaticType = - when (fromSourceType) { - is BagType -> fromSourceType.elementType - is ListType -> fromSourceType.elementType - is AnyType -> StaticType.ANY - is AnyOfType -> AnyOfType(fromSourceType.types.map { getElementTypeForFromSource(it) }.toSet()) - // All the other types coerce into a bag of themselves (including null/missing/sexp). - else -> fromSourceType - } - - private fun Rel.getTypeEnv() = PlanUtils.getTypeEnv(this) - - private fun Rel.getProperties() = this.getCommon().properties - - private fun Rel.getCommon() = when (this) { - is Rel.Aggregate -> this.common - is Rel.Bag -> this.common - is Rel.Fetch -> this.common - is Rel.Filter -> this.common - is Rel.Join -> this.common - is Rel.Project -> this.common - is Rel.Scan -> this.common - is Rel.Sort -> this.common - is Rel.Unpivot -> this.common - is Rel.Exclude -> this.common - } - - private fun inferPathComponentExprType( - previousComponentType: StaticType, - currentPathComponent: Step.Key, - ctx: Context - ): StaticType = - when (previousComponentType) { - is AnyType -> StaticType.ANY - is StructType -> inferStructLookupType( - currentPathComponent, - previousComponentType - ).flatten() - is ListType, - is SexpType -> { - val previous = previousComponentType as CollectionType // help Kotlin's type inference to be more specific - val key = visitRex(currentPathComponent.value, ctx = ctx) - if (key.grabType() is IntType) { - previous.elementType - } else { - StaticType.MISSING - } - } - is AnyOfType -> { - when (previousComponentType.types.size) { - 0 -> throw IllegalStateException("Cannot path on an empty StaticType union") - else -> { - val prevTypes = previousComponentType.allTypes - if (prevTypes.any { it is AnyType }) { - StaticType.ANY - } else { - val staticTypes = prevTypes.map { inferPathComponentExprType(it, currentPathComponent, ctx) } - AnyOfType(staticTypes.toSet()).flatten() - } - } - } - } - else -> StaticType.MISSING - } - - private fun inferStructLookupType( - currentPathComponent: Step.Key, - struct: StructType - ): StaticType = - when (val key = currentPathComponent.value) { - is Rex.Lit -> { - if (key.value is StringElement) { - val case = rexCaseToBindingCase(currentPathComponent.case) - ReferenceResolver.inferStructLookup(struct, BindingName(key.value.asAnyElement().stringValue, case)) - ?: when (struct.contentClosed) { - true -> StaticType.MISSING - false -> StaticType.ANY - } - } else { - // Should this branch result in an error? - StaticType.MISSING - } - } - else -> { - StaticType.MISSING - } - } - - private fun rexBindingNameToLangBindingName(name: BindingName) = org.partiql.lang.eval.BindingName( - name.name, - when (name.bindingCase) { - BindingCase.SENSITIVE -> org.partiql.lang.eval.BindingCase.SENSITIVE - BindingCase.INSENSITIVE -> org.partiql.lang.eval.BindingCase.INSENSITIVE - } - ) - - private fun rexIdToBindingName(node: Rex.Id): BindingName = BindingName( - node.name, - rexCaseToBindingCase(node.case) - ) - - private fun List.toAttributes(ctx: Context) = this.map { attribute(it.name, it.grabType() ?: handleMissingType(ctx)) } - - private fun inferConcatOp(leftType: SingleType, rightType: SingleType): SingleType { - fun checkUnconstrainedText(type: SingleType) = type is SymbolType || type is StringType && type.lengthConstraint is StringType.StringLengthConstraint.Unconstrained - - return when { - // Propagate missing as missing. Missing has precedence over null - leftType is MissingType || rightType is MissingType -> StaticType.MISSING - leftType is NullType || rightType is NullType -> StaticType.NULL - !leftType.isText() || !rightType.isText() -> StaticType.MISSING - checkUnconstrainedText(leftType) || checkUnconstrainedText(rightType) -> StaticType.STRING - else -> { // Constrained string types (char & varchar) - val leftLength = ((leftType as StringType).lengthConstraint as StringType.StringLengthConstraint.Constrained).length - val rightLength = ((rightType as StringType).lengthConstraint as StringType.StringLengthConstraint.Constrained).length - val sum = leftLength.value + rightLength.value - val newConstraint = when { - leftLength is NumberConstraint.UpTo || rightLength is NumberConstraint.UpTo -> NumberConstraint.UpTo(sum) - else -> NumberConstraint.Equals(sum) - } - StringType(StringType.StringLengthConstraint.Constrained(newConstraint)) - } - } - } - - private fun inferUnaryArithmeticOp(type: SingleType): SingleType = when (type) { - // Propagate NULL or MISSING - is NullType -> StaticType.NULL - is MissingType -> StaticType.MISSING - is DecimalType, is IntType, is FloatType -> type - else -> StaticType.MISSING - } - - private fun computeReturnTypeForUnary( - argStaticType: StaticType, - unaryOpInferencer: (SingleType) -> SingleType - ): StaticType { - val argSingleTypes = argStaticType.allTypes.map { it as SingleType } - val possibleReturnTypes = argSingleTypes.map { st -> unaryOpInferencer(st) } - - return StaticType.unionOf(possibleReturnTypes.toSet()).flatten() - } - - private fun inferNotOp(type: SingleType): SingleType = when (type) { - // Propagate NULL or MISSING - is NullType -> StaticType.NULL - is MissingType -> StaticType.MISSING - is BoolType -> type - else -> StaticType.MISSING - } - - private fun inferNaryLogicalOp(argsStaticType: List, op: String, ctx: Context): StaticType { - return when (hasValidOperandTypes(argsStaticType, op, ctx) { it is BoolType }) { - true -> { - val argsSingleTypes = argsStaticType.map { argStaticType -> - argStaticType.allTypes.map { singleType -> singleType as SingleType } - } - val argsSingleTypeCombination = argsSingleTypes.cartesianProduct() - val possibleResultTypes = argsSingleTypeCombination.map { argsSingleType -> - getTypeForNAryLogicalOperations(argsSingleType) - }.toSet() - - StaticType.unionOf(possibleResultTypes).flatten() - } - false -> StaticType.BOOL // continuation type to prevent incompatible types and unknown errors from propagating - } - } - - private fun getTypeForNAryLogicalOperations(args: List): StaticType = when { - // Logical operands need to be of Boolean Type - args.all { it == StaticType.BOOL } -> StaticType.BOOL - // If any of the arguments is boolean, then the return type can be boolean because of short-circuiting - // in logical ops. For e.g. "TRUE OR ANY" returns TRUE. "FALSE AND ANY" returns FALSE. But in the case - // where the other arg is an incompatible type (not an unknown or bool), the result type is MISSING. - args.any { it == StaticType.BOOL } -> when { - // If other argument is missing, then return union(bool, missing) - args.any { it is MissingType } -> AnyOfType(setOf(StaticType.MISSING, StaticType.BOOL)) - // If other argument is null, then return union(bool, null) - args.any { it is NullType } -> AnyOfType(setOf(StaticType.NULL, StaticType.BOOL)) - // If other type is anything other than null or missing, then it is an error case - else -> StaticType.MISSING - } - // If any of the operands is MISSING, return MISSING. MISSING has a precedence over NULL - args.any { it is MissingType } -> StaticType.MISSING - // If any of the operands is NULL, return NULL - args.any { it is NullType } -> StaticType.NULL - else -> StaticType.MISSING - } - - private fun computeReturnTypeForNAry( - argsStaticType: List, - binaryOpInferencer: (SingleType, SingleType) -> SingleType - ): StaticType = - argsStaticType.reduce { leftStaticType, rightStaticType -> - val leftSingleTypes = leftStaticType.allTypes.map { it as SingleType } - val rightSingleTypes = rightStaticType.allTypes.map { it as SingleType } - val possibleResultTypes: List = - leftSingleTypes.flatMap { leftSingleType -> - rightSingleTypes.map { rightSingleType -> - binaryOpInferencer(leftSingleType, rightSingleType) - } - } - - StaticType.unionOf(possibleResultTypes.toSet()).flatten() - } - - /** - * Computes return type for functions with [FunctionSignature.unknownArguments] as [UnknownArguments.PROPAGATE] - */ - private fun returnTypeForPropagatingFunction(signature: FunctionSignature, arguments: List, ctx: Context): StaticType { - val requiredArgs = arguments.zip(signature.requiredParameters) - val allArgs = requiredArgs - - return if (functionHasValidArgTypes(signature.name, allArgs, ctx)) { - val finalReturnTypes = signature.returnType.allTypes + allArgs.flatMap { (actualType, expectedType) -> - listOfNotNull( - // if any type is `MISSING`, add `MISSING` to possible return types. - // if the actual type is not a subtype is the expected type, add `MISSING`. In the future, may - // want to give a warning that a data type mismatch could occur - // (https://github.com/partiql/partiql-lang-kotlin/issues/507) - StaticType.MISSING.takeIf { - actualType.allTypes.any { it is MissingType } || !StaticTypeUtils.isSubTypeOf( - actualType.filterNullMissing(), - expectedType - ) - }, - // if any type is `NULL`, add `NULL` to possible return types - StaticType.NULL.takeIf { actualType.allTypes.any { it is NullType } } - ) - } - AnyOfType(finalReturnTypes.toSet()).flatten() - } else { - // otherwise, has an invalid arg type and errors. continuation type of [FunctionSignature.returnType] - signature.returnType - } - } - - /** - * For [this] [StaticType], filters out [NullType] and [MissingType] from [AnyOfType]s. Otherwise, returns [this]. - */ - private fun StaticType.filterNullMissing(): StaticType = - when (this) { - is AnyOfType -> AnyOfType(this.types.filter { !it.isNullOrMissing() }.toSet()).flatten() - else -> this - } - - private fun getUnpivotValueType(fromSourceType: StaticType): StaticType = - when (fromSourceType) { - is StructType -> if (fromSourceType.contentClosed) { - AnyOfType(fromSourceType.fields.map { it.value }.toSet()).flatten() - } else { - // Content is open, so value can be of any type - StaticType.ANY - } - is AnyType -> StaticType.ANY - is AnyOfType -> AnyOfType(fromSourceType.types.map { getUnpivotValueType(it) }.toSet()) - // All the other types coerce into a struct of themselves with synthetic key names - else -> fromSourceType - } - - /** - * Returns true if for every pair (expr, expectedType) in [argsWithExpectedTypes], the expr's [StaticType] is - * not an unknown and has a shared type with expectedType. Returns false otherwise. - * - * If an argument has an unknown type, the [SemanticProblemDetails.NullOrMissingFunctionArgument] error is - * handled by [ProblemHandler]. If an expr has no shared type with the expectedType, the - * [SemanticProblemDetails.InvalidArgumentTypeForFunction] error is handled by [ProblemHandler]. - */ - private fun functionHasValidArgTypes(functionName: String, argsWithExpectedTypes: List>, ctx: Context): Boolean { - var allArgsValid = true - argsWithExpectedTypes.forEach { (actualType, expectedType) -> - if (actualType.isUnknown()) { - handleNullOrMissingFunctionArgument(functionName, ctx) - allArgsValid = false - } else { - val actualNonUnknownType = actualType.filterNullMissing() - if (StaticTypeUtils.getTypeDomain(actualNonUnknownType).intersect(StaticTypeUtils.getTypeDomain(expectedType)).isEmpty() - ) { - handleInvalidArgumentTypeForFunction( - functionName = functionName, - expectedType = expectedType, - actualType = actualType, - ctx - ) - allArgsValid = false - } - } - } - return allArgsValid - } - - /** - * Computes return type for functions with [FunctionSignature.unknownArguments] as [UnknownArguments.PASS_THRU] - */ - private fun returnTypeForPassThruFunction(signature: FunctionSignature, arguments: List): StaticType { - return when { - matchesAllArguments(arguments, signature) -> signature.returnType - matchesAtLeastOneArgument(arguments, signature) -> StaticType.unionOf(signature.returnType, StaticType.MISSING) - else -> StaticType.MISSING - } - } - - /** - * Function assumes the number of [arguments] passed agrees with the [signature] - * Returns true when all the arguments (required, optional, variadic) are subtypes of the expected arguments for the [signature]. - * Returns false otherwise - */ - private fun matchesAllArguments(arguments: List, signature: FunctionSignature): Boolean { - // Checks if the actual StaticType is subtype of expected StaticType ( filtering the null/missing for PROPAGATING functions - fun isSubType(actual: StaticType, expected: StaticType): Boolean { - val lhs = when (signature.unknownArguments) { - UnknownArguments.PROPAGATE -> when (actual) { - is AnyOfType -> actual.copy( - types = actual.types.filter { - !it.isNullOrMissing() - }.toSet() - ) - else -> actual - } - UnknownArguments.PASS_THRU -> actual - } - return StaticTypeUtils.isSubTypeOf(lhs, expected) - } - - val requiredArgumentsMatch = arguments - .zip(signature.requiredParameters) - .all { (actual, expected) -> - isSubType(actual, expected) - } - return requiredArgumentsMatch - } - - internal fun Rex.isProjectAll(): Boolean { - return when (this) { - is Rex.Path -> { - val step = this.steps.lastOrNull() ?: return false - step is Step.Wildcard - } - else -> false - } - } - - /** - * Function assumes the number of [arguments] passed agrees with the [signature] - * - * Returns true if there's at least one valid overlap between actual and expected - * for all the expected arguments (required, optional, variadic) for the [signature]. - * - * Returns false otherwise. - */ - private fun matchesAtLeastOneArgument(arguments: List, signature: FunctionSignature): Boolean { - val requiredArgumentsMatch = arguments - .zip(signature.requiredParameters) - .all { (actual, expected) -> - StaticTypeUtils.getTypeDomain(actual).intersect(StaticTypeUtils.getTypeDomain(expected)).isNotEmpty() - } - return requiredArgumentsMatch - } - - private fun inferEqNeOp(lhs: SingleType, rhs: SingleType): SingleType = when { - // Propagate missing as missing. Missing has precedence over null - lhs is MissingType || rhs is MissingType -> StaticType.MISSING - lhs.isNullable() || rhs.isNullable() -> StaticType.NULL - else -> StaticType.BOOL - } - - // LT, LTE, GT, GTE - private fun inferComparatorOp(lhs: SingleType, rhs: SingleType): SingleType = when { - // Propagate missing as missing. Missing has precedence over null - lhs is MissingType || rhs is MissingType -> StaticType.MISSING - lhs is NullType || rhs is NullType -> StaticType.NULL - StaticTypeUtils.areStaticTypesComparable(lhs, rhs) -> StaticType.BOOL - else -> StaticType.MISSING - } - - /** - * Returns true if all of the provided [argsStaticType] are comparable to each other and are not unknown. Otherwise, - * returns false. - * - * If an operand is not comparable to another, the [SemanticProblemDetails.IncompatibleDatatypesForOp] error is - * handled by [ProblemHandler]. If an operand is unknown, the - * [SemanticProblemDetails.ExpressionAlwaysReturnsNullOrMissing] error is handled by [ProblemHandler]. - * - * TODO: consider if collection comparison semantics should be different (e.g. errors over warnings, - * more details in error message): https://github.com/partiql/partiql-lang-kotlin/issues/505 - */ - private fun operandsAreComparable(argsStaticType: List, op: String, ctx: Context): Boolean { - var hasValidOperands = true - - // check for comparability of all operands. currently only adds one data type mismatch error - outerLoop@ for (i in argsStaticType.indices) { - for (j in i + 1 until argsStaticType.size) { - if (!StaticTypeUtils.areStaticTypesComparable(argsStaticType[i], argsStaticType[j])) { - handleIncompatibleDataTypesForOpError(ctx, argsStaticType, op) - hasValidOperands = false - break@outerLoop - } - } - } - - // check for an unknown operand type - if (argsStaticType.any { operand -> operand.isUnknown() }) { - handleExpressionAlwaysReturnsNullOrMissingError(ctx) - hasValidOperands = false - } - return hasValidOperands - } - - // This could also have been a lookup table of types, however... doing this as a nested `when` allows - // us to not to rely on `.equals` and `.hashcode` implementations of [StaticType], which include metas - // and might introduce unwanted behavior. - private fun inferBinaryArithmeticOp(leftType: SingleType, rightType: SingleType): SingleType = when { - // Propagate missing as missing. Missing has precedence over null - leftType is MissingType || rightType is MissingType -> StaticType.MISSING - leftType is NullType || rightType is NullType -> StaticType.NULL - else -> when (leftType) { - is IntType -> - when (rightType) { - is IntType -> - when { - leftType.rangeConstraint == IntType.IntRangeConstraint.UNCONSTRAINED -> leftType - rightType.rangeConstraint == IntType.IntRangeConstraint.UNCONSTRAINED -> rightType - leftType.rangeConstraint.numBytes > rightType.rangeConstraint.numBytes -> leftType - else -> rightType - } - is FloatType -> StaticType.FLOAT - is DecimalType -> StaticType.DECIMAL // TODO: account for decimal precision - else -> StaticType.MISSING - } - is FloatType -> - when (rightType) { - is IntType -> StaticType.FLOAT - is FloatType -> StaticType.FLOAT - is DecimalType -> StaticType.DECIMAL // TODO: account for decimal precision - else -> StaticType.MISSING - } - is DecimalType -> - when (rightType) { - is IntType -> StaticType.DECIMAL // TODO: account for decimal precision - is FloatType -> StaticType.DECIMAL // TODO: account for decimal precision - is DecimalType -> StaticType.DECIMAL // TODO: account for decimal precision - else -> StaticType.MISSING - } - else -> StaticType.MISSING - } - } - - private fun hasValidOperandTypes( - operandsStaticType: List, - op: String, - ctx: Context, - operandTypeValidator: (StaticType) -> Boolean - ): Boolean { - // check for an incompatible operand type - if (operandsStaticType.any { operandStaticType -> !operandStaticType.isUnknown() && operandStaticType.allTypes.none(operandTypeValidator) }) { - handleIncompatibleDataTypesForOpError(ctx, operandsStaticType, op) - } - - // check for an unknown operand type - if (operandsStaticType.any { operandStaticType -> operandStaticType.isUnknown() }) { - handleExpressionAlwaysReturnsNullOrMissingError(ctx) - } - return true - } - - private fun assertAsInt(type: StaticType, ctx: Context) { - if (type.flatten().allTypes.any { variant -> variant is IntType }.not()) { - handleIncompatibleDataTypeForExprError(StaticType.INT, type, ctx) - } - } - - private fun StaticType.isNullOrMissing(): Boolean = (this is NullType || this is MissingType) - - internal fun StaticType.isText(): Boolean = (this is SymbolType || this is StringType) - - private fun StaticType.isUnknown(): Boolean = (this.isNullOrMissing() || this == StaticType.NULL_OR_MISSING) - - internal fun StaticType.isNumeric(): Boolean = (this is IntType || this is FloatType || this is DecimalType) - - private fun rexCaseToBindingCase(node: Case): BindingCase = when (node) { - Case.SENSITIVE -> BindingCase.SENSITIVE - Case.INSENSITIVE -> BindingCase.INSENSITIVE - } - - private fun findBind(path: BindingPath, qualifier: Rex.Id.Qualifier, ctx: Context): ReferenceResolver.ResolvedType { - val scopingOrder = when (qualifier) { - Rex.Id.Qualifier.LOCALS_FIRST -> ScopingOrder.LEXICAL_THEN_GLOBALS - Rex.Id.Qualifier.UNQUALIFIED -> ctx.scopingOrder - } - return when (scopingOrder) { - ScopingOrder.GLOBALS_THEN_LEXICAL -> ReferenceResolver.resolveGlobalBind(path, ctx) - ?: ReferenceResolver.resolveLocalBind(path, ctx.inputTypeEnv) - ?: handleUnresolvedDescriptor(path.steps.last(), ctx) { - ReferenceResolver.ResolvedType(StaticType.ANY) - } - ScopingOrder.LEXICAL_THEN_GLOBALS -> ReferenceResolver.resolveLocalBind(path, ctx.inputTypeEnv) - ?: ReferenceResolver.resolveGlobalBind(path, ctx) - ?: handleUnresolvedDescriptor(path.steps.last(), ctx) { - ReferenceResolver.ResolvedType(StaticType.ANY) - } - } - } - - private fun handleUnresolvedDescriptor(name: BindingName, ctx: Context, input: () -> T): T { - return when (ctx.tolerance) { - MinimumTolerance.FULL -> { - handleUndefinedVariable(name, ctx) - input.invoke() - } - MinimumTolerance.PARTIAL -> input.invoke() - } - } - - private fun grabFirstIds(node: Rex.Path): List { - if (node.root !is Rex.Id) { return emptyList() } - val steps = node.steps.map { - when (it) { - is Step.Key -> when (val value = it.value) { - is Rex.Lit -> { - val ionElement = value.value.asAnyElement() - when (ionElement.type) { - ElementType.SYMBOL, ElementType.STRING -> { - val stringValue = value.value.asAnyElement().stringValueOrNull - stringValue?.let { str -> - rexId(str, it.case, Rex.Id.Qualifier.UNQUALIFIED, null) - } - } - else -> null - } - } - else -> null - } - else -> null - } - } - val nullPosition = when (val nullIndex = steps.indexOf(null)) { - -1 -> steps.size - else -> nullIndex - } - val firstSteps = steps.subList(0, nullPosition).filterNotNull() - return listOf(node.root as Rex.Id) + firstSteps - } - - private fun inferType(expr: Rex, input: Rel?, ctx: Context): StaticType { - return type( - expr, - Context( - input, - ctx.session, - ctx.metadata, - ScopingOrder.LEXICAL_THEN_GLOBALS, - ctx.customFunctionSignatures, - ctx.tolerance, - ctx.problemHandler - ) - ).grabType() ?: handleMissingType(ctx) - } - - private fun typeRex(expr: Rex, input: Rel?, ctx: Context): Rex { - return type( - expr, - Context( - input, - ctx.session, - ctx.metadata, - ctx.scopingOrder, - ctx.customFunctionSignatures, - ctx.tolerance, - ctx.problemHandler - ) - ) - } - - private fun typeRel(rel: Rel, input: Rel?, ctx: Context): Rel { - return visitRel( - rel, - Context( - input, - ctx.session, - ctx.metadata, - ctx.scopingOrder, - ctx.customFunctionSignatures, - ctx.tolerance, - ctx.problemHandler - ) - ) - } - - private fun handleExpressionAlwaysReturnsNullOrMissingError(ctx: Context) { - ctx.problemHandler.handleProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = SemanticProblemDetails.ExpressionAlwaysReturnsNullOrMissing - ) - ) - } - - // TODO: https://github.com/partiql/partiql-lang-kotlin/issues/508 consider not working directly with strings for `op` - private fun handleIncompatibleDataTypesForOpError(ctx: Context, actualTypes: List, op: String) { - ctx.problemHandler.handleProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = SemanticProblemDetails.IncompatibleDatatypesForOp( - actualTypes, - op - ) - ) - ) - } - - private fun handleNoSuchFunctionError(ctx: Context, functionName: String) { - ctx.problemHandler.handleProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = SemanticProblemDetails.NoSuchFunction(functionName) - ) - ) - } - - private fun handleIncompatibleDataTypeForExprError(expectedType: StaticType, actualType: StaticType, ctx: Context) { - ctx.problemHandler.handleProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = SemanticProblemDetails.IncompatibleDataTypeForExpr(expectedType, actualType) - ) - ) - } - - private fun handleIncorrectNumberOfArgumentsToFunctionCallError( - functionName: String, - expectedArity: IntRange, - actualArgCount: Int, - ctx: Context - ) { - ctx.problemHandler.handleProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = SemanticProblemDetails.IncorrectNumberOfArgumentsToFunctionCall( - functionName, - expectedArity, - actualArgCount - ) - ) - ) - } - - private fun handleNullOrMissingFunctionArgument(functionName: String, ctx: Context) { - ctx.problemHandler.handleProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = SemanticProblemDetails.NullOrMissingFunctionArgument( - functionName = functionName - ) - ) - ) - } - - private fun handleUndefinedVariable(name: BindingName, ctx: Context) { - ctx.problemHandler.handleProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = PlanningProblemDetails.UndefinedVariable(name.name, name.bindingCase == BindingCase.SENSITIVE) - ) - ) - } - - private fun handleInvalidArgumentTypeForFunction(functionName: String, expectedType: StaticType, actualType: StaticType, ctx: Context) { - ctx.problemHandler.handleProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = SemanticProblemDetails.InvalidArgumentTypeForFunction( - functionName = functionName, - expectedType = expectedType, - actualType = actualType - ) - ) - ) - } - - private fun handleMissingType(ctx: Context): StaticType { - ctx.problemHandler.handleProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = PlanningProblemDetails.CompileError("Unable to determine type of node.") - ) - ) - return StaticType.ANY - } - - private fun handleDuplicateAliasesError(ctx: Context) { - ctx.problemHandler.handleProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = SemanticProblemDetails.DuplicateAliasesInSelectListItem - ) - ) - } - - private fun handleCoercionError(ctx: Context, actualType: StaticType) { - ctx.problemHandler.handleProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = SemanticProblemDetails.CoercionError(actualType) - ) - ) - } - - private fun handleUnresolvedExcludeExprRoot(root: String, ctx: Context) { - ctx.problemHandler.handleProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = PlanningProblemDetails.UnresolvedExcludeExprRoot(root) - ) - ) - } -} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/PlanUtils.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/PlanUtils.kt deleted file mode 100644 index 4da66fcb04..0000000000 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/PlanUtils.kt +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at: - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific - * language governing permissions and limitations under the License. - */ - -package org.partiql.lang.planner.transforms.plan - -import org.partiql.plan.Arg -import org.partiql.plan.Attribute -import org.partiql.plan.Binding -import org.partiql.plan.PlanNode -import org.partiql.plan.Rel -import org.partiql.plan.Rex -import org.partiql.plan.Step -import org.partiql.types.StaticType - -internal object PlanUtils { - internal fun getTypeEnv(input: Rel): List = when (input) { - is Rel.Project -> input.common.typeEnv - is Rel.Aggregate -> input.common.typeEnv - is Rel.Bag -> input.common.typeEnv - is Rel.Fetch -> input.common.typeEnv - is Rel.Filter -> input.common.typeEnv - is Rel.Join -> input.common.typeEnv - is Rel.Scan -> input.common.typeEnv - is Rel.Sort -> input.common.typeEnv - is Rel.Unpivot -> input.common.typeEnv - is Rel.Exclude -> input.common.typeEnv - } - - internal fun Rex.addType(type: StaticType): Rex = when (this) { - is Rex.Agg -> this.copy(type = type) - is Rex.Binary -> this.copy(type = type) - is Rex.Call -> this.copy(type = type) - is Rex.Collection.Array -> this.copy(type = type) - is Rex.Collection.Bag -> this.copy(type = type) - is Rex.Id -> this.copy(type = type) - is Rex.Lit -> this.copy(type = type) - is Rex.Path -> this.copy(type = type) - is Rex.Query.Collection -> this.copy(type = type) - is Rex.Query.Scalar.Pivot -> this.copy(type = type) - is Rex.Query.Scalar.Subquery -> this.copy(type = type) - is Rex.Switch -> this.copy(type = type) - is Rex.Tuple -> this.copy(type = type) - is Rex.Unary -> this.copy(type = type) - } - - internal fun Rex.grabType(): StaticType? = when (this) { - is Rex.Agg -> this.type - is Rex.Binary -> this.type - is Rex.Call -> this.type - is Rex.Collection.Array -> this.type - is Rex.Collection.Bag -> this.type - is Rex.Id -> this.type - is Rex.Lit -> this.type - is Rex.Path -> this.type - is Rex.Query.Collection -> this.type - is Rex.Query.Scalar.Pivot -> this.type - is Rex.Tuple -> this.type - is Rex.Unary -> this.type - is Rex.Query.Scalar.Subquery -> this.type - is Rex.Switch -> this.type - } - - internal fun PlanNode.grabType(): StaticType? = when (this) { - is Rex -> this.grabType() - is Arg.Value -> this.value.grabType() - is Arg.Type -> this.type - is Step.Key -> this.value.grabType() - is Binding -> this.value.grabType() - else -> error("Unable to grab static type of $this") - } -} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/RelConverter.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/RelConverter.kt deleted file mode 100644 index 26ecb6c870..0000000000 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/RelConverter.kt +++ /dev/null @@ -1,518 +0,0 @@ -package org.partiql.lang.planner.transforms.plan - -import com.amazon.ionelement.api.ionInt -import com.amazon.ionelement.api.ionString -import org.partiql.lang.domains.PartiqlAst -import org.partiql.lang.eval.visitors.VisitorTransformBase -import org.partiql.lang.planner.transforms.plan.RexConverter.convertCase -import org.partiql.plan.Binding -import org.partiql.plan.Case -import org.partiql.plan.ExcludeExpr -import org.partiql.plan.ExcludeStep -import org.partiql.plan.Rel -import org.partiql.plan.Rex -import org.partiql.plan.SortSpec -import org.partiql.plan.binding -import org.partiql.plan.common -import org.partiql.plan.excludeExpr -import org.partiql.plan.excludeStepCollectionIndex -import org.partiql.plan.excludeStepCollectionWildcard -import org.partiql.plan.excludeStepTupleAttr -import org.partiql.plan.excludeStepTupleWildcard -import org.partiql.plan.field -import org.partiql.plan.relAggregate -import org.partiql.plan.relExclude -import org.partiql.plan.relFetch -import org.partiql.plan.relFilter -import org.partiql.plan.relJoin -import org.partiql.plan.relProject -import org.partiql.plan.relScan -import org.partiql.plan.relSort -import org.partiql.plan.relUnpivot -import org.partiql.plan.rexAgg -import org.partiql.plan.rexId -import org.partiql.plan.rexLit -import org.partiql.plan.rexQueryCollection -import org.partiql.plan.rexQueryScalarPivot -import org.partiql.plan.rexTuple -import org.partiql.plan.sortSpec -import org.partiql.types.StaticType - -/** - * Lexically scoped state for use in translating an individual SELECT statement. - */ -internal class RelConverter { - - /** - * As of now, the COMMON property of relation operators is under development, so just use empty for now - */ - private val empty = common( - typeEnv = emptyList(), - properties = emptySet(), - metas = emptyMap() - ) - - companion object { - - /** - * Converts a SELECT-FROM-WHERE AST node to a [Rex.Query] - */ - @JvmStatic - fun convert(select: PartiqlAst.Expr.Select): Rex.Query = with(RelConverter()) { - val rel = convertSelect(select) - val rex = when (val projection = select.project) { - // PIVOT ... FROM - is PartiqlAst.Projection.ProjectPivot -> { - rexQueryScalarPivot( - rel = rel, - value = RexConverter.convert(projection.value), - at = RexConverter.convert(projection.key), - type = null - ) - } - // SELECT VALUE ... FROM - is PartiqlAst.Projection.ProjectValue -> { - rexQueryCollection( - rel = rel, - constructor = RexConverter.convert(projection.value), - type = null - ) - } - // SELECT ... FROM - else -> { - rexQueryCollection( - rel = rel, - constructor = null, - type = null - ) - } - } - rex - } - } - - // synthetic binding name counter - private var i = 0 - - // generate a synthetic binding name - private fun nextBindingName(): String = "\$__v${i++}" - - /** - * Translate SFW AST node to a pipeline of [Rel] operators; this skips the final projection. - * - * Note: - * - This does not append the final projection - * - The AST doesn't support set operators - * - The Parser doesn't have FETCH syntax - */ - private fun convertSelect(node: PartiqlAst.Expr.Select): Rel { - var sel = node - var rel = convertFrom(sel.from) - rel = convertWhere(rel, sel.where) - // kotlin does not have destructuring assignment - val (_sel, _rel) = convertAgg(rel, sel, sel.group) - sel = _sel - rel = _rel - // transform (possibly rewritten) sel node - rel = convertHaving(rel, sel.having) - rel = convertOrderBy(rel, sel.order) - rel = convertFetch(rel, sel.limit, sel.offset) - rel = convertExclude(rel, sel.excludeClause) - // append SQL projection if present - rel = when (val projection = sel.project) { - is PartiqlAst.Projection.ProjectList -> convertProjectList(rel, projection) - is PartiqlAst.Projection.ProjectStar -> error("AST not normalized, found project star") - else -> rel // skip - } - return rel - } - - private fun convertExclude(input: Rel, excludeOp: PartiqlAst.ExcludeOp?): Rel = when (excludeOp) { - null -> input - else -> { - val exprs = excludeOp.exprs.map { convertExcludeExpr(it) } - relExclude( - common = empty, - input = input, - exprs = exprs, - ) - } - } - - private fun convertExcludeExpr(excludeExpr: PartiqlAst.ExcludeExpr): ExcludeExpr { - val root = excludeExpr.root.name.text - val case = convertCase(excludeExpr.root.case) - val steps = excludeExpr.steps.map { convertExcludeSteps(it) } - return excludeExpr(root, case, steps) - } - - private fun convertExcludeSteps(excludeStep: PartiqlAst.ExcludeStep): ExcludeStep { - return when (excludeStep) { - is PartiqlAst.ExcludeStep.ExcludeCollectionWildcard -> excludeStepCollectionWildcard() - is PartiqlAst.ExcludeStep.ExcludeTupleWildcard -> excludeStepTupleWildcard() - is PartiqlAst.ExcludeStep.ExcludeTupleAttr -> excludeStepTupleAttr(excludeStep.attr.name.text, convertCase(excludeStep.attr.case)) - is PartiqlAst.ExcludeStep.ExcludeCollectionIndex -> excludeStepCollectionIndex(excludeStep.index.value.toInt()) - } - } - - /** - * Appends the appropriate [Rel] operator for the given FROM source - */ - private fun convertFrom(from: PartiqlAst.FromSource): Rel = when (from) { - is PartiqlAst.FromSource.Join -> convertJoin(from) - is PartiqlAst.FromSource.Scan -> convertScan(from) - is PartiqlAst.FromSource.Unpivot -> convertUnpivot(from) - } - - /** - * Appends [Rel.Join] where the left and right sides are converted FROM sources - */ - private fun convertJoin(join: PartiqlAst.FromSource.Join): Rel { - val lhs = convertFrom(join.left) - val rhs = convertFrom(join.right) - val condition = if (join.predicate != null) RexConverter.convert(join.predicate!!) else null - return relJoin( - common = empty, - lhs = lhs, - rhs = rhs, - condition = condition, - type = when (join.type) { - is PartiqlAst.JoinType.Full -> Rel.Join.Type.FULL - is PartiqlAst.JoinType.Inner -> Rel.Join.Type.INNER - is PartiqlAst.JoinType.Left -> Rel.Join.Type.LEFT - is PartiqlAst.JoinType.Right -> Rel.Join.Type.RIGHT - } - ) - } - - /** - * Appends [Rel.Scan] which takes no input relational expression - */ - private fun convertScan(scan: PartiqlAst.FromSource.Scan) = relScan( - common = empty, - value = when (val expr = scan.expr) { - is PartiqlAst.Expr.Select -> convert(expr) - else -> RexConverter.convert(scan.expr) - }, - alias = scan.asAlias?.text, - at = scan.atAlias?.text, - by = scan.byAlias?.text - ) - - /** - * Appends [Rel.Unpivot] to range over attribute value pairs - */ - private fun convertUnpivot(scan: PartiqlAst.FromSource.Unpivot) = relUnpivot( - common = empty, - value = RexConverter.convert(scan.expr), - alias = scan.asAlias?.text, - at = scan.atAlias?.text, - by = scan.byAlias?.text - ) - - /** - * Append [Rel.Filter] only if a WHERE condition exists - */ - private fun convertWhere(input: Rel, expr: PartiqlAst.Expr?): Rel = when (expr) { - null -> input - else -> relFilter( - common = empty, - input = input, - condition = RexConverter.convert(expr) - ) - } - - /** - * Append [Rel.Aggregate] only if SELECT contains aggregate expressions. - * - * @return Pair is returned where - * 1. Ast.Expr.Select has every Ast.Expr.CallAgg replaced by a synthetic Ast.Expr.Id - * 2. Rel which has the appropriate Rex.Agg calls and Rex groups - */ - private fun convertAgg( - input: Rel, - select: PartiqlAst.Expr.Select, - groupBy: PartiqlAst.GroupBy? - ): Pair { - // Rewrite and extract all aggregations in the SELECT clause - val (sel, aggregations) = AggregationTransform.apply(select) - - // No aggregation planning required for GROUP BY - if (aggregations.isEmpty()) { - if (groupBy != null) { - // As of now, GROUP BY with no aggregations is considered an error. - error("GROUP BY with no aggregations in SELECT clause") - } - return Pair(select, input) - } - - val calls = aggregations.toMutableList() - var groups = emptyList() - var strategy = Rel.Aggregate.Strategy.FULL - - if (groupBy != null) { - // GROUP AS is implemented as an aggregation function - if (groupBy.groupAsAlias != null) { - calls.add(convertGroupAs(groupBy.groupAsAlias!!.text, sel.from)) - } - groups = groupBy.keyList.keys.map { convertGroupByKey(it) } - strategy = when (groupBy.strategy) { - is PartiqlAst.GroupingStrategy.GroupFull -> Rel.Aggregate.Strategy.FULL - is PartiqlAst.GroupingStrategy.GroupPartial -> Rel.Aggregate.Strategy.PARTIAL - } - } - - val rel = relAggregate( - common = empty, - input = input, - calls = calls, - groups = groups, - strategy = strategy - ) - - return Pair(sel, rel) - } - - /** - * Each GROUP BY becomes a binding available in the output tuples of [Rel.Aggregate] - */ - private fun convertGroupByKey(groupKey: PartiqlAst.GroupKey) = binding( - name = groupKey.asAlias?.text ?: error("not normalized, group key $groupKey missing unique name"), - expr = groupKey.expr - ) - - /** - * Append [Rel.Filter] only if a HAVING condition exists - * - * Notes: - * - This currently does not support aggregation expressions in the WHERE condition - */ - private fun convertHaving(input: Rel, expr: PartiqlAst.Expr?): Rel = when (expr) { - null -> input - else -> relFilter( - common = empty, - input = input, - condition = RexConverter.convert(expr) - ) - } - - /** - * Append [Rel.Sort] only if an ORDER BY clause is present - */ - private fun convertOrderBy(input: Rel, orderBy: PartiqlAst.OrderBy?) = when (orderBy) { - null -> input - else -> relSort( - common = empty, - input = input, - specs = orderBy.sortSpecs.map { convertSortSpec(it) } - ) - } - - /** - * Append [Rel.Fetch] if there is a LIMIT or LIMIT and OFFSET. - * - * Notes: - * - It's unclear if OFFSET without LIMIT should be allowed in PartiQL, so err for now. - */ - private fun convertFetch( - input: Rel, - limit: PartiqlAst.Expr?, - offset: PartiqlAst.Expr? - ): Rel { - if (limit == null) { - if (offset != null) error("offset without limit") - return input - } - return relFetch( - common = empty, - input = input, - limit = RexConverter.convert(limit), - offset = RexConverter.convert(offset ?: PartiqlAst.Expr.Lit(ionInt(0).asAnyElement())) - ) - } - - /** - * Appends a [Rel.Project] which projects the result of each binding rex into its binding name. - * - * @param input - * @param projection - * @return - */ - private fun convertProjectList(input: Rel, projection: PartiqlAst.Projection.ProjectList) = relProject( - common = empty, - input = input, - bindings = projection.projectItems.bindings() - ) - - /** - * Converts Ast.SortSpec to SortSpec. - * - * Notes: - * - ASC NULLS LAST (default) - * - DESC NULLS FIRST (default for DESC) - */ - private fun convertSortSpec(sortSpec: PartiqlAst.SortSpec) = sortSpec( - value = RexConverter.convert(sortSpec.expr), - dir = when (sortSpec.orderingSpec) { - is PartiqlAst.OrderingSpec.Desc -> SortSpec.Dir.DESC - is PartiqlAst.OrderingSpec.Asc -> SortSpec.Dir.ASC - null -> SortSpec.Dir.ASC - }, - nulls = when (sortSpec.nullsSpec) { - is PartiqlAst.NullsSpec.NullsFirst -> SortSpec.Nulls.FIRST - is PartiqlAst.NullsSpec.NullsLast -> SortSpec.Nulls.LAST - null -> SortSpec.Nulls.LAST - } - ) - - /** - * Converts a GROUP AS X clause to a binding of the form: - * ``` - * { 'X': group_as({ 'a_0': e_0, ..., 'a_n': e_n }) } - * ``` - * - * Notes: - * - This was included to be consistent with the existing PartiqlAst and PartiqlLogical representations, - * but perhaps we don't want to represent GROUP AS with an agg function. - */ - private fun convertGroupAs(name: String, from: PartiqlAst.FromSource): Binding { - val fields = from.bindings().map { n -> - field( - name = rexLit(ionString(n), StaticType.STRING), - value = rexId(n, Case.SENSITIVE, Rex.Id.Qualifier.UNQUALIFIED, type = StaticType.STRUCT) - ) - } - return binding( - name = name, - value = rexAgg( - id = "group_as", - args = listOf(rexTuple(fields, StaticType.STRUCT)), - modifier = Rex.Agg.Modifier.ALL, - type = StaticType.STRUCT - ) - ) - } - - /** - * Helper to get all binding names in the FROM clause - */ - private fun PartiqlAst.FromSource.bindings(): List = when (this) { - is PartiqlAst.FromSource.Scan -> { - if (asAlias == null) { - error("not normalized, scan is missing an alias") - } - listOf(asAlias!!.text) - } - is PartiqlAst.FromSource.Join -> left.bindings() + right.bindings() - is PartiqlAst.FromSource.Unpivot -> { - if (asAlias == null) { - error("not normalized, scan is missing an alias") - } - listOf(asAlias!!.text) - } - } - - /** - * Helper to convert ProjectItems to bindings - * - * As of now, bindings is just a list, not a tuple. - * Binding and Tuple/Struct will be consolidated. - */ - private fun List.bindings() = map { - when (it) { - is PartiqlAst.ProjectItem.ProjectAll -> { - val path = PartiqlAst.Expr.Path(it.expr, listOf(PartiqlAst.PathStep.PathWildcard())) - val bindingName = when (val expr = it.expr) { - is PartiqlAst.Expr.Id -> expr.name.text - is PartiqlAst.Expr.Lit -> { - when (expr.value.type.isText) { - true -> expr.value.stringValue - false -> nextBindingName() - } - } - else -> nextBindingName() - } - binding(bindingName, path) - } - is PartiqlAst.ProjectItem.ProjectExpr -> binding( - name = it.asAlias?.text ?: error("not normalized"), - expr = it.expr - ) - } - } - - /** - * Rewrites a SFW node replacing all aggregations with a synthetic field name - * - * See AstToLogicalVisitorTransform.kt CallAggregationReplacer from org.partiql.lang.planner.transforms. - * - * ``` - * SELECT g, h, SUM(t.b) AS sumB - * FROM t - * GROUP BY t.a AS g GROUP AS h - * ``` - * - * into: - * - * ``` - * SELECT g, h, $__v0 AS sumB - * FROM t - * GROUP BY t.a AS g GROUP AS h - * ``` - * - * Where $__v0 is the binding name of SUM(t.b) in the aggregation output - * - * Inner object class to have access to current SELECT-FROM-WHERE converter state - */ - @Suppress("PrivatePropertyName") - private val AggregationTransform = object : VisitorTransformBase() { - - private var level = 0 - private var aggregations = mutableListOf() - - fun apply(node: PartiqlAst.Expr.Select): Pair> { - level = 0 - aggregations = mutableListOf() - val select = transformExprSelect(node) as PartiqlAst.Expr.Select - return Pair(select, aggregations) - } - - override fun transformProjectItemProjectExpr_expr(node: PartiqlAst.ProjectItem.ProjectExpr) = - transformExpr(node.expr) - - override fun transformProjectionProjectValue_value(node: PartiqlAst.Projection.ProjectValue) = - transformExpr(node.value) - - override fun transformExprSelect_having(node: PartiqlAst.Expr.Select): PartiqlAst.Expr? = - when (val having = node.having) { - null -> null - else -> transformExpr(having) - } - - override fun transformSortSpec_expr(node: PartiqlAst.SortSpec) = transformExpr(node.expr) - - override fun transformExprSelect(node: PartiqlAst.Expr.Select) = - if (level++ == 0) super.transformExprSelect(node) else node - - override fun transformExprCallAgg(node: PartiqlAst.Expr.CallAgg): PartiqlAst.Expr { - val name = nextBindingName() - aggregations.add(binding(name, node)) - return PartiqlAst.build { - id( - name = name, - case = caseInsensitive(), - qualifier = unqualified(), - metas = node.metas - ) - } - } - } - - /** - * Binding helper - */ - private fun binding(name: String, expr: PartiqlAst.Expr) = binding( - name = name, - value = RexConverter.convert(expr) - ) -} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/RexConverter.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/RexConverter.kt deleted file mode 100644 index fa177fe2ec..0000000000 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/RexConverter.kt +++ /dev/null @@ -1,692 +0,0 @@ -package org.partiql.lang.planner.transforms.plan - -import com.amazon.ionelement.api.MetaContainer -import com.amazon.ionelement.api.ionNull -import org.partiql.errors.ErrorCode -import org.partiql.lang.domains.PartiqlAst -import org.partiql.lang.eval.EvaluationSession -import org.partiql.lang.eval.builtins.ExprFunctionCurrentUser -import org.partiql.lang.eval.err -import org.partiql.lang.eval.errorContextFrom -import org.partiql.lang.planner.transforms.AstToPlan -import org.partiql.lang.planner.transforms.plan.PlanTyper.isProjectAll -import org.partiql.plan.Case -import org.partiql.plan.Rel -import org.partiql.plan.Rex -import org.partiql.plan.argType -import org.partiql.plan.argValue -import org.partiql.plan.branch -import org.partiql.plan.field -import org.partiql.plan.rexAgg -import org.partiql.plan.rexBinary -import org.partiql.plan.rexCall -import org.partiql.plan.rexCollectionArray -import org.partiql.plan.rexCollectionBag -import org.partiql.plan.rexId -import org.partiql.plan.rexLit -import org.partiql.plan.rexPath -import org.partiql.plan.rexQueryScalarSubquery -import org.partiql.plan.rexSwitch -import org.partiql.plan.rexTuple -import org.partiql.plan.rexUnary -import org.partiql.plan.stepKey -import org.partiql.plan.stepUnpivot -import org.partiql.plan.stepWildcard -import org.partiql.types.StaticType -import java.util.Locale - -/** - * Some workarounds for transforming a PIG tree without having to create another visitor: - * - Using the VisitorFold with ctx struct to create a parameterized return and scoped arguments/context - * - Using walks to control traversal, also walks have generated if/else blocks for sum types so its more useful - */ -@Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") -internal object RexConverter : PartiqlAst.VisitorFold() { - - /** - * Workaround for PIG visitor where: - * - Args != null when Ctx is the accumulator IN - * - Rex != null when Ctx is the accumulator OUT - * - * Destructuring ordering chosen for val (in, out) = ... - * - * @property node Node to invoke the behavior on - * @property rex Return value - */ - data class Ctx( - val node: PartiqlAst.PartiqlAstNode, - var rex: Rex? = null, - ) - - /** - * Read as `val rex = node.accept(visitor = RexVisitor.INSTANCE, args = emptyList())` - * Only works because RexConverter errs for all non Expr AST nodes, and Expr is one sum type. - */ - internal fun convert(node: PartiqlAst.Expr) = RexConverter.walkExpr(node, Ctx(node)).rex!! - - /** - * List version of `accept` - */ - private fun convert(nodes: List) = nodes.map { convert(it) } - - /** - * Vararg version of `accept` - */ - private fun convert(vararg nodes: PartiqlAst.Expr) = nodes.map { convert(it) } - - private fun arg(name: String, node: PartiqlAst.PartiqlAstNode) = when (node) { - is PartiqlAst.Expr -> argValue( - name = name, - value = convert(node), - ) - is PartiqlAst.Type -> argType( - name = name, - type = TypeConverter.convert(node) - ) - else -> error("Argument must be of type PartiqlAst.Expr or PartiqlAst.Type, found ${node::class.qualifiedName}") - } - - /** - * Convert a list of arguments to arg0, ...., argN - */ - private fun args(nodes: List) = args(*nodes.toTypedArray()) - - /** - * Convert arguments to arg0, ...., argN - */ - private fun args(vararg nodes: PartiqlAst.PartiqlAstNode?) = - nodes.filterNotNull().mapIndexed { i, arg -> arg("arg$i", arg) } - - /** - * Convert keyword pairs of arguments - */ - private fun args(vararg args: Pair) = args.map { arg(it.first, it.second) } - - /** - * Helper so the visitor "body" looks like it has Rex as the return value - */ - private inline fun visit(node: PartiqlAst.PartiqlAstNode, block: () -> Rex) = Ctx(node, block()) - - /** - * !! DEFAULT VISIT !! - * - * The PIG visitor doesn't give us control over the default "visit" - * We can override walkMetas (which appears on every super.walk call) as if it were a default "visit" - * MetaContainer isn't actually a domain node, and we don't have any context as to where the MetaContainer - * is coming from which is why the current node is stuffed into Ctx - */ - override fun walkMetas(node: MetaContainer, ctx: Ctx) = AstToPlan.unsupported(ctx.node) - - override fun walkExprMissing(node: PartiqlAst.Expr.Missing, ctx: Ctx) = visit(node) { - rexLit(ionNull(), StaticType.MISSING) - } - - override fun walkExprLit(node: PartiqlAst.Expr.Lit, ctx: Ctx) = visit(node) { - val ionType = node.value.type.toIonType() - rexLit( - value = node.value, - type = TypeConverter.convert(ionType) - ) - } - - override fun walkExprSessionAttribute(node: PartiqlAst.Expr.SessionAttribute, accumulator: Ctx) = visit(node) { - val functionName = when (node.value.text.uppercase(Locale.getDefault())) { - EvaluationSession.Constants.CURRENT_USER_KEY -> ExprFunctionCurrentUser.FUNCTION_NAME - else -> err( - "Unsupported session attribute: ${node.value.text}", - errorCode = ErrorCode.SEMANTIC_PROBLEM, - errorContext = errorContextFrom(node.metas), - internal = false - ) - } - rexCall( - id = functionName, - args = emptyList(), - type = null - ) - } - - override fun walkExprId(node: PartiqlAst.Expr.Id, ctx: Ctx) = visit(node) { - rexId( - name = node.name.text, - case = convertCase(node.case), - qualifier = when (node.qualifier) { - is PartiqlAst.ScopeQualifier.LocalsFirst -> Rex.Id.Qualifier.LOCALS_FIRST - is PartiqlAst.ScopeQualifier.Unqualified -> Rex.Id.Qualifier.UNQUALIFIED - }, - type = null, - ) - } - - override fun walkExprPath(node: PartiqlAst.Expr.Path, ctx: Ctx) = visit(node) { - rexPath( - root = convert(node.root), - steps = node.steps.map { - when (it) { - is PartiqlAst.PathStep.PathExpr -> stepKey( - value = convert(it.index), - case = convertCase(it.case) - ) - is PartiqlAst.PathStep.PathUnpivot -> stepUnpivot() - is PartiqlAst.PathStep.PathWildcard -> stepWildcard() - } - }, - type = null, - ) - } - - override fun walkExprNot(node: PartiqlAst.Expr.Not, ctx: Ctx) = visit(node) { - rexUnary( - value = convert(node.expr), - op = Rex.Unary.Op.NOT, - type = StaticType.BOOL, - ) - } - - override fun walkExprPos(node: PartiqlAst.Expr.Pos, ctx: Ctx) = visit(node) { - rexUnary( - value = convert(node.expr), - op = Rex.Unary.Op.POS, - type = StaticType.NUMERIC, - ) - } - - override fun walkExprNeg(node: PartiqlAst.Expr.Neg, ctx: Ctx) = visit(node) { - rexUnary( - value = convert(node.expr), - op = Rex.Unary.Op.NEG, - type = StaticType.NUMERIC, - ) - } - - override fun walkExprPlus(node: PartiqlAst.Expr.Plus, ctx: Ctx) = visit(node) { - rexBinary( - lhs = convert(node.operands[0]), - rhs = convert(node.operands[1]), - op = Rex.Binary.Op.PLUS, - type = StaticType.NUMERIC, - ) - } - - override fun walkExprMinus(node: PartiqlAst.Expr.Minus, ctx: Ctx) = visit(node) { - rexBinary( - lhs = convert(node.operands[0]), - rhs = convert(node.operands[1]), - op = Rex.Binary.Op.MINUS, - type = StaticType.NUMERIC, - ) - } - - override fun walkExprTimes(node: PartiqlAst.Expr.Times, ctx: Ctx) = visit(node) { - rexBinary( - lhs = convert(node.operands[0]), - rhs = convert(node.operands[1]), - op = Rex.Binary.Op.TIMES, - type = StaticType.NUMERIC, - ) - } - - override fun walkExprDivide(node: PartiqlAst.Expr.Divide, ctx: Ctx) = visit(node) { - rexBinary( - lhs = convert(node.operands[0]), - rhs = convert(node.operands[1]), - op = Rex.Binary.Op.DIV, - type = StaticType.NUMERIC, - ) - } - - override fun walkExprModulo(node: PartiqlAst.Expr.Modulo, ctx: Ctx) = visit(node) { - rexBinary( - lhs = convert(node.operands[0]), - rhs = convert(node.operands[1]), - op = Rex.Binary.Op.MODULO, - type = StaticType.NUMERIC, - ) - } - - override fun walkExprBitwiseAnd(node: PartiqlAst.Expr.BitwiseAnd, accumulator: Ctx) = visit(node) { - rexBinary( - lhs = convert(node.operands[0]), - rhs = convert(node.operands[1]), - op = Rex.Binary.Op.BITWISE_AND, - type = StaticType.unionOf( - StaticType.INT2, StaticType.INT4, StaticType.INT8, StaticType.INT - ) - ) - } - - override fun walkExprConcat(node: PartiqlAst.Expr.Concat, ctx: Ctx) = visit(node) { - rexBinary( - lhs = convert(node.operands[0]), - rhs = convert(node.operands[1]), - op = Rex.Binary.Op.CONCAT, - type = StaticType.TEXT, - ) - } - - override fun walkExprAnd(node: PartiqlAst.Expr.And, ctx: Ctx) = visit(node) { - rexBinary( - lhs = convert(node.operands[0]), - rhs = convert(node.operands[1]), - op = Rex.Binary.Op.AND, - type = StaticType.BOOL, - ) - } - - override fun walkExprOr(node: PartiqlAst.Expr.Or, ctx: Ctx) = visit(node) { - rexBinary( - lhs = convert(node.operands[0]), - rhs = convert(node.operands[1]), - op = Rex.Binary.Op.OR, - type = StaticType.BOOL, - ) - } - - override fun walkExprEq(node: PartiqlAst.Expr.Eq, ctx: Ctx) = visit(node) { - val (lhs, rhs) = walkComparisonOperands(node.operands) - rexBinary( - lhs = lhs, - rhs = rhs, - op = Rex.Binary.Op.EQ, - type = StaticType.BOOL, - ) - } - - override fun walkExprNe(node: PartiqlAst.Expr.Ne, ctx: Ctx) = visit(node) { - val (lhs, rhs) = walkComparisonOperands(node.operands) - rexBinary( - lhs = lhs, - rhs = rhs, - op = Rex.Binary.Op.NEQ, - type = StaticType.BOOL, - ) - } - - override fun walkExprGt(node: PartiqlAst.Expr.Gt, ctx: Ctx) = visit(node) { - val (lhs, rhs) = walkComparisonOperands(node.operands) - rexBinary( - lhs = lhs, - rhs = rhs, - op = Rex.Binary.Op.GT, - type = StaticType.BOOL, - ) - } - - override fun walkExprGte(node: PartiqlAst.Expr.Gte, ctx: Ctx) = visit(node) { - val (lhs, rhs) = walkComparisonOperands(node.operands) - rexBinary( - lhs = lhs, - rhs = rhs, - op = Rex.Binary.Op.GTE, - type = StaticType.BOOL, - ) - } - - override fun walkExprLt(node: PartiqlAst.Expr.Lt, ctx: Ctx) = visit(node) { - val (lhs, rhs) = walkComparisonOperands(node.operands) - rexBinary( - lhs = lhs, - rhs = rhs, - op = Rex.Binary.Op.LT, - type = StaticType.BOOL, - ) - } - - override fun walkExprLte(node: PartiqlAst.Expr.Lte, ctx: Ctx) = visit(node) { - val (lhs, rhs) = walkComparisonOperands(node.operands) - rexBinary( - lhs = lhs, - rhs = rhs, - op = Rex.Binary.Op.LTE, - type = StaticType.BOOL, - ) - } - - /** - * Converts Comparison Operands. Also coerces them if one of them is an array. - */ - private fun walkComparisonOperands(operands: List): Pair { - var lhs = convert(operands[0]) - var rhs = convert(operands[1]) - if (lhs is Rex.Collection.Array) { rhs = coercePotentialSubquery(rhs) } - if (rhs is Rex.Collection.Array) { lhs = coercePotentialSubquery(lhs) } - return lhs to rhs - } - - /** - * We convert the scalar subquery of a SFW into a scalar subquery of a SELECT VALUE. - */ - private fun coercePotentialSubquery(rex: Rex): Rex { - var rhs = rex - if (rhs is Rex.Query.Scalar.Subquery) { - val sfw = rhs.query as? Rex.Query.Collection ?: error("Malformed plan, all scalar subqueries should hold collection queries") - val constructor = sfw.constructor ?: run { - val relProject = sfw.rel as? Rel.Project ?: error("Malformed plan, the top of a plan should be a projection") - getConstructorFromRelProject(relProject) - } - rhs = rhs.copy( - query = rhs.query.copy( - constructor = constructor - ) - ) - } - return rhs - } - - private fun getConstructorFromRelProject(relProject: Rel.Project): Rex { - return when (relProject.bindings.size) { - 0 -> error("The Projection should not have held empty bindings.") - 1 -> { - val binding = relProject.bindings.first() - if (binding.value.isProjectAll()) { - error("Unimplemented feature: coercion of SELECT *.") - } - relProject.bindings.first().value - } - else -> { - if (relProject.bindings.any { it.value.isProjectAll() }) { - error("Unimplemented feature: coercion of SELECT *.") - } - rexCollectionArray( - relProject.bindings.map { it.value }, - type = StaticType.LIST - ) - } - } - } - - override fun walkExprLike(node: PartiqlAst.Expr.Like, ctx: Ctx) = visit(node) { - when (val escape = node.escape) { - null -> rexCall( - id = Constants.like, - args = args( - "value" to node.value, - "pattern" to node.pattern, - ), - type = StaticType.BOOL, - ) - else -> rexCall( - id = Constants.likeEscape, - args = args( - "value" to node.value, - "pattern" to node.pattern, - "escape" to escape, - ), - type = StaticType.BOOL, - ) - } - } - - override fun walkExprBetween(node: PartiqlAst.Expr.Between, ctx: Ctx) = visit(node) { - rexCall( - id = Constants.between, - args = args("value" to node.value, "from" to node.from, "to" to node.to), - type = StaticType.BOOL, - ) - } - - /** - * Here, we must visit the RHS. If it is a scalar subquery, we need to grab the underlying collection. - */ - override fun walkExprInCollection(node: PartiqlAst.Expr.InCollection, ctx: Ctx) = visit(node) { - val lhs = convert(node.operands[0]) - val potentialSubqueryRex = convert(node.operands[1]) - val potentialSubquery = coercePotentialSubquery(potentialSubqueryRex) - val rhs = (potentialSubquery as? Rex.Query.Scalar.Subquery)?.query ?: potentialSubquery - rexCall( - id = Constants.inCollection, - args = listOf( - argValue("lhs", lhs), - argValue("rhs", rhs), - ), - type = StaticType.BOOL, - ) - } - - override fun walkExprStruct(node: PartiqlAst.Expr.Struct, ctx: Ctx) = visit(node) { - rexTuple( - fields = node.fields.map { - field( - name = convert(it.first), - value = convert(it.second) - ) - }, - type = StaticType.STRUCT, - ) - } - - override fun walkExprBag(node: PartiqlAst.Expr.Bag, ctx: Ctx) = visit(node) { - rexCollectionBag( - values = convert(node.values), - type = StaticType.BAG, - ) - } - - override fun walkExprList(node: PartiqlAst.Expr.List, ctx: Ctx) = visit(node) { - rexCollectionArray( - values = convert(node.values), - type = StaticType.LIST, - ) - } - - override fun walkExprSexp(node: PartiqlAst.Expr.Sexp, accumulator: Ctx) = visit(node) { - rexCollectionArray( - values = convert(node.values), - type = StaticType.LIST, - ) - } - - override fun walkExprCall(node: PartiqlAst.Expr.Call, ctx: Ctx) = visit(node) { - rexCall( - id = node.funcName.text, - args = args(*node.args.toTypedArray()), - type = null, - ) - } - - override fun walkExprCallAgg(node: PartiqlAst.Expr.CallAgg, ctx: Ctx) = visit(node) { - rexAgg( - id = node.funcName.text, - args = listOf(convert(node.arg)), - modifier = when (node.setq) { - is PartiqlAst.SetQuantifier.All -> Rex.Agg.Modifier.ALL - is PartiqlAst.SetQuantifier.Distinct -> Rex.Agg.Modifier.DISTINCT - }, - type = StaticType.NUMERIC, - ) - } - - override fun walkExprIsType(node: PartiqlAst.Expr.IsType, ctx: Ctx) = visit(node) { - rexCall( - id = Constants.isType, - args = args("value" to node.value, "type" to node.type), - type = StaticType.BOOL, - ) - } - - override fun walkExprSimpleCase(node: PartiqlAst.Expr.SimpleCase, ctx: Ctx) = visit(node) { - rexSwitch( - match = convert(node.expr), - branches = node.cases.pairs.map { - branch( - condition = convert(it.first), - value = convert(it.second), - ) - }, - default = if (node.default != null) convert(node.default!!) else null, - type = null - ) - } - - override fun walkExprSearchedCase(node: PartiqlAst.Expr.SearchedCase, ctx: Ctx) = visit(node) { - rexSwitch( - match = null, - branches = node.cases.pairs.map { - branch( - condition = convert(it.first), - value = convert(it.second), - ) - }, - default = if (node.default != null) convert(node.default!!) else null, - type = null - ) - } - - override fun walkExprDate(node: PartiqlAst.Expr.Date, ctx: Ctx): Ctx { - error("Date class undetermined at the moment") - } - - override fun walkExprLitTime(node: PartiqlAst.Expr.LitTime, ctx: Ctx): Ctx { - error("Time class undetermined at the moment") - } - - override fun walkExprBagOp(node: PartiqlAst.Expr.BagOp, ctx: Ctx) = visit(node) { - // Hack for UNION / INTERSECT / EXCEPT because they are missing from the parser - val op = when (node.quantifier) { - is PartiqlAst.SetQuantifier.All -> when (node.op) { - is PartiqlAst.BagOpType.Union, - is PartiqlAst.BagOpType.OuterUnion -> Constants.outerBagUnion - is PartiqlAst.BagOpType.Intersect, - is PartiqlAst.BagOpType.OuterIntersect -> Constants.outerBagIntersect - is PartiqlAst.BagOpType.Except, - is PartiqlAst.BagOpType.OuterExcept -> Constants.outerBagExcept - } - is PartiqlAst.SetQuantifier.Distinct -> when (node.op) { - is PartiqlAst.BagOpType.Union, - is PartiqlAst.BagOpType.OuterUnion -> Constants.outerSetUnion - is PartiqlAst.BagOpType.Intersect, - is PartiqlAst.BagOpType.OuterIntersect -> Constants.outerSetIntersect - is PartiqlAst.BagOpType.Except, - is PartiqlAst.BagOpType.OuterExcept -> Constants.outerSetExcept - } - } - rexCall( - id = op, - args = args("lhs" to node.operands[0], "rhs" to node.operands[1]), - type = StaticType.BAG, - ) - } - - override fun walkExprCast(node: PartiqlAst.Expr.Cast, ctx: Ctx) = visit(node) { - rexCall( - id = Constants.cast, - args = args("value" to node.value, "type" to node.asType), - type = TypeConverter.convert(node.asType), - ) - } - - override fun walkExprCanCast(node: PartiqlAst.Expr.CanCast, ctx: Ctx) = visit(node) { - rexCall( - id = Constants.canCast, - args = args("value" to node.value, "type" to node.asType), - type = StaticType.BOOL, - ) - } - - override fun walkExprCanLosslessCast(node: PartiqlAst.Expr.CanLosslessCast, ctx: Ctx) = visit(node) { - rexCall( - id = Constants.canLosslessCast, - args = args("value" to node.value, "type" to node.asType), - type = StaticType.BOOL, - ) - } - - override fun walkExprNullIf(node: PartiqlAst.Expr.NullIf, ctx: Ctx) = visit(node) { - rexCall( - id = Constants.nullIf, - args = args(node.expr1, node.expr2), - type = StaticType.BOOL, - ) - } - - override fun walkExprCoalesce(node: PartiqlAst.Expr.Coalesce, ctx: Ctx) = visit(node) { - rexCall( - id = Constants.coalesce, - args = args(node.args), - type = null, - ) - } - - override fun walkExprSelect(node: PartiqlAst.Expr.Select, ctx: Ctx) = visit(node) { - when (val query = RelConverter.convert(node)) { - is Rex.Query.Collection -> rexQueryScalarSubquery(query, null) - is Rex.Query.Scalar -> query - } - } - - internal fun convertCase(case: PartiqlAst.CaseSensitivity) = when (case) { - is PartiqlAst.CaseSensitivity.CaseInsensitive -> Case.INSENSITIVE - is PartiqlAst.CaseSensitivity.CaseSensitive -> Case.SENSITIVE - } - - internal object Constants { - - // const val unaryNot = "unary_not" - // - // const val unaryPlus = "unary_plus" - // - // const val unaryMinus = "unary_minus" - // - // const val unaryNegate = "unary_negate" - // - // const val binaryAdd = "binary_add" - // - // const val binarySub = "binary_sb" - // - // const val binaryMult = "binary_mult" - // - // const val binaryDiv = "binary_div" - // - // const val binaryMod = "binary_mod" - // - // const val binaryConcat = "binary_concat" - // - // const val binaryAnd = "binary_and" - // - // const val binaryOr = "binary_or" - // - // const val binaryEq = "binary_eq" - // - // const val binaryNeq = "binary_neq" - // - // const val binaryGt = "binary_gt" - // - // const val binaryGte = "binary_gte" - // - // const val binaryLt = "binary_lt" - // - // const val binaryLte = "binary_lte" - - const val like = "like" - - const val likeEscape = "like_escape" - - const val between = "between" - - const val inCollection = "in_collection" - - const val isType = "is_type" - - const val outerBagUnion = "outer_bag_union" - - const val outerBagIntersect = "outer_bag_intersect" - - const val outerBagExcept = "outer_bag_except" - - const val outerSetUnion = "outer_set_union" - - const val outerSetIntersect = "outer_set_intersect" - - const val outerSetExcept = "outer_set_except" - - const val cast = "cast" - - const val canCast = "can_cast" - - const val canLosslessCast = "can_lossless_cast" - - const val nullIf = "null_if" - - const val coalesce = "coalesce" - } -} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt index 55ef2033ff..d6ad905f77 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt @@ -20,7 +20,6 @@ import org.partiql.lang.planner.transforms.PartiQLSchemaInferencerTests.ProblemH import org.partiql.lang.planner.transforms.PartiQLSchemaInferencerTests.TestCase.ErrorTestCase import org.partiql.lang.planner.transforms.PartiQLSchemaInferencerTests.TestCase.SuccessTestCase import org.partiql.lang.planner.transforms.PartiQLSchemaInferencerTests.TestCase.ThrowingExceptionTestCase -import org.partiql.plan.Rex import org.partiql.plan.debug.PlanPrinter import org.partiql.planner.PartiQLPlanner import org.partiql.planner.PlanningProblemDetails @@ -84,6 +83,11 @@ class PartiQLSchemaInferencerTests { @Execution(ExecutionMode.CONCURRENT) fun testJoins(tc: TestCase) = runTest(tc) + @ParameterizedTest + @MethodSource("excludeCases") + @Execution(ExecutionMode.CONCURRENT) + fun testExclude(tc: TestCase) = runTest(tc) + companion object { private val root = this::class.java.getResource("/catalogs")!!.toURI().toPath().pathString @@ -518,57 +522,38 @@ class PartiQLSchemaInferencerTests { } ), ) - } - - sealed class TestCase { - class SuccessTestCase( - val name: String, - val query: String, - val catalog: String? = null, - val catalogPath: List = emptyList(), - val expected: StaticType, - ) : TestCase() { - override fun toString(): String = "$name : $query" - } - - class ErrorTestCase( - val name: String, - val query: String, - val catalog: String? = null, - val catalogPath: List = emptyList(), - val note: String? = null, - val expected: StaticType? = null, - val problemHandler: ProblemHandler? = null, - ) : TestCase() { - override fun toString(): String = "$name : $query" - } - - class ThrowingExceptionTestCase( - val name: String, - val query: String, - val catalog: String? = null, - val catalogPath: List = emptyList(), - val note: String? = null, - val expectedThrowable: KClass, - ) : TestCase() { - override fun toString(): String { - return "$name : $query" - } - } - } - - class TestProvider : ArgumentsProvider { - override fun provideArguments(context: ExtensionContext?): Stream { - return parameters.map { Arguments.of(it) }.stream() - } - private val parameters = listOf( - ErrorTestCase( - name = "Pets should not be accessible #1", - query = "SELECT * FROM pets", + @JvmStatic + fun excludeCases() = listOf( + SuccessTestCase( + name = "EXCLUDE SELECT star", + query = """SELECT * EXCLUDE c.ssn FROM [ + { + 'name': 'Alan', + 'custId': 1, + 'address': { + 'city': 'Seattle', + 'zipcode': 98109, + 'street': '123 Seaplane Dr.' + }, + 'ssn': 123456789 + } + ] AS c""", expected = BagType( StructType( - fields = mapOf("pets" to StaticType.ANY), + fields = mapOf( + "name" to StaticType.STRING, + "custId" to StaticType.INT, + "address" to StructType( + fields = mapOf( + "city" to StaticType.STRING, + "zipcode" to StaticType.INT, + "street" to StaticType.STRING, + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ) + ), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -576,21 +561,36 @@ class PartiQLSchemaInferencerTests { TupleConstraint.Ordered ) ) - ), - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable("pets", false) - ) - } + ) ), - ErrorTestCase( - name = "Pets should not be accessible #2", - catalog = CATALOG_AWS, - query = "SELECT * FROM pets", + SuccessTestCase( + name = "EXCLUDE SELECT star multiple paths", + query = """SELECT * EXCLUDE c.ssn, c.address.street FROM [ + { + 'name': 'Alan', + 'custId': 1, + 'address': { + 'city': 'Seattle', + 'zipcode': 98109, + 'street': '123 Seaplane Dr.' + }, + 'ssn': 123456789 + } + ] AS c""", expected = BagType( StructType( - fields = mapOf("pets" to StaticType.ANY), + fields = mapOf( + "name" to StaticType.STRING, + "custId" to StaticType.INT, + "address" to StructType( + fields = mapOf( + "city" to StaticType.STRING, + "zipcode" to StaticType.INT + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ) + ), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -598,55 +598,69 @@ class PartiQLSchemaInferencerTests { TupleConstraint.Ordered ) ) - ), - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable("pets", false) - ) - } - ), - SuccessTestCase( - name = "Project all explicitly", - catalog = CATALOG_AWS, - catalogPath = listOf("ddb"), - query = "SELECT * FROM pets", - expected = TABLE_AWS_DDB_PETS - ), - SuccessTestCase( - name = "Project all implicitly", - catalog = CATALOG_AWS, - catalogPath = listOf("ddb"), - query = "SELECT id, breed FROM pets", - expected = TABLE_AWS_DDB_PETS - ), - SuccessTestCase( - name = "Test #4", - catalog = CATALOG_B, - catalogPath = listOf("b"), - query = "b", - expected = TYPE_B_B_B - ), - SuccessTestCase( - name = "Test #5", - catalog = CATALOG_AWS, - catalogPath = listOf("ddb"), - query = "SELECT * FROM b", - expected = TABLE_AWS_DDB_B + ) ), SuccessTestCase( - name = "Test #6", - catalog = CATALOG_AWS, - catalogPath = listOf("b"), - query = "SELECT * FROM b", - expected = TABLE_AWS_B_B - ), - ErrorTestCase( - name = "Test #7", - query = "SELECT * FROM ddb.pets", + name = "EXCLUDE SELECT star list index and list index field", + query = """SELECT * + EXCLUDE + t.a.b.c[0], + t.a.b.c[1].field + FROM [{ + 'a': { + 'b': { + 'c': [ + { + 'field': 0 -- c[0] + }, + { + 'field': 1 -- c[1] + }, + { + 'field': 2 -- c[2] + } + ] + } + }, + 'foo': 'bar' + }] AS t""", expected = BagType( StructType( - fields = mapOf("pets" to StaticType.ANY), + fields = mapOf( + "a" to StructType( + fields = mapOf( + "b" to StructType( + fields = mapOf( + "c" to ListType( + elementType = StructType( + fields = mapOf( + "field" to AnyOfType( + setOf( + INT, + MISSING // c[1]'s `field` was excluded + ) + ) + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ) + ) + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ) + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ), + "foo" to StaticType.STRING + ), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -654,321 +668,203 @@ class PartiQLSchemaInferencerTests { TupleConstraint.Ordered ) ) - ), - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable("pets", false) - ) - } - ), - SuccessTestCase( - name = "Test #10", - catalog = CATALOG_B, - query = "b.b", - expected = TYPE_B_B_B - ), - SuccessTestCase( - name = "Test #11", - catalog = CATALOG_B, - catalogPath = listOf("b"), - query = "b.b", - expected = TYPE_B_B_B - ), - SuccessTestCase( - name = "Test #12", - catalog = CATALOG_AWS, - catalogPath = listOf("ddb"), - query = "SELECT * FROM b.b", - expected = TABLE_AWS_B_B + ) ), SuccessTestCase( - name = "Test #13", - catalog = CATALOG_AWS, - catalogPath = listOf("ddb"), - query = "SELECT * FROM ddb.b", - expected = TABLE_AWS_DDB_B - ), - SuccessTestCase( - name = "Test #14", - query = "SELECT * FROM aws.ddb.pets", - expected = TABLE_AWS_DDB_PETS - ), - SuccessTestCase( - name = "Test #15", - catalog = CATALOG_AWS, - query = "SELECT * FROM aws.b.b", - expected = TABLE_AWS_B_B - ), - SuccessTestCase( - name = "Test #16", - catalog = CATALOG_B, - query = "b.b.b", - expected = TYPE_B_B_B - ), - SuccessTestCase( - name = "Test #17", - catalog = CATALOG_B, - query = "b.b.c", - expected = TYPE_B_B_C - ), - SuccessTestCase( - name = "Test #18", - catalog = CATALOG_B, - catalogPath = listOf("b"), - query = "b.b.b", - expected = TYPE_B_B_B - ), - SuccessTestCase( - name = "Test #19", - query = "b.b.b.c", - expected = TYPE_B_B_B_C - ), - SuccessTestCase( - name = "Test #20", - query = "b.b.b.b", - expected = TYPE_B_B_B_B - ), - SuccessTestCase( - name = "Test #21", - catalog = CATALOG_B, - query = "b.b.b.b", - expected = TYPE_B_B_B_B - ), - SuccessTestCase( - name = "Test #22", - catalog = CATALOG_B, - query = "b.b.b.c", - expected = TYPE_B_B_C - ), - SuccessTestCase( - name = "Test #23", - catalog = CATALOG_B, - catalogPath = listOf("b"), - query = "b.b.b.b", - expected = TYPE_B_B_B_B - ), - SuccessTestCase( - name = "Test #24", - query = "b.b.b.b.b", - expected = TYPE_B_B_B_B_B - ), - SuccessTestCase( - name = "Test #24", - catalog = CATALOG_B, - query = "b.b.b.b.b", - expected = TYPE_B_B_B_B_B - ), - SuccessTestCase( - name = "EQ", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "order_info.customer_id = 1", - expected = TYPE_BOOL - ), - SuccessTestCase( - name = "NEQ", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "order_info.customer_id <> 1", - expected = TYPE_BOOL - ), - SuccessTestCase( - name = "GEQ", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "order_info.customer_id >= 1", - expected = TYPE_BOOL - ), - SuccessTestCase( - name = "GT", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "order_info.customer_id > 1", - expected = TYPE_BOOL - ), - SuccessTestCase( - name = "LEQ", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "order_info.customer_id <= 1", - expected = TYPE_BOOL - ), - SuccessTestCase( - name = "LT", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "order_info.customer_id < 1", - expected = TYPE_BOOL - ), - SuccessTestCase( - name = "IN", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "order_info.customer_id IN (1, 2, 3)", - expected = TYPE_BOOL - ), - ErrorTestCase( - name = "IN Failure", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "order_info.customer_id IN 'hello'", - expected = MISSING, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnknownFunction( - "in_collection", - listOf(INT, STRING), - ) - ) - } - ), - SuccessTestCase( - name = "BETWEEN", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "order_info.customer_id BETWEEN 1 AND 2", - expected = TYPE_BOOL - ), - ErrorTestCase( - name = "BETWEEN Failure", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "order_info.customer_id BETWEEN 1 AND 'a'", - expected = MISSING, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnknownFunction( - "between", - listOf( - INT, - INT, - STRING + name = "EXCLUDE SELECT star collection index as last step", + query = """SELECT * + EXCLUDE + t.a.b.c[0] + FROM [{ + 'a': { + 'b': { + 'c': [0, 1, 2] + } + }, + 'foo': 'bar' + }] AS t""", + expected = BagType( + StructType( + fields = mapOf( + "a" to StructType( + fields = mapOf( + "b" to StructType( + fields = mapOf( + "c" to ListType( + elementType = StaticType.INT + ) + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ) + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) ), + "foo" to StaticType.STRING + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered ) ) - } + ) ), + // EXCLUDE regression test (behavior subject to change pending RFC) SuccessTestCase( - name = "LIKE", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "order_info.ship_option LIKE '%ABC%'", - expected = TYPE_BOOL - ), - ErrorTestCase( - name = "LIKE Failure", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "order_info.ship_option LIKE 3", - expected = MISSING, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnknownFunction( - "like", - listOf(STRING, INT), + name = "EXCLUDE SELECT star collection wildcard as last step", + query = """SELECT * + EXCLUDE + t.a[*] + FROM [{ + 'a': [0, 1, 2] + }] AS t""", + expected = BagType( + StructType( + fields = mapOf( + "a" to ListType( + elementType = StaticType.INT // empty list but still preserve typing information + ) + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered ) ) - } - ), - SuccessTestCase( - name = "Case Insensitive success", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "order_info.CUSTOMER_ID = 1", - expected = TYPE_BOOL - ), - ErrorTestCase( - name = "Case Sensitive failure", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "order_info.\"CUSTOMER_ID\" = 1", - expected = TYPE_BOOL - ), - SuccessTestCase( - name = "Case Sensitive success", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "order_info.\"customer_id\" = 1", - expected = TYPE_BOOL - ), - SuccessTestCase( - name = "1-Level Junction", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "(order_info.customer_id = 1) AND (order_info.marketplace_id = 2)", - expected = TYPE_BOOL - ), - SuccessTestCase( - name = "2-Level Junction", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "(order_info.customer_id = 1) AND (order_info.marketplace_id = 2) OR (order_info.customer_id = 3) AND (order_info.marketplace_id = 4)", - expected = TYPE_BOOL + ) ), SuccessTestCase( - name = "INT and STR Comparison", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "order_info.customer_id = 'something'", - expected = TYPE_BOOL, - ), - ErrorTestCase( - name = "Nonexisting Comparison", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "non_existing_column = 1", - expected = StaticType.BOOL, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable("non_existing_column", false) - ) - } - ), - ErrorTestCase( - name = "Bad comparison", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "order_info.customer_id = 1 AND 1", - expected = MISSING, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnknownFunction( - "and", - listOf(StaticType.BOOL, INT), - ) - ) - } - ), - ErrorTestCase( - name = "Bad comparison", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "1 AND order_info.customer_id = 1", - expected = MISSING, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnknownFunction( - "and", - listOf(INT, StaticType.BOOL), + name = "EXCLUDE SELECT star list wildcard", + query = """SELECT * + EXCLUDE + t.a.b.c[*].field_x + FROM [{ + 'a': { + 'b': { + 'c': [ + { -- c[0] + 'field_x': 0, + 'field_y': 0 + }, + { -- c[1] + 'field_x': 1, + 'field_y': 1 + }, + { -- c[2] + 'field_x': 2, + 'field_y': 2 + } + ] + } + }, + 'foo': 'bar' + }] AS t""", + expected = BagType( + StructType( + fields = mapOf( + "a" to StructType( + fields = mapOf( + "b" to StructType( + fields = mapOf( + "c" to ListType( + elementType = StructType( + fields = mapOf( + "field_y" to StaticType.INT + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ) + ) + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ) + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ), + "foo" to StaticType.STRING + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered ) ) - } + ) ), - ErrorTestCase( - name = "Unknown column", - catalog = CATALOG_DB, - catalogPath = DB_SCHEMA_MARKETS, - query = "SELECT unknown_col FROM orders WHERE customer_id = 1", + SuccessTestCase( + name = "EXCLUDE SELECT star list tuple wildcard", + query = """SELECT * + EXCLUDE + t.a.b.c[*].* + FROM [{ + 'a': { + 'b': { + 'c': [ + { -- c[0] + 'field_x': 0, + 'field_y': 0 + }, + { -- c[1] + 'field_x': 1, + 'field_y': 1 + }, + { -- c[2] + 'field_x': 2, + 'field_y': 2 + } + ] + } + }, + 'foo': 'bar' + }] AS t""", expected = BagType( StructType( - fields = mapOf("unknown_col" to AnyType()), + fields = mapOf( + "a" to StructType( + fields = mapOf( + "b" to StructType( + fields = mapOf( + "c" to ListType( + elementType = StructType( + fields = mapOf( + // all fields gone + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ) + ) + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ) + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ), + "foo" to StaticType.STRING + ), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -976,83 +872,33 @@ class PartiQLSchemaInferencerTests { TupleConstraint.Ordered ) ) - ), - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable("unknown_col", false) - ) - } - ), - SuccessTestCase( - name = "ORDER BY int", - catalog = CATALOG_AWS, - catalogPath = listOf("ddb"), - query = "SELECT * FROM pets ORDER BY id", - expected = TABLE_AWS_DDB_PETS_LIST - ), - SuccessTestCase( - name = "ORDER BY str", - catalog = CATALOG_AWS, - catalogPath = listOf("ddb"), - query = "SELECT * FROM pets ORDER BY breed", - expected = TABLE_AWS_DDB_PETS_LIST - ), - SuccessTestCase( - name = "ORDER BY str", - catalog = CATALOG_AWS, - catalogPath = listOf("ddb"), - query = "SELECT * FROM pets ORDER BY unknown_col", - expected = TABLE_AWS_DDB_PETS_LIST - ), - SuccessTestCase( - name = "LIMIT INT", - catalog = CATALOG_AWS, - catalogPath = listOf("ddb"), - query = "SELECT * FROM pets LIMIT 5", - expected = TABLE_AWS_DDB_PETS - ), - ErrorTestCase( - name = "LIMIT STR", - catalog = CATALOG_AWS, - catalogPath = listOf("ddb"), - query = "SELECT * FROM pets LIMIT '5'", - expected = TABLE_AWS_DDB_PETS, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnexpectedType(STRING, setOf(INT)) - ) - } - ), - SuccessTestCase( - name = "OFFSET INT", - catalog = CATALOG_AWS, - catalogPath = listOf("ddb"), - query = "SELECT * FROM pets LIMIT 1 OFFSET 5", - expected = TABLE_AWS_DDB_PETS - ), - ErrorTestCase( - name = "OFFSET STR", - catalog = CATALOG_AWS, - catalogPath = listOf("ddb"), - query = "SELECT * FROM pets LIMIT 1 OFFSET '5'", - expected = TABLE_AWS_DDB_PETS, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnexpectedType(STRING, setOf(INT)) - ) - } + ) ), SuccessTestCase( - name = "CAST", - catalog = CATALOG_AWS, - catalogPath = listOf("ddb"), - query = "SELECT CAST(breed AS INT) AS cast_breed FROM pets", - expected = BagType( + name = "EXCLUDE SELECT star order by", + query = """SELECT * + EXCLUDE + t.a + FROM [ + { + 'a': 2, + 'foo': 'bar2' + }, + { + 'a': 1, + 'foo': 'bar1' + }, + { + 'a': 3, + 'foo': 'bar3' + } + ] AS t + ORDER BY t.a""", + expected = ListType( StructType( - fields = mapOf("cast_breed" to unionOf(INT, MISSING)), + fields = mapOf( + "foo" to StaticType.STRING + ), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -1063,13 +909,25 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "UPPER", - catalog = CATALOG_AWS, - catalogPath = listOf("ddb"), - query = "SELECT UPPER(breed) AS upper_breed FROM pets", + name = "EXCLUDE SELECT star with JOINs", + query = """SELECT * + EXCLUDE bar.d + FROM + << + {'a': 1, 'b': 11}, + {'a': 2, 'b': 22} + >> AS foo, + << + {'c': 3, 'd': 33}, + {'c': 4, 'd': 44} + >> AS bar""", expected = BagType( StructType( - fields = mapOf("upper_breed" to STRING), + fields = mapOf( + "a" to StaticType.INT, + "b" to StaticType.INT, + "c" to StaticType.INT + ), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -1080,11 +938,28 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "Non-tuples", - query = "SELECT a FROM << [ 1, 1.0 ] >> AS a", + name = "SELECT t.b EXCLUDE ex 1", + query = """SELECT t.b EXCLUDE t.b[*].b_1 + FROM << + { + 'a': {'a_1':1,'a_2':2}, + 'b': [ {'b_1':3,'b_2':4}, {'b_1':5,'b_2':6} ], + 'c': 7, + 'd': 8 + } >> AS t""", expected = BagType( StructType( - fields = mapOf("a" to ListType(unionOf(INT, StaticType.DECIMAL))), + fields = mapOf( + "b" to ListType( + elementType = StructType( + fields = mapOf( + "b_2" to StaticType.INT + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ) + ), + ), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -1095,29 +970,37 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "Non-tuples in SELECT VALUE", - query = "SELECT VALUE a FROM << [ 1, 1.0 ] >> AS a", - expected = - BagType(ListType(unionOf(INT, StaticType.DECIMAL))) - ), - SuccessTestCase( - name = "SELECT VALUE", - query = "SELECT VALUE [1, 1.0] FROM <<>>", - expected = - BagType(ListType(unionOf(INT, StaticType.DECIMAL))) - ), - SuccessTestCase( - name = "Duplicate fields in struct", - query = """ - SELECT t.a AS a + name = "SELECT * EXCLUDE ex 2", + query = """SELECT * EXCLUDE t.b[*].b_1 FROM << - { 'a': 1, 'a': 'hello' } - >> AS t - """, + { + 'a': {'a_1':1,'a_2':2}, + 'b': [ {'b_1':3,'b_2':4}, {'b_1':5,'b_2':6} ], + 'c': 7, + 'd': 8 + } >> AS t""", expected = BagType( StructType( - fields = listOf( - StructType.Field("a", unionOf(INT, STRING)) + fields = mapOf( + "a" to StructType( + fields = mapOf( + "a_1" to StaticType.INT, + "a_2" to StaticType.INT + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ), + "b" to ListType( + elementType = StructType( + fields = mapOf( + "b_2" to StaticType.INT + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ) + ), + "c" to StaticType.INT, + "d" to StaticType.INT ), contentClosed = true, constraints = setOf( @@ -1129,15 +1012,60 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "Duplicate fields in ordered STRUCT. NOTE: b.b.d is an ordered struct with two attributes (e). First is INT.", - query = """ - SELECT d.e AS e - FROM << b.b.d >> AS d - """, + name = "SELECT VALUE t.b EXCLUDE", + query = """SELECT VALUE t.b EXCLUDE t.b[*].b_1 + FROM << + { + 'a': {'a_1':1,'a_2':2}, + 'b': [ {'b_1':3,'b_2':4}, {'b_1':5,'b_2':6} ], + 'c': 7, + 'd': 8 + } >> AS t""", + expected = BagType( + ListType( + elementType = StructType( + fields = mapOf( + "b_2" to StaticType.INT + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ) + ), + ) + ), + SuccessTestCase( + name = "SELECT * EXCLUDE collection wildcard and nested tuple attr", + query = """SELECT * EXCLUDE t.a[*].b.c + FROM << + { + 'a': [ + { 'b': { 'c': 0, 'd': 'zero' } }, + { 'b': { 'c': 1, 'd': 'one' } }, + { 'b': { 'c': 2, 'd': 'two' } } + ] + } + >> AS t""", expected = BagType( StructType( - fields = listOf( - StructType.Field("e", INT) + fields = mapOf( + "a" to ListType( + elementType = StructType( + fields = mapOf( + "b" to StructType( + fields = mapOf( + "d" to StaticType.STRING + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ), + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ), + ) ), contentClosed = true, constraints = setOf( @@ -1149,17 +1077,39 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "Duplicate fields in struct", - query = """ - SELECT a AS a + name = "SELECT * EXCLUDE collection index and nested tuple attr", + query = """SELECT * EXCLUDE t.a[1].b.c FROM << - { 'a': 1, 'a': 'hello' } - >> AS t - """, + { + 'a': [ + { 'b': { 'c': 0, 'd': 'zero' } }, + { 'b': { 'c': 1, 'd': 'one' } }, + { 'b': { 'c': 2, 'd': 'two' } } + ] + } + >> AS t""", expected = BagType( StructType( - fields = listOf( - StructType.Field("a", unionOf(INT, STRING)) + fields = mapOf( + "a" to ListType( + elementType = StructType( + fields = mapOf( + "b" to StructType( + fields = mapOf( + "c" to StaticType.INT.asOptional(), + "d" to StaticType.STRING + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ), + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ), + ) ), contentClosed = true, constraints = setOf( @@ -1171,15 +1121,36 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "AGGREGATE over INTS", - query = "SELECT a, COUNT(*) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1, 'b': 2} >> GROUP BY a", + name = "SELECT * EXCLUDE collection wildcard and nested tuple wildcard", + query = """SELECT * EXCLUDE t.a[*].b.* + FROM << + { + 'a': [ + { 'b': { 'c': 0, 'd': 'zero' } }, + { 'b': { 'c': 1, 'd': 'one' } }, + { 'b': { 'c': 2, 'd': 'two' } } + ] + } + >> AS t""", expected = BagType( StructType( fields = mapOf( - "a" to INT, - "c" to INT, - "s" to INT, - "m" to INT, + "a" to ListType( + elementType = StructType( + fields = mapOf( + "b" to StructType( + fields = mapOf(), // empty map; all fields of b excluded + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ), + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ), + ) ), contentClosed = true, constraints = setOf( @@ -1191,15 +1162,39 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "AGGREGATE over DECIMALS", - query = "SELECT a, COUNT(*) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1.0, 'b': 2.0}, {'a': 1.0, 'b': 2.0} >> GROUP BY a", + name = "SELECT * EXCLUDE collection index and nested tuple wildcard", + query = """SELECT * EXCLUDE t.a[1].b.* + FROM << + { + 'a': [ + { 'b': { 'c': 0, 'd': 'zero' } }, + { 'b': { 'c': 1, 'd': 'one' } }, + { 'b': { 'c': 2, 'd': 'two' } } + ] + } + >> AS t""", expected = BagType( StructType( fields = mapOf( - "a" to StaticType.DECIMAL, - "c" to INT, - "s" to StaticType.DECIMAL, - "m" to StaticType.DECIMAL, + "a" to ListType( + elementType = StructType( + fields = mapOf( + "b" to StructType( + fields = mapOf( // all fields of b optional + "c" to StaticType.INT.asOptional(), + "d" to StaticType.STRING.asOptional() + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ), + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ), + ) ), contentClosed = true, constraints = setOf( @@ -1211,82 +1206,49 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "Current User", - query = "CURRENT_USER", - expected = unionOf(STRING, NULL) - ), - SuccessTestCase( - name = "Trim", - query = "trim(' ')", - expected = STRING - ), - SuccessTestCase( - name = "Current User Concat", - query = "CURRENT_USER || 'hello'", - expected = unionOf(STRING, NULL) - ), - SuccessTestCase( - name = "Current User Concat in WHERE", - query = "SELECT VALUE a FROM [ 0 ] AS a WHERE CURRENT_USER = 'hello'", - expected = BagType(INT) - ), - SuccessTestCase( - name = "TRIM_2", - query = "trim(' ' FROM ' Hello, World! ')", - expected = STRING - ), - SuccessTestCase( - name = "TRIM_1", - query = "trim(' Hello, World! ')", - expected = STRING - ), - SuccessTestCase( - name = "TRIM_3", - query = "trim(LEADING ' ' FROM ' Hello, World! ')", - expected = STRING - ), - ErrorTestCase( - name = "TRIM_2_error", - query = "trim(2 FROM ' Hello, World! ')", - expected = MISSING, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnknownFunction( - "trim_chars", - args = listOf(STRING, INT) - ) - ) - } - ), - // EXCLUDE test cases - SuccessTestCase( - name = "EXCLUDE SELECT star", - query = """SELECT * EXCLUDE c.ssn FROM [ - { - 'name': 'Alan', - 'custId': 1, - 'address': { - 'city': 'Seattle', - 'zipcode': 98109, - 'street': '123 Seaplane Dr.' - }, - 'ssn': 123456789 - } - ] AS c""", + name = "SELECT * EXCLUDE collection wildcard and nested collection wildcard", + query = """SELECT * EXCLUDE t.a[*].b.d[*].e + FROM << + { + 'a': [ + { 'b': { 'c': 0, 'd': [{'e': 'zero', 'f': true}] } }, + { 'b': { 'c': 1, 'd': [{'e': 'one', 'f': true}] } }, + { 'b': { 'c': 2, 'd': [{'e': 'two', 'f': true}] } } + ] + } + >> AS t""", expected = BagType( StructType( fields = mapOf( - "name" to StaticType.STRING, - "custId" to StaticType.INT, - "address" to StructType( - fields = mapOf( - "city" to StaticType.STRING, - "zipcode" to StaticType.INT, - "street" to StaticType.STRING, + "a" to ListType( + elementType = StructType( + fields = mapOf( + "b" to StructType( + fields = mapOf( + "c" to StaticType.INT, + "d" to ListType( + elementType = StructType( + fields = mapOf( + "f" to StaticType.BOOL + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ) + ) + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ), + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) ) ), contentClosed = true, @@ -1299,31 +1261,50 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "EXCLUDE SELECT star multiple paths", - query = """SELECT * EXCLUDE c.ssn, c.address.street FROM [ - { - 'name': 'Alan', - 'custId': 1, - 'address': { - 'city': 'Seattle', - 'zipcode': 98109, - 'street': '123 Seaplane Dr.' - }, - 'ssn': 123456789 - } - ] AS c""", + name = "SELECT * EXCLUDE collection index and nested collection wildcard", + query = """SELECT * EXCLUDE t.a[1].b.d[*].e + FROM << + { + 'a': [ + { 'b': { 'c': 0, 'd': [{'e': 'zero', 'f': true}] } }, + { 'b': { 'c': 1, 'd': [{'e': 'one', 'f': true}] } }, + { 'b': { 'c': 2, 'd': [{'e': 'two', 'f': true}] } } + ] + } + >> AS t""", expected = BagType( StructType( fields = mapOf( - "name" to StaticType.STRING, - "custId" to StaticType.INT, - "address" to StructType( - fields = mapOf( - "city" to StaticType.STRING, - "zipcode" to StaticType.INT + "a" to ListType( + elementType = StructType( + fields = mapOf( + "b" to StructType( + fields = mapOf( + "c" to StaticType.INT, + "d" to ListType( + elementType = StructType( + fields = mapOf( + "e" to StaticType.STRING.asOptional(), // last step is optional since only a[1]... is excluded + "f" to StaticType.BOOL + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ) + ) + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ), + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) ) ), contentClosed = true, @@ -1336,65 +1317,51 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "EXCLUDE SELECT star list index and list index field", - query = """SELECT * - EXCLUDE - t.a.b.c[0], - t.a.b.c[1].field - FROM [{ - 'a': { - 'b': { - 'c': [ - { - 'field': 0 -- c[0] - }, - { - 'field': 1 -- c[1] - }, - { - 'field': 2 -- c[2] - } - ] - } - }, - 'foo': 'bar' - }] AS t""", + name = "SELECT * EXCLUDE collection index and nested collection index", + query = """SELECT * EXCLUDE t.a[1].b.d[0].e + FROM << + { + 'a': [ + { 'b': { 'c': 0, 'd': [{'e': 'zero', 'f': true}] } }, + { 'b': { 'c': 1, 'd': [{'e': 'one', 'f': true}] } }, + { 'b': { 'c': 2, 'd': [{'e': 'two', 'f': true}] } } + ] + } + >> AS t""", expected = BagType( StructType( fields = mapOf( - "a" to StructType( - fields = mapOf( - "b" to StructType( - fields = mapOf( - "c" to ListType( - elementType = StructType( - fields = mapOf( - "field" to AnyOfType( - setOf( - INT, - MISSING // c[1]'s `field` was excluded - ) + "a" to ListType( + elementType = StructType( + fields = mapOf( + "b" to StructType( + fields = mapOf( + "c" to StaticType.INT, + "d" to ListType( + elementType = StructType( + fields = mapOf( // same as above + "e" to StaticType.STRING.asOptional(), + "f" to StaticType.BOOL + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) ) - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) ) ) + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) ) ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ) + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ), - "foo" to StaticType.STRING + ) ), contentClosed = true, constraints = setOf( @@ -1406,65 +1373,37 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "EXCLUDE SELECT star collection index as last step", - query = """SELECT * - EXCLUDE - t.a.b.c[0] - FROM [{ - 'a': { - 'b': { - 'c': [0, 1, 2] + name = "EXCLUDE case sensitive lookup", + query = """SELECT * EXCLUDE t."a".b['c'] + FROM << + { + 'a': { + 'B': { + 'c': 0, + 'd': 'foo' + } } - }, - 'foo': 'bar' - }] AS t""", + } + >> AS t""", expected = BagType( StructType( fields = mapOf( "a" to StructType( fields = mapOf( - "b" to StructType( + "B" to StructType( fields = mapOf( - "c" to ListType( - elementType = StaticType.INT - ) + "d" to StaticType.STRING ), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true) ) - ) + ), ), contentClosed = true, constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) ), - "foo" to StaticType.STRING - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Ordered - ) - ) - ) - ), - // EXCLUDE regression test (behavior subject to change pending RFC) - SuccessTestCase( - name = "EXCLUDE SELECT star collection wildcard as last step", - query = """SELECT * - EXCLUDE - t.a[*] - FROM [{ - 'a': [0, 1, 2] - }] AS t""", - expected = BagType( - StructType( - fields = mapOf( - "a" to ListType( - elementType = StaticType.INT // empty list but still preserve typing information - ) ), contentClosed = true, constraints = setOf( @@ -1476,62 +1415,39 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "EXCLUDE SELECT star list wildcard", - query = """SELECT * - EXCLUDE - t.a.b.c[*].field_x - FROM [{ - 'a': { - 'b': { - 'c': [ - { -- c[0] - 'field_x': 0, - 'field_y': 0 - }, - { -- c[1] - 'field_x': 1, - 'field_y': 1 - }, - { -- c[2] - 'field_x': 2, - 'field_y': 2 - } - ] + name = "EXCLUDE case sensitive lookup with capitalized and uncapitalized attr", + query = """SELECT * EXCLUDE t."a".b['c'] + FROM << + { + 'a': { + 'B': { + 'c': 0, + 'C': true, + 'd': 'foo' + } } - }, - 'foo': 'bar' - }] AS t""", + } + >> AS t""", expected = BagType( StructType( fields = mapOf( "a" to StructType( fields = mapOf( - "b" to StructType( + "B" to StructType( fields = mapOf( - "c" to ListType( - elementType = StructType( - fields = mapOf( - "field_y" to StaticType.INT - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ) - ) + "C" to StaticType.BOOL, // keep 'C' + "d" to StaticType.STRING ), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true) ) - ) + ), ), contentClosed = true, constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) ), - "foo" to StaticType.STRING ), contentClosed = true, constraints = setOf( @@ -1543,62 +1459,38 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "EXCLUDE SELECT star list tuple wildcard", - query = """SELECT * - EXCLUDE - t.a.b.c[*].* - FROM [{ - 'a': { - 'b': { - 'c': [ - { -- c[0] - 'field_x': 0, - 'field_y': 0 - }, - { -- c[1] - 'field_x': 1, - 'field_y': 1 - }, - { -- c[2] - 'field_x': 2, - 'field_y': 2 - } - ] + name = "EXCLUDE case sensitive lookup with both capitalized and uncapitalized removed", + query = """SELECT * EXCLUDE t."a".b.c + FROM << + { + 'a': { + 'B': { -- both 'c' and 'C' to be removed + 'c': 0, + 'C': true, + 'd': 'foo' + } } - }, - 'foo': 'bar' - }] AS t""", + } + >> AS t""", expected = BagType( StructType( fields = mapOf( "a" to StructType( fields = mapOf( - "b" to StructType( + "B" to StructType( fields = mapOf( - "c" to ListType( - elementType = StructType( - fields = mapOf( - // all fields gone - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ) - ) + "d" to StaticType.STRING ), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true) ) - ) + ), ), contentClosed = true, constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) ), - "foo" to StaticType.STRING ), contentClosed = true, constraints = setOf( @@ -1610,29 +1502,39 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "EXCLUDE SELECT star order by", - query = """SELECT * - EXCLUDE - t.a - FROM [ - { - 'a': 2, - 'foo': 'bar2' - }, - { - 'a': 1, - 'foo': 'bar1' - }, + name = "EXCLUDE with both duplicates", + query = """SELECT * EXCLUDE t."a".b.c + FROM << { - 'a': 3, - 'foo': 'bar3' + 'a': { + 'B': { + 'c': 0, + 'c': true, + 'd': 'foo' + } + } } - ] AS t - ORDER BY t.a""", - expected = ListType( + >> AS t""", + expected = BagType( StructType( fields = mapOf( - "foo" to StaticType.STRING + "a" to StructType( + fields = mapOf( + "B" to StructType( + fields = mapOf( + // both "c" removed + "d" to StaticType.STRING + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(false) + ) // UniqueAttrs set to false + ), + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ), ), contentClosed = true, constraints = setOf( @@ -1643,24 +1545,13 @@ class PartiQLSchemaInferencerTests { ) ) ), + // EXCLUDE regression test (behavior subject to change pending RFC) SuccessTestCase( - name = "EXCLUDE SELECT star with JOINs", - query = """SELECT * - EXCLUDE bar.d - FROM - << - {'a': 1, 'b': 11}, - {'a': 2, 'b': 22} - >> AS foo, - << - {'c': 3, 'd': 33}, - {'c': 4, 'd': 44} - >> AS bar""", + name = "EXCLUDE with removed attribute later referenced", + query = "SELECT * EXCLUDE t.a, t.a.b FROM << { 'a': { 'b': 1 }, 'c': 2 } >> AS t", expected = BagType( StructType( fields = mapOf( - "a" to StaticType.INT, - "b" to StaticType.INT, "c" to StaticType.INT ), contentClosed = true, @@ -1672,28 +1563,14 @@ class PartiQLSchemaInferencerTests { ) ) ), + // EXCLUDE regression test (behavior subject to change pending RFC) SuccessTestCase( - name = "SELECT t.b EXCLUDE ex 1", - query = """SELECT t.b EXCLUDE t.b[*].b_1 - FROM << - { - 'a': {'a_1':1,'a_2':2}, - 'b': [ {'b_1':3,'b_2':4}, {'b_1':5,'b_2':6} ], - 'c': 7, - 'd': 8 - } >> AS t""", + name = "EXCLUDE with non-existent attribute reference", + query = "SELECT * EXCLUDE t.attr_does_not_exist FROM << { 'a': 1 } >> AS t", expected = BagType( StructType( fields = mapOf( - "b" to ListType( - elementType = StructType( - fields = mapOf( - "b_2" to StaticType.INT - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ) - ), + "a" to StaticType.INT ), contentClosed = true, constraints = setOf( @@ -1704,38 +1581,49 @@ class PartiQLSchemaInferencerTests { ) ) ), + // EXCLUDE regression test (behavior subject to change pending RFC); could give error/warning SuccessTestCase( - name = "SELECT * EXCLUDE ex 2", - query = """SELECT * EXCLUDE t.b[*].b_1 + name = "exclude union of types", + query = """SELECT t EXCLUDE t.a.b FROM << - { - 'a': {'a_1':1,'a_2':2}, - 'b': [ {'b_1':3,'b_2':4}, {'b_1':5,'b_2':6} ], - 'c': 7, - 'd': 8 - } >> AS t""", + { + 'a': { + 'b': 1, -- `b` to be excluded + 'c': 'foo' + } + }, + { + 'a': NULL + } + >> AS t""", expected = BagType( StructType( fields = mapOf( - "a" to StructType( - fields = mapOf( - "a_1" to StaticType.INT, - "a_2" to StaticType.INT + "t" to StaticType.unionOf( + StructType( + fields = mapOf( + "a" to StructType( + fields = mapOf( + "c" to StaticType.STRING + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ) + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ), - "b" to ListType( - elementType = StructType( + StructType( fields = mapOf( - "b_2" to StaticType.INT + "a" to StaticType.NULL ), contentClosed = true, constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ) - ), - "c" to StaticType.INT, - "d" to StaticType.INT + ), + ) ), contentClosed = true, constraints = setOf( @@ -1747,55 +1635,54 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "SELECT VALUE t.b EXCLUDE", - query = """SELECT VALUE t.b EXCLUDE t.b[*].b_1 - FROM << - { - 'a': {'a_1':1,'a_2':2}, - 'b': [ {'b_1':3,'b_2':4}, {'b_1':5,'b_2':6} ], - 'c': 7, - 'd': 8 - } >> AS t""", - expected = BagType( - ListType( - elementType = StructType( - fields = mapOf( - "b_2" to StaticType.INT - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ) - ), - ) - ), - SuccessTestCase( - name = "SELECT * EXCLUDE collection wildcard and nested tuple attr", - query = """SELECT * EXCLUDE t.a[*].b.c + name = "exclude union of types exclude same type", + query = """SELECT t EXCLUDE t.a.b FROM << { - 'a': [ - { 'b': { 'c': 0, 'd': 'zero' } }, - { 'b': { 'c': 1, 'd': 'one' } }, - { 'b': { 'c': 2, 'd': 'two' } } - ] + 'a': { + 'b': 1, -- `b` to be excluded + 'c': 'foo' + } + }, + { + 'a': { + 'b': 1, -- `b` to be excluded + 'c': NULL + } } >> AS t""", expected = BagType( StructType( fields = mapOf( - "a" to ListType( - elementType = StructType( + "t" to StaticType.unionOf( + StructType( fields = mapOf( - "b" to StructType( + "a" to StructType( fields = mapOf( - "d" to StaticType.STRING + "c" to StaticType.STRING ), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true) ) - ), + ) + ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ), + StructType( + fields = mapOf( + "a" to StructType( + fields = mapOf( + "c" to StaticType.NULL + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ) ), contentClosed = true, constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) @@ -1812,38 +1699,40 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "SELECT * EXCLUDE collection index and nested tuple attr", - query = """SELECT * EXCLUDE t.a[1].b.c + name = "exclude union of types exclude different type", + query = """SELECT t EXCLUDE t.a.c FROM << { - 'a': [ - { 'b': { 'c': 0, 'd': 'zero' } }, - { 'b': { 'c': 1, 'd': 'one' } }, - { 'b': { 'c': 2, 'd': 'two' } } - ] + 'a': { + 'b': 1, + 'c': 'foo' -- `c` to be excluded + } + }, + { + 'a': { + 'b': 1, + 'c': NULL -- `c` to be excluded + } } >> AS t""", expected = BagType( StructType( fields = mapOf( - "a" to ListType( - elementType = StructType( - fields = mapOf( - "b" to StructType( - fields = mapOf( - "c" to StaticType.INT.asOptional(), - "d" to StaticType.STRING - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) + "t" to StructType( // union gone + fields = mapOf( + "a" to StructType( + fields = mapOf( + "b" to StaticType.INT ), - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ) ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) ) ), contentClosed = true, @@ -1855,36 +1744,39 @@ class PartiQLSchemaInferencerTests { ) ) ), + // EXCLUDE regression test (behavior subject to change pending RFC); could give error/warning SuccessTestCase( - name = "SELECT * EXCLUDE collection wildcard and nested tuple wildcard", - query = """SELECT * EXCLUDE t.a[*].b.* + name = "invalid exclude collection wildcard", + query = """SELECT * EXCLUDE t.a[*] FROM << { - 'a': [ - { 'b': { 'c': 0, 'd': 'zero' } }, - { 'b': { 'c': 1, 'd': 'one' } }, - { 'b': { 'c': 2, 'd': 'two' } } - ] + 'a': { + 'b': { + 'c': 0, + 'd': 'foo' + } + } } >> AS t""", expected = BagType( - StructType( + elementType = StructType( fields = mapOf( - "a" to ListType( - elementType = StructType( - fields = mapOf( - "b" to StructType( - fields = mapOf(), // empty map; all fields of b excluded - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) + "a" to StructType( + fields = mapOf( + "b" to StructType( + fields = mapOf( + "c" to StaticType.INT, + "d" to StaticType.STRING ), - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ) ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) ) ), contentClosed = true, @@ -1896,39 +1788,39 @@ class PartiQLSchemaInferencerTests { ) ) ), + // EXCLUDE regression test (behavior subject to change pending RFC); could give error/warning SuccessTestCase( - name = "SELECT * EXCLUDE collection index and nested tuple wildcard", - query = """SELECT * EXCLUDE t.a[1].b.* + name = "invalid exclude collection index", + query = """SELECT * EXCLUDE t.a[1] FROM << { - 'a': [ - { 'b': { 'c': 0, 'd': 'zero' } }, - { 'b': { 'c': 1, 'd': 'one' } }, - { 'b': { 'c': 2, 'd': 'two' } } - ] + 'a': { + 'b': { + 'c': 0, + 'd': 'foo' + } + } } >> AS t""", expected = BagType( - StructType( + elementType = StructType( fields = mapOf( - "a" to ListType( - elementType = StructType( - fields = mapOf( - "b" to StructType( - fields = mapOf( // all fields of b optional - "c" to StaticType.INT.asOptional(), - "d" to StaticType.STRING.asOptional() - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) + "a" to StructType( + fields = mapOf( + "b" to StructType( + fields = mapOf( + "c" to StaticType.INT, + "d" to StaticType.STRING ), - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ) ), + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) ) ), contentClosed = true, @@ -1940,50 +1832,30 @@ class PartiQLSchemaInferencerTests { ) ) ), + // EXCLUDE regression test (behavior subject to change pending RFC); could give error/warning SuccessTestCase( - name = "SELECT * EXCLUDE collection wildcard and nested collection wildcard", - query = """SELECT * EXCLUDE t.a[*].b.d[*].e + name = "invalid exclude tuple attr", + query = """SELECT * EXCLUDE t.a.b FROM << { 'a': [ - { 'b': { 'c': 0, 'd': [{'e': 'zero', 'f': true}] } }, - { 'b': { 'c': 1, 'd': [{'e': 'one', 'f': true}] } }, - { 'b': { 'c': 2, 'd': [{'e': 'two', 'f': true}] } } + { 'b': 0 }, + { 'b': 1 }, + { 'b': 2 } ] } >> AS t""", expected = BagType( - StructType( + elementType = StructType( fields = mapOf( "a" to ListType( elementType = StructType( fields = mapOf( - "b" to StructType( - fields = mapOf( - "c" to StaticType.INT, - "d" to ListType( - elementType = StructType( - fields = mapOf( - "f" to StaticType.BOOL - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ) - ) - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ), + "b" to StaticType.INT ), contentClosed = true, constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ), + ) ) ), contentClosed = true, @@ -1995,51 +1867,30 @@ class PartiQLSchemaInferencerTests { ) ) ), + // EXCLUDE regression test (behavior subject to change pending RFC); could give error/warning SuccessTestCase( - name = "SELECT * EXCLUDE collection index and nested collection wildcard", - query = """SELECT * EXCLUDE t.a[1].b.d[*].e + name = "invalid exclude tuple wildcard", + query = """SELECT * EXCLUDE t.a.* FROM << { 'a': [ - { 'b': { 'c': 0, 'd': [{'e': 'zero', 'f': true}] } }, - { 'b': { 'c': 1, 'd': [{'e': 'one', 'f': true}] } }, - { 'b': { 'c': 2, 'd': [{'e': 'two', 'f': true}] } } + { 'b': 0 }, + { 'b': 1 }, + { 'b': 2 } ] } >> AS t""", expected = BagType( - StructType( + elementType = StructType( fields = mapOf( "a" to ListType( elementType = StructType( fields = mapOf( - "b" to StructType( - fields = mapOf( - "c" to StaticType.INT, - "d" to ListType( - elementType = StructType( - fields = mapOf( - "e" to StaticType.STRING.asOptional(), // last step is optional since only a[1]... is excluded - "f" to StaticType.BOOL - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ) - ) - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ), + "b" to StaticType.INT ), contentClosed = true, constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ), + ) ) ), contentClosed = true, @@ -2051,51 +1902,30 @@ class PartiQLSchemaInferencerTests { ) ) ), - SuccessTestCase( - name = "SELECT * EXCLUDE collection index and nested collection index", - query = """SELECT * EXCLUDE t.a[1].b.d[0].e + // EXCLUDE regression test (behavior subject to change pending RFC); could give error/warning + SuccessTestCase( + name = "invalid exclude tuple attr step", + query = """SELECT * EXCLUDE t.b -- `t.b` does not exist FROM << { - 'a': [ - { 'b': { 'c': 0, 'd': [{'e': 'zero', 'f': true}] } }, - { 'b': { 'c': 1, 'd': [{'e': 'one', 'f': true}] } }, - { 'b': { 'c': 2, 'd': [{'e': 'two', 'f': true}] } } - ] + 'a': << + { 'b': 0 }, + { 'b': 1 }, + { 'b': 2 } + >> } >> AS t""", expected = BagType( - StructType( + elementType = StructType( fields = mapOf( - "a" to ListType( + "a" to BagType( elementType = StructType( fields = mapOf( - "b" to StructType( - fields = mapOf( - "c" to StaticType.INT, - "d" to ListType( - elementType = StructType( - fields = mapOf( // same as above - "e" to StaticType.STRING.asOptional(), - "f" to StaticType.BOOL - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ) - ) - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ), + "b" to StaticType.INT ), contentClosed = true, constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ), + ) ) ), contentClosed = true, @@ -2107,27 +1937,78 @@ class PartiQLSchemaInferencerTests { ) ) ), + // EXCLUDE regression test (behavior subject to change pending RFC); could give error/warning + // ErrorTestCase( + // name = "invalid exclude root", + // query = """SELECT * EXCLUDE nonsense.b -- `nonsense` does not exist in binding tuples + // FROM << + // { + // 'a': << + // { 'b': 0 }, + // { 'b': 1 }, + // { 'b': 2 } + // >> + // } + // >> AS t""", + // expected = BagType( + // elementType = StructType( + // fields = mapOf( + // "a" to BagType( + // elementType = StructType( + // fields = mapOf( + // "b" to StaticType.INT + // ), + // contentClosed = true, + // constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + // ) + // ) + // ), + // contentClosed = true, + // constraints = setOf( + // TupleConstraint.Open(false), + // TupleConstraint.UniqueAttrs(true), + // TupleConstraint.Ordered + // ) + // ) + // ), + // problemHandler = assertProblemExists { + // Problem( + // UNKNOWN_PROBLEM_LOCATION, + // PlanningProblemDetails.UnresolvedExcludeExprRoot("nonsense") + // ) + // } + // ), + // EXCLUDE regression test (behavior subject to change pending RFC); could give error/warning SuccessTestCase( - name = "EXCLUDE case sensitive lookup", - query = """SELECT * EXCLUDE t."a".b['c'] + name = "exclude with unions and last step collection index", + query = """SELECT * EXCLUDE t.a[0].c -- `c`'s type to be unioned with `MISSING` FROM << { - 'a': { - 'B': { - 'c': 0, - 'd': 'foo' + 'a': [ + { + 'b': 0, + 'c': 0 + }, + { + 'b': 1, + 'c': NULL + }, + { + 'b': 2, + 'c': 0.1 } - } + ] } >> AS t""", expected = BagType( - StructType( + elementType = StructType( fields = mapOf( - "a" to StructType( - fields = mapOf( - "B" to StructType( + "a" to ListType( + elementType = StaticType.unionOf( + StructType( fields = mapOf( - "d" to StaticType.STRING + "b" to StaticType.INT, + "c" to StaticType.INT.asOptional() ), contentClosed = true, constraints = setOf( @@ -2135,10 +2016,30 @@ class PartiQLSchemaInferencerTests { TupleConstraint.UniqueAttrs(true) ) ), - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ), + StructType( + fields = mapOf( + "b" to StaticType.INT, + "c" to StaticType.NULL.asOptional() + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ), + StructType( + fields = mapOf( + "b" to StaticType.INT, + "c" to StaticType.DECIMAL.asOptional() + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ) + ) + ) ), contentClosed = true, constraints = setOf( @@ -2150,38 +2051,22 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "EXCLUDE case sensitive lookup with capitalized and uncapitalized attr", - query = """SELECT * EXCLUDE t."a".b['c'] - FROM << - { - 'a': { - 'B': { - 'c': 0, - 'C': true, - 'd': 'foo' - } - } - } - >> AS t""", + name = "EXCLUDE using a catalog", + catalog = CATALOG_B, + query = "SELECT * EXCLUDE t.c FROM b.b.b AS t", expected = BagType( - StructType( + elementType = StructType( fields = mapOf( - "a" to StructType( + "b" to StructType( fields = mapOf( - "B" to StructType( - fields = mapOf( - "C" to StaticType.BOOL, // keep 'C' - "d" to StaticType.STRING - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ), + "b" to StaticType.INT ), contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) ), ), contentClosed = true, @@ -2191,42 +2076,138 @@ class PartiQLSchemaInferencerTests { TupleConstraint.Ordered ) ) - ) + ) + ), + ) + } + + sealed class TestCase { + class SuccessTestCase( + val name: String, + val query: String, + val catalog: String? = null, + val catalogPath: List = emptyList(), + val expected: StaticType, + ) : TestCase() { + override fun toString(): String = "$name : $query" + } + + class ErrorTestCase( + val name: String, + val query: String, + val catalog: String? = null, + val catalogPath: List = emptyList(), + val note: String? = null, + val expected: StaticType? = null, + val problemHandler: ProblemHandler? = null, + ) : TestCase() { + override fun toString(): String = "$name : $query" + } + + class ThrowingExceptionTestCase( + val name: String, + val query: String, + val catalog: String? = null, + val catalogPath: List = emptyList(), + val note: String? = null, + val expectedThrowable: KClass, + ) : TestCase() { + override fun toString(): String { + return "$name : $query" + } + } + } + + class TestProvider : ArgumentsProvider { + override fun provideArguments(context: ExtensionContext?): Stream { + return parameters.map { Arguments.of(it) }.stream() + } + + private val parameters = listOf( + ErrorTestCase( + name = "Pets should not be accessible #1", + query = "SELECT * FROM pets", + expected = BagType( + StructType( + fields = mapOf("pets" to StaticType.ANY), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ), + problemHandler = assertProblemExists { + Problem( + UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.UndefinedVariable("pets", false) + ) + } + ), + ErrorTestCase( + name = "Pets should not be accessible #2", + catalog = CATALOG_AWS, + query = "SELECT * FROM pets", + expected = BagType( + StructType( + fields = mapOf("pets" to StaticType.ANY), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ), + problemHandler = assertProblemExists { + Problem( + UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.UndefinedVariable("pets", false) + ) + } ), SuccessTestCase( - name = "EXCLUDE case sensitive lookup with both capitalized and uncapitalized removed", - query = """SELECT * EXCLUDE t."a".b.c - FROM << - { - 'a': { - 'B': { -- both 'c' and 'C' to be removed - 'c': 0, - 'C': true, - 'd': 'foo' - } - } - } - >> AS t""", + name = "Project all explicitly", + catalog = CATALOG_AWS, + catalogPath = listOf("ddb"), + query = "SELECT * FROM pets", + expected = TABLE_AWS_DDB_PETS + ), + SuccessTestCase( + name = "Project all implicitly", + catalog = CATALOG_AWS, + catalogPath = listOf("ddb"), + query = "SELECT id, breed FROM pets", + expected = TABLE_AWS_DDB_PETS + ), + SuccessTestCase( + name = "Test #4", + catalog = CATALOG_B, + catalogPath = listOf("b"), + query = "b", + expected = TYPE_B_B_B + ), + SuccessTestCase( + name = "Test #5", + catalog = CATALOG_AWS, + catalogPath = listOf("ddb"), + query = "SELECT * FROM b", + expected = TABLE_AWS_DDB_B + ), + SuccessTestCase( + name = "Test #6", + catalog = CATALOG_AWS, + catalogPath = listOf("b"), + query = "SELECT * FROM b", + expected = TABLE_AWS_B_B + ), + ErrorTestCase( + name = "Test #7", + query = "SELECT * FROM ddb.pets", expected = BagType( StructType( - fields = mapOf( - "a" to StructType( - fields = mapOf( - "B" to StructType( - fields = mapOf( - "d" to StaticType.STRING - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ), - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ), - ), + fields = mapOf("pets" to StaticType.ANY), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -2234,196 +2215,321 @@ class PartiQLSchemaInferencerTests { TupleConstraint.Ordered ) ) - ) + ), + problemHandler = assertProblemExists { + Problem( + UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.UndefinedVariable("pets", false) + ) + } + ), + SuccessTestCase( + name = "Test #10", + catalog = CATALOG_B, + query = "b.b", + expected = TYPE_B_B_B + ), + SuccessTestCase( + name = "Test #11", + catalog = CATALOG_B, + catalogPath = listOf("b"), + query = "b.b", + expected = TYPE_B_B_B + ), + SuccessTestCase( + name = "Test #12", + catalog = CATALOG_AWS, + catalogPath = listOf("ddb"), + query = "SELECT * FROM b.b", + expected = TABLE_AWS_B_B + ), + SuccessTestCase( + name = "Test #13", + catalog = CATALOG_AWS, + catalogPath = listOf("ddb"), + query = "SELECT * FROM ddb.b", + expected = TABLE_AWS_DDB_B + ), + SuccessTestCase( + name = "Test #14", + query = "SELECT * FROM aws.ddb.pets", + expected = TABLE_AWS_DDB_PETS + ), + SuccessTestCase( + name = "Test #15", + catalog = CATALOG_AWS, + query = "SELECT * FROM aws.b.b", + expected = TABLE_AWS_B_B + ), + SuccessTestCase( + name = "Test #16", + catalog = CATALOG_B, + query = "b.b.b", + expected = TYPE_B_B_B + ), + SuccessTestCase( + name = "Test #17", + catalog = CATALOG_B, + query = "b.b.c", + expected = TYPE_B_B_C + ), + SuccessTestCase( + name = "Test #18", + catalog = CATALOG_B, + catalogPath = listOf("b"), + query = "b.b.b", + expected = TYPE_B_B_B + ), + SuccessTestCase( + name = "Test #19", + query = "b.b.b.c", + expected = TYPE_B_B_B_C + ), + SuccessTestCase( + name = "Test #20", + query = "b.b.b.b", + expected = TYPE_B_B_B_B + ), + SuccessTestCase( + name = "Test #21", + catalog = CATALOG_B, + query = "b.b.b.b", + expected = TYPE_B_B_B_B + ), + SuccessTestCase( + name = "Test #22", + catalog = CATALOG_B, + query = "b.b.b.c", + expected = TYPE_B_B_C + ), + SuccessTestCase( + name = "Test #23", + catalog = CATALOG_B, + catalogPath = listOf("b"), + query = "b.b.b.b", + expected = TYPE_B_B_B_B + ), + SuccessTestCase( + name = "Test #24", + query = "b.b.b.b.b", + expected = TYPE_B_B_B_B_B + ), + SuccessTestCase( + name = "Test #24", + catalog = CATALOG_B, + query = "b.b.b.b.b", + expected = TYPE_B_B_B_B_B + ), + SuccessTestCase( + name = "EQ", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "order_info.customer_id = 1", + expected = TYPE_BOOL + ), + SuccessTestCase( + name = "NEQ", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "order_info.customer_id <> 1", + expected = TYPE_BOOL + ), + SuccessTestCase( + name = "GEQ", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "order_info.customer_id >= 1", + expected = TYPE_BOOL + ), + SuccessTestCase( + name = "GT", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "order_info.customer_id > 1", + expected = TYPE_BOOL + ), + SuccessTestCase( + name = "LEQ", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "order_info.customer_id <= 1", + expected = TYPE_BOOL + ), + SuccessTestCase( + name = "LT", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "order_info.customer_id < 1", + expected = TYPE_BOOL + ), + SuccessTestCase( + name = "IN", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "order_info.customer_id IN (1, 2, 3)", + expected = TYPE_BOOL + ), + ErrorTestCase( + name = "IN Failure", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "order_info.customer_id IN 'hello'", + expected = MISSING, + problemHandler = assertProblemExists { + Problem( + UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.UnknownFunction( + "in_collection", + listOf(INT, STRING), + ) + ) + } ), SuccessTestCase( - name = "EXCLUDE with both duplicates", - query = """SELECT * EXCLUDE t."a".b.c - FROM << - { - 'a': { - 'B': { - 'c': 0, - 'c': true, - 'd': 'foo' - } - } - } - >> AS t""", - expected = BagType( - StructType( - fields = mapOf( - "a" to StructType( - fields = mapOf( - "B" to StructType( - fields = mapOf( - // both "c" removed - "d" to StaticType.STRING - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(false) - ) // UniqueAttrs set to false - ), - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + name = "BETWEEN", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "order_info.customer_id BETWEEN 1 AND 2", + expected = TYPE_BOOL + ), + ErrorTestCase( + name = "BETWEEN Failure", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "order_info.customer_id BETWEEN 1 AND 'a'", + expected = MISSING, + problemHandler = assertProblemExists { + Problem( + UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.UnknownFunction( + "between", + listOf( + INT, + INT, + STRING ), - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Ordered ) ) - ) + } ), - // EXCLUDE regression test (behavior subject to change pending RFC) SuccessTestCase( - name = "EXCLUDE with removed attribute later referenced", - query = "SELECT * EXCLUDE t.a, t.a.b FROM << { 'a': { 'b': 1 }, 'c': 2 } >> AS t", - expected = BagType( - StructType( - fields = mapOf( - "c" to StaticType.INT - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Ordered + name = "LIKE", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "order_info.ship_option LIKE '%ABC%'", + expected = TYPE_BOOL + ), + ErrorTestCase( + name = "LIKE Failure", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "order_info.ship_option LIKE 3", + expected = MISSING, + problemHandler = assertProblemExists { + Problem( + UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.UnknownFunction( + "like", + listOf(STRING, INT), ) ) - ) + } ), - // EXCLUDE regression test (behavior subject to change pending RFC) SuccessTestCase( - name = "EXCLUDE with non-existent attribute reference", - query = "SELECT * EXCLUDE t.attr_does_not_exist FROM << { 'a': 1 } >> AS t", - expected = BagType( - StructType( - fields = mapOf( - "a" to StaticType.INT - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Ordered - ) - ) - ) + name = "Case Insensitive success", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "order_info.CUSTOMER_ID = 1", + expected = TYPE_BOOL + ), + ErrorTestCase( + name = "Case Sensitive failure", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "order_info.\"CUSTOMER_ID\" = 1", + expected = TYPE_BOOL ), - // EXCLUDE regression test (behavior subject to change pending RFC); could give error/warning SuccessTestCase( - name = "exclude union of types", - query = """SELECT t EXCLUDE t.a.b - FROM << - { - 'a': { - 'b': 1, -- `b` to be excluded - 'c': 'foo' - } - }, - { - 'a': NULL - } - >> AS t""", - expected = BagType( - StructType( - fields = mapOf( - "t" to StaticType.unionOf( - StructType( - fields = mapOf( - "a" to StructType( - fields = mapOf( - "c" to StaticType.STRING - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ) - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ), - StructType( - fields = mapOf( - "a" to StaticType.NULL - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ), - ) - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Ordered + name = "Case Sensitive success", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "order_info.\"customer_id\" = 1", + expected = TYPE_BOOL + ), + SuccessTestCase( + name = "1-Level Junction", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "(order_info.customer_id = 1) AND (order_info.marketplace_id = 2)", + expected = TYPE_BOOL + ), + SuccessTestCase( + name = "2-Level Junction", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "(order_info.customer_id = 1) AND (order_info.marketplace_id = 2) OR (order_info.customer_id = 3) AND (order_info.marketplace_id = 4)", + expected = TYPE_BOOL + ), + SuccessTestCase( + name = "INT and STR Comparison", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "order_info.customer_id = 'something'", + expected = TYPE_BOOL, + ), + ErrorTestCase( + name = "Nonexisting Comparison", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "non_existing_column = 1", + expected = StaticType.BOOL, + problemHandler = assertProblemExists { + Problem( + UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.UndefinedVariable("non_existing_column", false) + ) + } + ), + ErrorTestCase( + name = "Bad comparison", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "order_info.customer_id = 1 AND 1", + expected = MISSING, + problemHandler = assertProblemExists { + Problem( + UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.UnknownFunction( + "and", + listOf(StaticType.BOOL, INT), ) ) - ) + } ), - SuccessTestCase( - name = "exclude union of types exclude same type", - query = """SELECT t EXCLUDE t.a.b - FROM << - { - 'a': { - 'b': 1, -- `b` to be excluded - 'c': 'foo' - } - }, - { - 'a': { - 'b': 1, -- `b` to be excluded - 'c': NULL - } - } - >> AS t""", + ErrorTestCase( + name = "Bad comparison", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "1 AND order_info.customer_id = 1", + expected = MISSING, + problemHandler = assertProblemExists { + Problem( + UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.UnknownFunction( + "and", + listOf(INT, StaticType.BOOL), + ) + ) + } + ), + ErrorTestCase( + name = "Unknown column", + catalog = CATALOG_DB, + catalogPath = DB_SCHEMA_MARKETS, + query = "SELECT unknown_col FROM orders WHERE customer_id = 1", expected = BagType( StructType( - fields = mapOf( - "t" to StaticType.unionOf( - StructType( - fields = mapOf( - "a" to StructType( - fields = mapOf( - "c" to StaticType.STRING - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ) - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ), - StructType( - fields = mapOf( - "a" to StructType( - fields = mapOf( - "c" to StaticType.NULL - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ) - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ), - ) - ), + fields = mapOf("unknown_col" to AnyType()), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -2431,45 +2537,83 @@ class PartiQLSchemaInferencerTests { TupleConstraint.Ordered ) ) - ) + ), + problemHandler = assertProblemExists { + Problem( + UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.UndefinedVariable("unknown_col", false) + ) + } ), SuccessTestCase( - name = "exclude union of types exclude different type", - query = """SELECT t EXCLUDE t.a.c - FROM << - { - 'a': { - 'b': 1, - 'c': 'foo' -- `c` to be excluded - } - }, - { - 'a': { - 'b': 1, - 'c': NULL -- `c` to be excluded - } - } - >> AS t""", + name = "ORDER BY int", + catalog = CATALOG_AWS, + catalogPath = listOf("ddb"), + query = "SELECT * FROM pets ORDER BY id", + expected = TABLE_AWS_DDB_PETS_LIST + ), + SuccessTestCase( + name = "ORDER BY str", + catalog = CATALOG_AWS, + catalogPath = listOf("ddb"), + query = "SELECT * FROM pets ORDER BY breed", + expected = TABLE_AWS_DDB_PETS_LIST + ), + SuccessTestCase( + name = "ORDER BY str", + catalog = CATALOG_AWS, + catalogPath = listOf("ddb"), + query = "SELECT * FROM pets ORDER BY unknown_col", + expected = TABLE_AWS_DDB_PETS_LIST + ), + SuccessTestCase( + name = "LIMIT INT", + catalog = CATALOG_AWS, + catalogPath = listOf("ddb"), + query = "SELECT * FROM pets LIMIT 5", + expected = TABLE_AWS_DDB_PETS + ), + ErrorTestCase( + name = "LIMIT STR", + catalog = CATALOG_AWS, + catalogPath = listOf("ddb"), + query = "SELECT * FROM pets LIMIT '5'", + expected = TABLE_AWS_DDB_PETS, + problemHandler = assertProblemExists { + Problem( + UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.UnexpectedType(STRING, setOf(INT)) + ) + } + ), + SuccessTestCase( + name = "OFFSET INT", + catalog = CATALOG_AWS, + catalogPath = listOf("ddb"), + query = "SELECT * FROM pets LIMIT 1 OFFSET 5", + expected = TABLE_AWS_DDB_PETS + ), + ErrorTestCase( + name = "OFFSET STR", + catalog = CATALOG_AWS, + catalogPath = listOf("ddb"), + query = "SELECT * FROM pets LIMIT 1 OFFSET '5'", + expected = TABLE_AWS_DDB_PETS, + problemHandler = assertProblemExists { + Problem( + UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.UnexpectedType(STRING, setOf(INT)) + ) + } + ), + SuccessTestCase( + name = "CAST", + catalog = CATALOG_AWS, + catalogPath = listOf("ddb"), + query = "SELECT CAST(breed AS INT) AS cast_breed FROM pets", expected = BagType( StructType( - fields = mapOf( - "t" to StructType( // union gone - fields = mapOf( - "a" to StructType( - fields = mapOf( - "b" to StaticType.INT - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ) - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ) - ), + fields = mapOf("cast_breed" to unionOf(INT, MISSING)), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -2479,41 +2623,14 @@ class PartiQLSchemaInferencerTests { ) ) ), - // EXCLUDE regression test (behavior subject to change pending RFC); could give error/warning SuccessTestCase( - name = "invalid exclude collection wildcard", - query = """SELECT * EXCLUDE t.a[*] - FROM << - { - 'a': { - 'b': { - 'c': 0, - 'd': 'foo' - } - } - } - >> AS t""", + name = "UPPER", + catalog = CATALOG_AWS, + catalogPath = listOf("ddb"), + query = "SELECT UPPER(breed) AS upper_breed FROM pets", expected = BagType( - elementType = StructType( - fields = mapOf( - "a" to StructType( - fields = mapOf( - "b" to StructType( - fields = mapOf( - "c" to StaticType.INT, - "d" to StaticType.STRING - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ) - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ) - ), + StructType( + fields = mapOf("upper_breed" to STRING), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -2523,41 +2640,12 @@ class PartiQLSchemaInferencerTests { ) ) ), - // EXCLUDE regression test (behavior subject to change pending RFC); could give error/warning SuccessTestCase( - name = "invalid exclude collection index", - query = """SELECT * EXCLUDE t.a[1] - FROM << - { - 'a': { - 'b': { - 'c': 0, - 'd': 'foo' - } - } - } - >> AS t""", + name = "Non-tuples", + query = "SELECT a FROM << [ 1, 1.0 ] >> AS a", expected = BagType( - elementType = StructType( - fields = mapOf( - "a" to StructType( - fields = mapOf( - "b" to StructType( - fields = mapOf( - "c" to StaticType.INT, - "d" to StaticType.STRING - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ) - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ) - ), + StructType( + fields = mapOf("a" to ListType(unionOf(INT, StaticType.DECIMAL))), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -2567,31 +2655,30 @@ class PartiQLSchemaInferencerTests { ) ) ), - // EXCLUDE regression test (behavior subject to change pending RFC); could give error/warning SuccessTestCase( - name = "invalid exclude tuple attr", - query = """SELECT * EXCLUDE t.a.b + name = "Non-tuples in SELECT VALUE", + query = "SELECT VALUE a FROM << [ 1, 1.0 ] >> AS a", + expected = + BagType(ListType(unionOf(INT, StaticType.DECIMAL))) + ), + SuccessTestCase( + name = "SELECT VALUE", + query = "SELECT VALUE [1, 1.0] FROM <<>>", + expected = + BagType(ListType(unionOf(INT, StaticType.DECIMAL))) + ), + SuccessTestCase( + name = "Duplicate fields in struct", + query = """ + SELECT t.a AS a FROM << - { - 'a': [ - { 'b': 0 }, - { 'b': 1 }, - { 'b': 2 } - ] - } - >> AS t""", + { 'a': 1, 'a': 'hello' } + >> AS t + """, expected = BagType( - elementType = StructType( - fields = mapOf( - "a" to ListType( - elementType = StructType( - fields = mapOf( - "b" to StaticType.INT - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ) - ) + StructType( + fields = listOf( + StructType.Field("a", unionOf(INT, STRING)) ), contentClosed = true, constraints = setOf( @@ -2602,31 +2689,16 @@ class PartiQLSchemaInferencerTests { ) ) ), - // EXCLUDE regression test (behavior subject to change pending RFC); could give error/warning SuccessTestCase( - name = "invalid exclude tuple wildcard", - query = """SELECT * EXCLUDE t.a.* - FROM << - { - 'a': [ - { 'b': 0 }, - { 'b': 1 }, - { 'b': 2 } - ] - } - >> AS t""", + name = "Duplicate fields in ordered STRUCT. NOTE: b.b.d is an ordered struct with two attributes (e). First is INT.", + query = """ + SELECT d.e AS e + FROM << b.b.d >> AS d + """, expected = BagType( - elementType = StructType( - fields = mapOf( - "a" to ListType( - elementType = StructType( - fields = mapOf( - "b" to StaticType.INT - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ) - ) + StructType( + fields = listOf( + StructType.Field("e", INT) ), contentClosed = true, constraints = setOf( @@ -2637,31 +2709,18 @@ class PartiQLSchemaInferencerTests { ) ) ), - // EXCLUDE regression test (behavior subject to change pending RFC); could give error/warning SuccessTestCase( - name = "invalid exclude tuple attr step", - query = """SELECT * EXCLUDE t.b -- `t.b` does not exist + name = "Duplicate fields in struct", + query = """ + SELECT a AS a FROM << - { - 'a': << - { 'b': 0 }, - { 'b': 1 }, - { 'b': 2 } - >> - } - >> AS t""", + { 'a': 1, 'a': 'hello' } + >> AS t + """, expected = BagType( - elementType = StructType( - fields = mapOf( - "a" to BagType( - elementType = StructType( - fields = mapOf( - "b" to StaticType.INT - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ) - ) + StructType( + fields = listOf( + StructType.Field("a", unionOf(INT, STRING)) ), contentClosed = true, constraints = setOf( @@ -2672,109 +2731,16 @@ class PartiQLSchemaInferencerTests { ) ) ), - // EXCLUDE regression test (behavior subject to change pending RFC); could give error/warning - // ErrorTestCase( - // name = "invalid exclude root", - // query = """SELECT * EXCLUDE nonsense.b -- `nonsense` does not exist in binding tuples - // FROM << - // { - // 'a': << - // { 'b': 0 }, - // { 'b': 1 }, - // { 'b': 2 } - // >> - // } - // >> AS t""", - // expected = BagType( - // elementType = StructType( - // fields = mapOf( - // "a" to BagType( - // elementType = StructType( - // fields = mapOf( - // "b" to StaticType.INT - // ), - // contentClosed = true, - // constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - // ) - // ) - // ), - // contentClosed = true, - // constraints = setOf( - // TupleConstraint.Open(false), - // TupleConstraint.UniqueAttrs(true), - // TupleConstraint.Ordered - // ) - // ) - // ), - // problemHandler = assertProblemExists { - // Problem( - // UNKNOWN_PROBLEM_LOCATION, - // PlanningProblemDetails.UnresolvedExcludeExprRoot("nonsense") - // ) - // } - // ), - // EXCLUDE regression test (behavior subject to change pending RFC); could give error/warning SuccessTestCase( - name = "exclude with unions and last step collection index", - query = """SELECT * EXCLUDE t.a[0].c -- `c`'s type to be unioned with `MISSING` - FROM << - { - 'a': [ - { - 'b': 0, - 'c': 0 - }, - { - 'b': 1, - 'c': NULL - }, - { - 'b': 2, - 'c': 0.1 - } - ] - } - >> AS t""", + name = "AGGREGATE over INTS", + query = "SELECT a, COUNT(*) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1, 'b': 2} >> GROUP BY a", expected = BagType( - elementType = StructType( + StructType( fields = mapOf( - "a" to ListType( - elementType = StaticType.unionOf( - StructType( - fields = mapOf( - "b" to StaticType.INT, - "c" to StaticType.INT.asOptional() - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ), - StructType( - fields = mapOf( - "b" to StaticType.INT, - "c" to StaticType.NULL.asOptional() - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ), - StructType( - fields = mapOf( - "b" to StaticType.INT, - "c" to StaticType.DECIMAL.asOptional() - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ) - ) - ) + "a" to INT, + "c" to INT, + "s" to INT, + "m" to INT, ), contentClosed = true, constraints = setOf( @@ -2786,23 +2752,15 @@ class PartiQLSchemaInferencerTests { ) ), SuccessTestCase( - name = "EXCLUDE using a catalog", - catalog = CATALOG_B, - query = "SELECT * EXCLUDE t.c FROM b.b.b AS t", + name = "AGGREGATE over DECIMALS", + query = "SELECT a, COUNT(*) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1.0, 'b': 2.0}, {'a': 1.0, 'b': 2.0} >> GROUP BY a", expected = BagType( - elementType = StructType( + StructType( fields = mapOf( - "b" to StructType( - fields = mapOf( - "b" to StaticType.INT - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Ordered - ) - ), + "a" to StaticType.DECIMAL, + "c" to INT, + "s" to StaticType.DECIMAL, + "m" to StaticType.DECIMAL, ), contentClosed = true, constraints = setOf( @@ -2813,6 +2771,55 @@ class PartiQLSchemaInferencerTests { ) ) ), + SuccessTestCase( + name = "Current User", + query = "CURRENT_USER", + expected = unionOf(STRING, NULL) + ), + SuccessTestCase( + name = "Trim", + query = "trim(' ')", + expected = STRING + ), + SuccessTestCase( + name = "Current User Concat", + query = "CURRENT_USER || 'hello'", + expected = unionOf(STRING, NULL) + ), + SuccessTestCase( + name = "Current User Concat in WHERE", + query = "SELECT VALUE a FROM [ 0 ] AS a WHERE CURRENT_USER = 'hello'", + expected = BagType(INT) + ), + SuccessTestCase( + name = "TRIM_2", + query = "trim(' ' FROM ' Hello, World! ')", + expected = STRING + ), + SuccessTestCase( + name = "TRIM_1", + query = "trim(' Hello, World! ')", + expected = STRING + ), + SuccessTestCase( + name = "TRIM_3", + query = "trim(LEADING ' ' FROM ' Hello, World! ')", + expected = STRING + ), + ErrorTestCase( + name = "TRIM_2_error", + query = "trim(2 FROM ' Hello, World! ')", + expected = MISSING, + problemHandler = assertProblemExists { + Problem( + UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.UnknownFunction( + "trim_chars", + args = listOf(STRING, INT) + ) + ) + } + ), ) } diff --git a/partiql-plan/src/main/kotlin/org/partiql/plan/Plan.kt b/partiql-plan/src/main/kotlin/org/partiql/plan/Plan.kt deleted file mode 100644 index bcf047763a..0000000000 --- a/partiql-plan/src/main/kotlin/org/partiql/plan/Plan.kt +++ /dev/null @@ -1,19 +0,0 @@ -package org.partiql.plan - -import org.partiql.plan.builder.PlanFactory -import org.partiql.plan.builder.PlanFactoryImpl - -/** - * Singleton instance of the default factory. Also accessible via `PlanFactory.DEFAULT`. - */ -object Plan : PlanBaseFactory() { - - public inline fun create(block: PlanFactory.() -> T) = this.block() -} - -/** - * PlanBaseFactory can be used to create a factory which extends from the factory provided by PlanFactory.DEFAULT. - */ -public abstract class PlanBaseFactory : PlanFactoryImpl() { - // internal default overrides here -} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/Env.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/Env.kt index fe08bb3de6..55f807059d 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/Env.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/Env.kt @@ -3,9 +3,11 @@ package org.partiql.planner import org.partiql.plan.Fn import org.partiql.plan.Global import org.partiql.plan.Identifier -import org.partiql.plan.Plan import org.partiql.plan.Rel import org.partiql.plan.Rex +import org.partiql.plan.global +import org.partiql.plan.identifierQualified +import org.partiql.plan.identifierSymbol import org.partiql.planner.typer.FunctionResolver import org.partiql.planner.typer.Mapping import org.partiql.planner.typer.isNullOrMissing @@ -275,7 +277,7 @@ internal class Env( getObjectDescriptor(handle).let { type -> val depth = calculateMatched(originalPath, catalogPath, handle.second.absolutePath) val match = BindingPath(originalPath.steps.subList(0, depth)) - val global = Plan.global(match.toIdentifier(), type) + val global = global(match.toIdentifier(), type) globals.add(global) // Return resolution metadata ResolvedVar.Global(type, globals.size - 1, depth) @@ -411,12 +413,12 @@ internal class Env( return originalPath.steps.size + outputCatalogPath.steps.size - inputCatalogPath.steps.size } - private fun BindingPath.toIdentifier() = Plan.identifierQualified( + private fun BindingPath.toIdentifier() = identifierQualified( root = steps[0].toIdentifier(), steps = steps.subList(1, steps.size).map { it.toIdentifier() } ) - private fun BindingName.toIdentifier() = Plan.identifierSymbol( + private fun BindingName.toIdentifier() = identifierSymbol( symbol = name, caseSensitivity = when (bindingCase) { BindingCase.SENSITIVE -> Identifier.CaseSensitivity.SENSITIVE diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt index 717971c674..c7354bce65 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt @@ -4,7 +4,7 @@ import org.partiql.ast.Statement import org.partiql.ast.normalize.normalize import org.partiql.errors.ProblemCallback import org.partiql.plan.PartiQLVersion -import org.partiql.plan.Plan +import org.partiql.plan.partiQLPlan import org.partiql.planner.transforms.AstToPlan import org.partiql.planner.typer.PlanTyper import org.partiql.spi.Plugin @@ -34,7 +34,7 @@ internal class PartiQLPlannerDefault( // 3. Resolve variables val typer = PlanTyper(env, onProblem) - var plan = Plan.partiQLPlan( + var plan = partiQLPlan( version = version, globals = env.globals, statement = typer.resolve(root), diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/AstToPlan.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/AstToPlan.kt index 9337dfd19b..f8e11d150e 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/AstToPlan.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/AstToPlan.kt @@ -1,11 +1,27 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + package org.partiql.planner.transforms import org.partiql.ast.AstNode import org.partiql.ast.Expr import org.partiql.ast.visitor.AstBaseVisitor -import org.partiql.plan.Plan -import org.partiql.plan.PlanNode -import org.partiql.plan.builder.PlanFactory +import org.partiql.plan.identifierQualified +import org.partiql.plan.identifierSymbol +import org.partiql.plan.statementQuery import org.partiql.planner.Env import org.partiql.ast.Identifier as AstIdentifier import org.partiql.ast.Statement as AstStatement @@ -24,18 +40,14 @@ internal object AstToPlan { @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") private object ToPlanStatement : AstBaseVisitor() { - private val factory = Plan - - private inline fun transform(block: PlanFactory.() -> T): T = factory.block() - override fun defaultReturn(node: AstNode, env: Env) = throw IllegalArgumentException("Unsupported statement") - override fun visitStatementQuery(node: AstStatement.Query, env: Env) = transform { + override fun visitStatementQuery(node: AstStatement.Query, env: Env): PlanStatement { val rex = when (val expr = node.expr) { is Expr.SFW -> RelConverter.apply(expr, env) else -> RexConverter.apply(expr, env) } - statementQuery(rex) + return statementQuery(rex) } } @@ -49,7 +61,7 @@ internal object AstToPlan { fun convert(identifier: AstIdentifier.Qualified): PlanIdentifier.Qualified { val root = convert(identifier.root) val steps = identifier.steps.map { convert(it) } - return Plan.identifierQualified(root, steps) + return identifierQualified(root, steps) } fun convert(identifier: AstIdentifier.Symbol): PlanIdentifier.Symbol { @@ -58,6 +70,6 @@ internal object AstToPlan { AstIdentifier.CaseSensitivity.SENSITIVE -> PlanIdentifier.CaseSensitivity.SENSITIVE AstIdentifier.CaseSensitivity.INSENSITIVE -> PlanIdentifier.CaseSensitivity.INSENSITIVE } - return Plan.identifierSymbol(symbol, case) + return identifierSymbol(symbol, case) } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RelConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RelConverter.kt index 9de081bf56..d565ef3ac6 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RelConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RelConverter.kt @@ -1,3 +1,19 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + package org.partiql.planner.transforms import org.partiql.ast.AstNode @@ -11,9 +27,38 @@ import org.partiql.ast.Sort import org.partiql.ast.builder.ast import org.partiql.ast.util.AstRewriter import org.partiql.ast.visitor.AstBaseVisitor -import org.partiql.plan.Plan import org.partiql.plan.Rel import org.partiql.plan.Rex +import org.partiql.plan.fnUnresolved +import org.partiql.plan.rel +import org.partiql.plan.relBinding +import org.partiql.plan.relOpAggregate +import org.partiql.plan.relOpAggregateAgg +import org.partiql.plan.relOpErr +import org.partiql.plan.relOpExcept +import org.partiql.plan.relOpFilter +import org.partiql.plan.relOpIntersect +import org.partiql.plan.relOpJoin +import org.partiql.plan.relOpLimit +import org.partiql.plan.relOpOffset +import org.partiql.plan.relOpProject +import org.partiql.plan.relOpScan +import org.partiql.plan.relOpSort +import org.partiql.plan.relOpSortSpec +import org.partiql.plan.relOpUnion +import org.partiql.plan.relOpUnpivot +import org.partiql.plan.relType +import org.partiql.plan.rex +import org.partiql.plan.rexOpLit +import org.partiql.plan.rexOpPath +import org.partiql.plan.rexOpPivot +import org.partiql.plan.rexOpSelect +import org.partiql.plan.rexOpStruct +import org.partiql.plan.rexOpStructField +import org.partiql.plan.rexOpTupleUnion +import org.partiql.plan.rexOpTupleUnionArgSpread +import org.partiql.plan.rexOpTupleUnionArgStruct +import org.partiql.plan.rexOpVarResolved import org.partiql.planner.Env import org.partiql.types.StaticType import org.partiql.value.PartiQLValueExperimental @@ -26,14 +71,12 @@ import org.partiql.value.stringValue internal object RelConverter { // IGNORE — so we don't have to non-null assert on operator inputs - private val nil = Plan.create { - rel(relType(emptyList(), emptySet()), relOpErr()) - } + private val nil = rel(relType(emptyList(), emptySet()), relOpErr()) /** * Here we convert an SFW to composed [Rel]s, then apply the appropriate relation-value projection to get a [Rex]. */ - internal fun apply(sfw: Expr.SFW, env: Env): Rex = Plan.create { + internal fun apply(sfw: Expr.SFW, env: Env): Rex { var rel = sfw.accept(ToRel(env), nil) val rex = when (val projection = sfw.select) { // PIVOT ... FROM @@ -99,14 +142,14 @@ internal object RelConverter { * See https://partiql.org/dql/select.html#sql-select */ @OptIn(PartiQLValueExperimental::class) - private fun defaultConstructor(schema: List): Rex = Plan.create { + private fun defaultConstructor(schema: List): Rex { val fields = schema.mapIndexed { i, b -> val k = rex(StaticType.STRING, rexOpLit(stringValue(b.name))) val v = rex(b.type, rexOpVarResolved(i)) rexOpStructField(k, v) } val op = rexOpStruct(fields) - rex(StaticType.STRUCT, op) + return rex(StaticType.STRUCT, op) } /** @@ -114,7 +157,7 @@ internal object RelConverter { * * See https://partiql.org/assets/PartiQL-Specification.pdf#page=28 */ - private fun tupleUnionConstructor(op: Rel.Op.Project, type: Rel.Type): Pair = with(Plan) { + private fun tupleUnionConstructor(op: Rel.Op.Project, type: Rel.Type): Pair { val projections = mutableListOf() val args = op.projections.mapIndexed { i, item -> val binding = type.schema[i] @@ -133,14 +176,14 @@ internal object RelConverter { } val constructor = rex(StaticType.STRUCT, rexOpTupleUnion(args)) val rel = rel(type, relOpProject(op.input, projections)) - constructor to rel + return constructor to rel } private fun Rex.isProjectAll(): Boolean { return (op is Rex.Op.Path && (op as Rex.Op.Path).steps.last() is Rex.Op.Path.Step.Unpivot) } - private fun Rex.removeUnpivot(): Rex = Plan.create { + private fun Rex.removeUnpivot(): Rex { val rex = this@removeUnpivot val path = op if (path is Rex.Op.Path) { @@ -154,7 +197,7 @@ internal object RelConverter { } } // skip, should be unreachable - rex + return rex } @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE", "LocalVariableName") @@ -190,7 +233,7 @@ internal object RelConverter { return rel } - override fun visitSelectProject(node: Select.Project, input: Rel) = Plan.create { + override fun visitSelectProject(node: Select.Project, input: Rel): Rel { // this ignores aggregations val schema = mutableListOf() val props = emptySet() @@ -202,10 +245,10 @@ internal object RelConverter { } val type = relType(schema, props) val op = relOpProject(input, projections) - rel(type, op) + return rel(type, op) } - override fun visitFromValue(node: From.Value, nil: Rel) = Plan.create { + override fun visitFromValue(node: From.Value, nil: Rel): Rel { val rex = RexConverter.apply(node.expr, env) val binding = when (val a = node.asAlias) { null -> error("AST not normalized, missing AS alias on $node") @@ -214,7 +257,7 @@ internal object RelConverter { type = rex.type ) } - when (node.type) { + return when (node.type) { From.Value.Type.SCAN -> { when (val i = node.atAlias) { null -> convertScan(rex, binding) @@ -246,7 +289,7 @@ internal object RelConverter { * TODO compute basic schema */ @OptIn(PartiQLValueExperimental::class) - override fun visitFromJoin(node: From.Join, nil: Rel) = Plan.create { + override fun visitFromJoin(node: From.Join, nil: Rel): Rel { val lhs = visitFrom(node.lhs, nil) val rhs = visitFrom(node.rhs, nil) val schema = listOf() @@ -263,25 +306,25 @@ internal object RelConverter { } val type = relType(schema, props) val op = relOpJoin(lhs, rhs, condition, joinType) - rel(type, op) + return rel(type, op) } // Helpers - private fun convertScan(rex: Rex, binding: Rel.Binding) = Plan.create { + private fun convertScan(rex: Rex, binding: Rel.Binding): Rel { val schema = listOf(binding) val props = emptySet() val type = relType(schema, props) val op = relOpScan(rex) - rel(type, op) + return rel(type, op) } - private fun convertScanIndexed(rex: Rex, binding: Rel.Binding, index: Rel.Binding) = Plan.create { + private fun convertScanIndexed(rex: Rex, binding: Rel.Binding, index: Rel.Binding): Rel { val schema = listOf(binding, index) val props = setOf(Rel.Prop.ORDERED) val type = relType(schema, props) val op = relOpScan(rex) - rel(type, op) + return rel(type, op) } /** @@ -291,12 +334,12 @@ internal object RelConverter { * @param k * @param v */ - private fun convertUnpivot(rex: Rex, k: Rel.Binding, v: Rel.Binding) = Plan.create { + private fun convertUnpivot(rex: Rex, k: Rel.Binding, v: Rel.Binding): Rel { val schema = listOf(k, v) val props = emptySet() val type = relType(schema, props) val op = relOpUnpivot(rex) - rel(type, op) + return rel(type, op) } private fun convertProjectionItem(item: Select.Project.Item) = when (item) { @@ -314,21 +357,21 @@ internal object RelConverter { else -> a.symbol } val rex = RexConverter.apply(item.expr, env) - val binding = Plan.relBinding(name, rex.type) + val binding = relBinding(name, rex.type) return binding to rex } /** * Append [Rel.Op.Filter] only if a WHERE condition exists */ - private fun convertWhere(input: Rel, expr: Expr?): Rel = Plan.create { + private fun convertWhere(input: Rel, expr: Expr?): Rel { if (expr == null) { return input } val type = input.type val predicate = expr.toRex(env) val op = relOpFilter(input, predicate) - rel(type, op) + return rel(type, op) } /** @@ -361,15 +404,15 @@ internal object RelConverter { // Build the rel operator var strategy = Rel.Op.Aggregate.Strategy.FULL val aggs = aggregations.mapIndexed { i, agg -> - val binding = Plan.relBinding( + val binding = relBinding( name = syntheticAgg(i), type = (StaticType.ANY), ) schema.add(binding) val args = agg.args.map { arg -> arg.toRex(env) } val id = AstToPlan.convert(agg.function) - val fn = Plan.fnUnresolved(id) - Plan.relOpAggregateAgg(fn, args) + val fn = fnUnresolved(id) + relOpAggregateAgg(fn, args) } var groups = emptyList() if (groupBy != null) { @@ -377,7 +420,7 @@ internal object RelConverter { if (it.asAlias == null) { error("not normalized, group key $it missing unique name") } - val binding = Plan.relBinding( + val binding = relBinding( name = it.asAlias!!.symbol, type = (StaticType.ANY) ) @@ -389,9 +432,9 @@ internal object RelConverter { GroupBy.Strategy.PARTIAL -> Rel.Op.Aggregate.Strategy.PARTIAL } } - val type = Plan.relType(schema, props) - val op = Plan.relOpAggregate(input, strategy, aggs, groups) - val rel = Plan.rel(type, op) + val type = relType(schema, props) + val op = relOpAggregate(input, strategy, aggs, groups) + val rel = rel(type, op) return Pair(sel, rel) } @@ -401,14 +444,14 @@ internal object RelConverter { * Notes: * - This currently does not support aggregation expressions in the WHERE condition */ - private fun convertHaving(input: Rel, expr: Expr?): Rel = Plan.create { + private fun convertHaving(input: Rel, expr: Expr?): Rel { if (expr == null) { return input } val type = input.type val predicate = expr.toRex(env) val op = relOpFilter(input, predicate) - rel(type, op) + return rel(type, op) } /** @@ -417,7 +460,7 @@ internal object RelConverter { * TODO combine/compare schemas * TODO set quantifier */ - private fun convertSetOp(input: Rel, setOp: Expr.SFW.SetOp?): Rel = Plan.create { + private fun convertSetOp(input: Rel, setOp: Expr.SFW.SetOp?): Rel { if (setOp == null) { return input } @@ -427,15 +470,15 @@ internal object RelConverter { val op = when (setOp.type.type) { SetOp.Type.UNION -> relOpUnion(lhs, rhs) SetOp.Type.INTERSECT -> relOpIntersect(lhs, rhs) - SetOp.Type.EXCEPT -> relOpIntersect(lhs, rhs) + SetOp.Type.EXCEPT -> relOpExcept(lhs, rhs) } - rel(type, op) + return rel(type, op) } /** * Append [Rel.Op.Sort] only if an ORDER BY clause is present */ - private fun convertOrderBy(input: Rel, orderBy: OrderBy?) = Plan.create { + private fun convertOrderBy(input: Rel, orderBy: OrderBy?): Rel { if (orderBy == null) { return input } @@ -455,33 +498,33 @@ internal object RelConverter { relOpSortSpec(rex, order) } val op = relOpSort(input, specs) - rel(type, op) + return rel(type, op) } /** * Append [Rel.Op.Limit] if there is a LIMIT */ - private fun convertLimit(input: Rel, limit: Expr?): Rel = Plan.create { + private fun convertLimit(input: Rel, limit: Expr?): Rel { if (limit == null) { return input } val type = input.type val rex = RexConverter.apply(limit, env) val op = relOpLimit(input, rex) - rel(type, op) + return rel(type, op) } /** * Append [Rel.Op.Offset] if there is an OFFSET */ - private fun convertOffset(input: Rel, offset: Expr?): Rel = Plan.create { + private fun convertOffset(input: Rel, offset: Expr?): Rel { if (offset == null) { return input } val type = input.type val rex = RexConverter.apply(offset, env) val op = relOpOffset(input, rex) - rel(type, op) + return rel(type, op) } // /** diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RexConverter.kt index 513ae0ba42..a5ec405191 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RexConverter.kt @@ -1,3 +1,19 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + package org.partiql.planner.transforms import org.partiql.ast.AstNode @@ -7,12 +23,25 @@ import org.partiql.ast.Select import org.partiql.ast.Type import org.partiql.ast.visitor.AstBaseVisitor import org.partiql.plan.Identifier -import org.partiql.plan.Plan -import org.partiql.plan.Plan.rex -import org.partiql.plan.Plan.rexOpLit -import org.partiql.plan.PlanNode import org.partiql.plan.Rex -import org.partiql.plan.builder.PlanFactory +import org.partiql.plan.fnUnresolved +import org.partiql.plan.identifierSymbol +import org.partiql.plan.rex +import org.partiql.plan.rexOpCall +import org.partiql.plan.rexOpCase +import org.partiql.plan.rexOpCaseBranch +import org.partiql.plan.rexOpCollToScalar +import org.partiql.plan.rexOpCollToScalarSubquery +import org.partiql.plan.rexOpCollection +import org.partiql.plan.rexOpLit +import org.partiql.plan.rexOpPath +import org.partiql.plan.rexOpPathStepIndex +import org.partiql.plan.rexOpPathStepSymbol +import org.partiql.plan.rexOpPathStepUnpivot +import org.partiql.plan.rexOpPathStepWildcard +import org.partiql.plan.rexOpStruct +import org.partiql.plan.rexOpStructField +import org.partiql.plan.rexOpVarUnresolved import org.partiql.planner.ATTRIBUTES import org.partiql.planner.Env import org.partiql.planner.typer.toNonNullStaticType @@ -30,10 +59,6 @@ import org.partiql.value.nullValue */ internal object RexConverter { - private val factory = Plan - - private inline fun transform(block: PlanFactory.() -> T): T = factory.block() - internal fun apply(expr: Expr, context: Env): Rex = expr.accept(ToRex, context) // expr.toRex() @OptIn(PartiQLValueExperimental::class) @@ -43,16 +68,16 @@ internal object RexConverter { override fun defaultReturn(node: AstNode, context: Env): Rex = throw IllegalArgumentException("unsupported rex $node") - override fun visitExprLit(node: Expr.Lit, context: Env) = transform { + override fun visitExprLit(node: Expr.Lit, context: Env): Rex { val type = when (node.value.isNull) { true -> node.value.type.toStaticType() else -> node.value.type.toNonNullStaticType() } val op = rexOpLit(node.value) - rex(type, op) + return rex(type, op) } - override fun visitExprVar(node: Expr.Var, context: Env) = transform { + override fun visitExprVar(node: Expr.Var, context: Env): Rex { val type = (StaticType.ANY) val identifier = AstToPlan.convert(node.identifier) val scope = when (node.scope) { @@ -60,10 +85,10 @@ internal object RexConverter { Expr.Var.Scope.LOCAL -> Rex.Op.Var.Scope.LOCAL } val op = rexOpVarUnresolved(identifier, scope) - rex(type, op) + return rex(type, op) } - override fun visitExprUnary(node: Expr.Unary, context: Env) = transform { + override fun visitExprUnary(node: Expr.Unary, context: Env): Rex { val type = (StaticType.ANY) // Args val arg = node.expr.accept(ToRex, context) @@ -73,10 +98,10 @@ internal object RexConverter { val fn = fnUnresolved(id) // Rex val op = rexOpCall(fn, args) - rex(type, op) + return rex(type, op) } - override fun visitExprBinary(node: Expr.Binary, context: Env) = transform { + override fun visitExprBinary(node: Expr.Binary, context: Env): Rex { val type = (StaticType.ANY) // Args val lhs = node.lhs.accept(ToRex, context) @@ -87,10 +112,10 @@ internal object RexConverter { val fn = fnUnresolved(id) // Rex val op = rexOpCall(fn, args) - rex(type, op) + return rex(type, op) } - override fun visitExprPath(node: Expr.Path, context: Env): Rex = transform { + override fun visitExprPath(node: Expr.Path, context: Env): Rex { val type = (StaticType.ANY) // Args val root = visitExpr(node.root, context) @@ -110,10 +135,10 @@ internal object RexConverter { } // Rex val op = rexOpPath(root, steps) - rex(type, op) + return rex(type, op) } - override fun visitExprCall(node: Expr.Call, context: Env) = transform { + override fun visitExprCall(node: Expr.Call, context: Env): Rex { val type = (StaticType.ANY) // Args val args = node.args.map { visitExpr(it, context) } @@ -122,10 +147,10 @@ internal object RexConverter { val fn = fnUnresolved(id) // Rex val op = rexOpCall(fn, args) - rex(type, op) + return rex(type, op) } - override fun visitExprCase(node: Expr.Case, context: Env) = transform { + override fun visitExprCase(node: Expr.Case, context: Env): Rex { val type = (StaticType.ANY) val rex = when (node.expr) { null -> bool(true) // match `true` @@ -153,10 +178,10 @@ internal object RexConverter { } branches += rexOpCaseBranch(bool(true), defaultRex) val op = rexOpCase(branches) - rex(type, op) + return rex(type, op) } - override fun visitExprCollection(node: Expr.Collection, context: Env) = transform { + override fun visitExprCollection(node: Expr.Collection, context: Env): Rex { val type = when (node.type) { Expr.Collection.Type.BAG -> StaticType.BAG Expr.Collection.Type.ARRAY -> StaticType.LIST @@ -166,10 +191,10 @@ internal object RexConverter { } val values = node.values.map { visitExpr(it, context) } val op = rexOpCollection(values) - rex(type, op) + return rex(type, op) } - override fun visitExprStruct(node: Expr.Struct, context: Env) = transform { + override fun visitExprStruct(node: Expr.Struct, context: Env): Rex { val type = (StaticType.STRUCT) val fields = node.fields.map { val k = visitExpr(it.name, context) @@ -177,7 +202,7 @@ internal object RexConverter { rexOpStructField(k, v) } val op = rexOpStruct(fields) - rex(type, op) + return rex(type, op) } // SPECIAL FORMS @@ -185,7 +210,7 @@ internal object RexConverter { /** * NOT? LIKE ( ESCAPE )? */ - override fun visitExprLike(node: Expr.Like, ctx: Env) = transform { + override fun visitExprLike(node: Expr.Like, ctx: Env): Rex { val type = StaticType.BOOL // Args val arg0 = visitExpr(node.value, ctx) @@ -200,13 +225,13 @@ internal object RexConverter { if (node.not == true) { call = negate(call) } - rex(type, call) + return rex(type, call) } /** * NOT? BETWEEN AND */ - override fun visitExprBetween(node: Expr.Between, ctx: Env) = transform { + override fun visitExprBetween(node: Expr.Between, ctx: Env): Rex { val type = StaticType.BOOL // Args val arg0 = visitExpr(node.value, ctx) @@ -218,13 +243,13 @@ internal object RexConverter { if (node.not == true) { call = negate(call) } - rex(type, call) + return rex(type, call) } /** * NOT? IN */ - override fun visitExprInCollection(node: Expr.InCollection, ctx: Env) = transform { + override fun visitExprInCollection(node: Expr.InCollection, ctx: Env): Rex { val type = StaticType.BOOL // Args val arg0 = visitExpr(node.lhs, ctx) @@ -235,13 +260,13 @@ internal object RexConverter { if (node.not == true) { call = negate(call) } - rex(type, call) + return rex(type, call) } /** * IS ? */ - override fun visitExprIsType(node: Expr.IsType, ctx: Env) = transform { + override fun visitExprIsType(node: Expr.IsType, ctx: Env): Rex { val type = StaticType.BOOL // arg val arg0 = visitExpr(node.value, ctx) @@ -290,35 +315,35 @@ internal object RexConverter { call = negate(call) } - rex(type, call) + return rex(type, call) } - override fun visitExprCoalesce(node: Expr.Coalesce, ctx: Env): Rex = transform { + override fun visitExprCoalesce(node: Expr.Coalesce, ctx: Env): Rex { val type = StaticType.ANY // Args val arg0 = rex(StaticType.LIST, rexOpCollection(node.args.map { visitExpr(it, ctx) })) // Call val call = call("coalesce", arg0) - rex(type, call) + return rex(type, call) } /** * NULLIF(, ) */ - override fun visitExprNullIf(node: Expr.NullIf, ctx: Env) = transform { + override fun visitExprNullIf(node: Expr.NullIf, ctx: Env): Rex { val type = StaticType.ANY // Args val arg0 = visitExpr(node.value, ctx) val arg1 = visitExpr(node.nullifier, ctx) // Call val call = call("null_if", arg0, arg1) - rex(type, call) + return rex(type, call) } /** * SUBSTRING( (FROM (FOR )?)? ) */ - override fun visitExprSubstring(node: Expr.Substring, ctx: Env) = transform { + override fun visitExprSubstring(node: Expr.Substring, ctx: Env): Rex { val type = StaticType.ANY // Args val arg0 = visitExpr(node.value, ctx) @@ -329,26 +354,26 @@ internal object RexConverter { null -> call("substring", arg0, arg1) else -> call("substring_length", arg0, arg1, arg2) } - rex(type, call) + return rex(type, call) } /** * POSITION( IN ) */ - override fun visitExprPosition(node: Expr.Position, ctx: Env) = transform { + override fun visitExprPosition(node: Expr.Position, ctx: Env): Rex { val type = StaticType.ANY // Args val arg0 = visitExpr(node.lhs, ctx) val arg1 = visitExpr(node.rhs, ctx) // Call val call = call("position", arg0, arg1) - rex(type, call) + return rex(type, call) } /** * TRIM([LEADING|TRAILING|BOTH]? ( FROM)? ) */ - override fun visitExprTrim(node: Expr.Trim, ctx: Env) = transform { + override fun visitExprTrim(node: Expr.Trim, ctx: Env): Rex { val type = StaticType.TEXT // Args val arg0 = visitExpr(node.value, ctx) @@ -368,7 +393,7 @@ internal object RexConverter { else -> call("trim_chars", arg0, arg1) } } - rex(type, call) + return rex(type, call) } override fun visitExprOverlay(node: Expr.Overlay, ctx: Env): Rex { @@ -380,10 +405,10 @@ internal object RexConverter { } // TODO: Ignoring type parameter now - override fun visitExprCast(node: Expr.Cast, ctx: Env): Rex = transform { + override fun visitExprCast(node: Expr.Cast, ctx: Env): Rex { val type = node.asType val arg0 = visitExpr(node.value, ctx) - when (type) { + return when (type) { is Type.NullType -> rex(StaticType.NULL, call("cast_null", arg0)) is Type.Missing -> rex(StaticType.MISSING, call("cast_missing", arg0)) is Type.Bool -> rex(StaticType.BOOL, call("cast_bool", arg0)) @@ -430,7 +455,7 @@ internal object RexConverter { TODO("PartiQL Special Form CAN_LOSSLESS_CAST") } - override fun visitExprDateAdd(node: Expr.DateAdd, ctx: Env) = transform { + override fun visitExprDateAdd(node: Expr.DateAdd, ctx: Env): Rex { val type = StaticType.TIMESTAMP // Args val arg0 = visitExpr(node.lhs, ctx) @@ -441,10 +466,10 @@ internal object RexConverter { DatetimeField.TIMEZONE_MINUTE -> error("Invalid call DATE_ADD(TIMEZONE_MINUTE, ...)") else -> call("date_add_${node.field.name.lowercase()}", arg0, arg1) } - rex(type, call) + return rex(type, call) } - override fun visitExprDateDiff(node: Expr.DateDiff, ctx: Env) = transform { + override fun visitExprDateDiff(node: Expr.DateDiff, ctx: Env): Rex { val type = StaticType.TIMESTAMP // Args val arg0 = visitExpr(node.lhs, ctx) @@ -455,10 +480,10 @@ internal object RexConverter { DatetimeField.TIMEZONE_MINUTE -> error("Invalid call DATE_DIFF(TIMEZONE_MINUTE, ...)") else -> call("date_diff_${node.field.name.lowercase()}", arg0, arg1) } - rex(type, call) + return rex(type, call) } - override fun visitExprSessionAttribute(node: Expr.SessionAttribute, ctx: Env) = transform { + override fun visitExprSessionAttribute(node: Expr.SessionAttribute, ctx: Env): Rex { val type = StaticType.ANY val attribute = node.attribute.name.uppercase() val fn = ATTRIBUTES[attribute] @@ -467,7 +492,7 @@ internal object RexConverter { error("Unknown session attribute $attribute") } val call = call(fn) - rex(type, call) + return rex(type, call) } /** @@ -479,9 +504,9 @@ internal object RexConverter { * - It is the collection expression of a FROM clause * - It is the RHS of an IN predicate */ - override fun visitExprSFW(node: Expr.SFW, context: Env): Rex = transform { + override fun visitExprSFW(node: Expr.SFW, context: Env): Rex { val query = RelConverter.apply(node, context) - when (val select = query.op) { + return when (val select = query.op) { is Rex.Op.Select -> { if (node.select is Select.Value) { // SELECT VALUE does not implicitly coerce to a scalar @@ -500,11 +525,11 @@ internal object RexConverter { private fun bool(v: Boolean): Rex { val type = StaticType.BOOL - val op = Plan.rexOpLit(boolValue(v)) - return Plan.rex(type, op) + val op = rexOpLit(boolValue(v)) + return rex(type, op) } - private fun PlanFactory.negate(call: Rex.Op.Call): Rex.Op.Call { + private fun negate(call: Rex.Op.Call): Rex.Op.Call { val name = Expr.Unary.Op.NOT.name val id = identifierSymbol(name.lowercase(), Identifier.CaseSensitivity.SENSITIVE) val fn = fnUnresolved(id) @@ -514,7 +539,7 @@ internal object RexConverter { return rexOpCall(fn, listOf(arg)) } - private fun PlanFactory.call(name: String, vararg args: Rex): Rex.Op.Call { + private fun call(name: String, vararg args: Rex): Rex.Op.Call { val id = identifierSymbol(name, Identifier.CaseSensitivity.SENSITIVE) val fn = fnUnresolved(id) return rexOpCall(fn, args.toList()) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt index 9d854c85a9..48148a5e66 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt @@ -1,3 +1,19 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + package org.partiql.planner.typer import org.partiql.errors.Problem @@ -5,10 +21,35 @@ import org.partiql.errors.ProblemCallback import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION import org.partiql.plan.Fn import org.partiql.plan.Identifier -import org.partiql.plan.Plan import org.partiql.plan.Rel import org.partiql.plan.Rex import org.partiql.plan.Statement +import org.partiql.plan.fnResolved +import org.partiql.plan.identifierSymbol +import org.partiql.plan.rel +import org.partiql.plan.relBinding +import org.partiql.plan.relOpErr +import org.partiql.plan.relOpFilter +import org.partiql.plan.relOpJoin +import org.partiql.plan.relOpLimit +import org.partiql.plan.relOpOffset +import org.partiql.plan.relOpProject +import org.partiql.plan.relOpScan +import org.partiql.plan.relOpUnpivot +import org.partiql.plan.relType +import org.partiql.plan.rex +import org.partiql.plan.rexOpCall +import org.partiql.plan.rexOpCollection +import org.partiql.plan.rexOpErr +import org.partiql.plan.rexOpGlobal +import org.partiql.plan.rexOpPath +import org.partiql.plan.rexOpPathStepSymbol +import org.partiql.plan.rexOpSelect +import org.partiql.plan.rexOpStruct +import org.partiql.plan.rexOpStructField +import org.partiql.plan.rexOpTupleUnion +import org.partiql.plan.rexOpVarResolved +import org.partiql.plan.statementQuery import org.partiql.plan.util.PlanRewriter import org.partiql.planner.Env import org.partiql.planner.FnMatch @@ -61,16 +102,9 @@ internal class PlanTyper( strategy = ResolutionStrategy.GLOBAL, ) val root = statement.root.type(typeEnv) - return Plan.statementQuery(root) + return statementQuery(root) } - /** - * Use default factory for rewrites - */ - private val factory = Plan - - private inline fun rewrite(block: Plan.() -> T): T = block.invoke(factory) - /** * Types the relational operators of a query expression. * @@ -83,7 +117,7 @@ internal class PlanTyper( /** * The output schema of a `rel.op.scan` is the single value binding. */ - override fun visitRelOpScan(node: Rel.Op.Scan, ctx: Rel.Type?): Rel = rewrite { + override fun visitRelOpScan(node: Rel.Op.Scan, ctx: Rel.Type?): Rel { // descend, with GLOBAL resolution strategy val rex = node.rex.type(outer.global()) // compute rel type @@ -91,13 +125,13 @@ internal class PlanTyper( val type = ctx!!.copyWithSchema(listOf(valueT)) // rewrite val op = relOpScan(rex) - rel(type, op) + return rel(type, op) } /** * The output schema of a `rel.op.scan_index` is the value binding and index binding. */ - override fun visitRelOpScanIndexed(node: Rel.Op.ScanIndexed, ctx: Rel.Type?): Rel = rewrite { + override fun visitRelOpScanIndexed(node: Rel.Op.ScanIndexed, ctx: Rel.Type?): Rel { // descend, with GLOBAL resolution strategy val rex = node.rex.type(outer.global()) // compute rel type @@ -106,13 +140,13 @@ internal class PlanTyper( val type = ctx!!.copyWithSchema(listOf(valueT, indexT)) // rewrite val op = relOpScan(rex) - rel(type, op) + return rel(type, op) } /** * TODO handle NULL|STRUCT type */ - override fun visitRelOpUnpivot(node: Rel.Op.Unpivot, ctx: Rel.Type?): Rel = rewrite { + override fun visitRelOpUnpivot(node: Rel.Op.Unpivot, ctx: Rel.Type?): Rel { // descend, with GLOBAL resolution strategy val rex = node.rex.type(outer.global()) @@ -137,14 +171,14 @@ internal class PlanTyper( // rewrite val op = relOpUnpivot(rex) - rel(type, op) + return rel(type, op) } override fun visitRelOpDistinct(node: Rel.Op.Distinct, ctx: Rel.Type?): Rel { TODO("Type RelOp Distinct") } - override fun visitRelOpFilter(node: Rel.Op.Filter, ctx: Rel.Type?): Rel = rewrite { + override fun visitRelOpFilter(node: Rel.Op.Filter, ctx: Rel.Type?): Rel { // compute input schema val input = visitRel(node.input, ctx) // type sub-nodes @@ -154,7 +188,7 @@ internal class PlanTyper( val type = input.type // rewrite val op = relOpFilter(input, predicate) - rel(type, op) + return rel(type, op) } override fun visitRelOpSort(node: Rel.Op.Sort, ctx: Rel.Type?): Rel { @@ -177,7 +211,7 @@ internal class PlanTyper( TODO("Type RelOp Except") } - override fun visitRelOpLimit(node: Rel.Op.Limit, ctx: Rel.Type?) = rewrite { + override fun visitRelOpLimit(node: Rel.Op.Limit, ctx: Rel.Type?): Rel { // compute input schema val input = visitRel(node.input, ctx) // type limit expression using outer scope with global resolution @@ -189,10 +223,10 @@ internal class PlanTyper( val type = input.type // rewrite val op = relOpLimit(input, limit) - rel(type, op) + return rel(type, op) } - override fun visitRelOpOffset(node: Rel.Op.Offset, ctx: Rel.Type?) = rewrite { + override fun visitRelOpOffset(node: Rel.Op.Offset, ctx: Rel.Type?): Rel { // compute input schema val input = visitRel(node.input, ctx) // type offset expression using outer scope with global resolution @@ -204,10 +238,10 @@ internal class PlanTyper( val type = input.type // rewrite val op = relOpOffset(input, offset) - rel(type, op) + return rel(type, op) } - override fun visitRelOpProject(node: Rel.Op.Project, ctx: Rel.Type?): Rel = rewrite { + override fun visitRelOpProject(node: Rel.Op.Project, ctx: Rel.Type?): Rel { // compute input schema val input = visitRel(node.input, ctx) // type sub-nodes @@ -218,10 +252,10 @@ internal class PlanTyper( val type = ctx!!.copyWithSchema(schema) // rewrite val op = relOpProject(input, projections) - rel(type, op) + return rel(type, op) } - override fun visitRelOpJoin(node: Rel.Op.Join, ctx: Rel.Type?): Rel = rewrite { + override fun visitRelOpJoin(node: Rel.Op.Join, ctx: Rel.Type?): Rel { // Rewrite LHS and RHS val lhs = visitRel(node.lhs, ctx) val rhs = visitRel(node.rhs, ctx) @@ -241,7 +275,7 @@ internal class PlanTyper( val condition = node.rex.type(TypeEnv(type.schema, ResolutionStrategy.LOCAL)) val op = relOpJoin(lhs, rhs, condition, node.type) - rel(type, op) + return rel(type, op) } override fun visitRelOpAggregate(node: Rel.Op.Aggregate, ctx: Rel.Type?): Rel { @@ -270,18 +304,18 @@ internal class PlanTyper( override fun visitRex(node: Rex, ctx: StaticType?): Rex = visitRexOp(node.op, node.type) as Rex - override fun visitRexOpLit(node: Rex.Op.Lit, ctx: StaticType?): Rex = rewrite { + override fun visitRexOpLit(node: Rex.Op.Lit, ctx: StaticType?): Rex { // type comes from RexConverter - rex(ctx!!, node) + return rex(ctx!!, node) } - override fun visitRexOpVarResolved(node: Rex.Op.Var.Resolved, ctx: StaticType?): Rex = rewrite { + override fun visitRexOpVarResolved(node: Rex.Op.Var.Resolved, ctx: StaticType?): Rex { assert(node.ref < locals.schema.size) { "Invalid resolved variable (var ${node.ref}) for $locals" } val type = locals.schema[node.ref].type - rex(type, node) + return rex(type, node) } - override fun visitRexOpVarUnresolved(node: Rex.Op.Var.Unresolved, ctx: StaticType?): Rex = rewrite { + override fun visitRexOpVarUnresolved(node: Rex.Op.Var.Unresolved, ctx: StaticType?): Rex { val path = node.identifier.toBindingPath() val resolvedVar = env.resolve(path, locals, node.scope) @@ -294,19 +328,19 @@ internal class PlanTyper( is ResolvedVar.Global -> rexOpGlobal(resolvedVar.ordinal) is ResolvedVar.Local -> resolvedLocalPath(resolvedVar) } - rex(type, op) + return rex(type, op) } - override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: StaticType?): Rex = rewrite { + override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: StaticType?): Rex { val global = env.globals[node.ref] val type = global.type - rex(type, node) + return rex(type, node) } /** * Match path as far as possible (rewriting the steps), then infer based on resolved root and rewritten steps. */ - override fun visitRexOpPath(node: Rex.Op.Path, ctx: StaticType?): Rex = rewrite { + override fun visitRexOpPath(node: Rex.Op.Path, ctx: StaticType?): Rex { // 1. Resolve path prefix val (root, steps) = when (val rootOp = node.root.op) { is Rex.Op.Var.Unresolved -> { @@ -361,7 +395,7 @@ internal class PlanTyper( } // 5. Non-missing, root is resolved - rex(type, rexOpPath(root, steps)) + return rex(type, rexOpPath(root, steps)) } /** @@ -374,7 +408,7 @@ internal class PlanTyper( * @param ctx * @return */ - override fun visitRexOpCall(node: Rex.Op.Call, ctx: StaticType?): Rex = rewrite { + override fun visitRexOpCall(node: Rex.Op.Call, ctx: StaticType?): Rex { // Already resolved; unreachable but handle gracefully. if (node.fn is Fn.Resolved) return rex(ctx!!, node) @@ -391,11 +425,11 @@ internal class PlanTyper( // 7.1 All functions return MISSING when one of their inputs is MISSING (except `=`) if (missingArg && !isEq) { handleAlwaysMissing() - return@rewrite rex(StaticType.MISSING, rexOpCall(fn, args)) + return rex(StaticType.MISSING, rexOpCall(fn, args)) } // Try to match the arguments to functions defined in the catalog - when (val match = env.resolveFn(fn, args)) { + return when (val match = env.resolveFn(fn, args)) { is FnMatch.Ok -> { // Found a match! @@ -443,10 +477,10 @@ internal class PlanTyper( } } - override fun visitRexOpCase(node: Rex.Op.Case, ctx: StaticType?): Rex = rewrite { + override fun visitRexOpCase(node: Rex.Op.Case, ctx: StaticType?): Rex { val visitedBranches = node.branches.map { visitRexOpCaseBranch(it, null) } val resultTypes = visitedBranches.map { it.rex }.map { it.type } - rex(AnyOfType(resultTypes.toSet()).flatten(), node.copy(branches = visitedBranches)) + return rex(AnyOfType(resultTypes.toSet()).flatten(), node.copy(branches = visitedBranches)) } override fun visitRexOpCaseBranch(node: Rex.Op.Case.Branch, ctx: StaticType?): Rex.Op.Case.Branch { @@ -455,7 +489,7 @@ internal class PlanTyper( return node.copy(condition = visitedCondition, rex = visitedReturn) } - override fun visitRexOpCollection(node: Rex.Op.Collection, ctx: StaticType?): Rex = rewrite { + override fun visitRexOpCollection(node: Rex.Op.Collection, ctx: StaticType?): Rex { if (ctx!! !is CollectionType) { handleUnexpectedType(ctx, setOf(StaticType.LIST, StaticType.BAG, StaticType.SEXP)) return rex(StaticType.NULL_OR_MISSING, rexOpErr("Expected collection type")) @@ -467,11 +501,11 @@ internal class PlanTyper( is ListType -> ListType(t) is SexpType -> SexpType(t) } - rex(type, rexOpCollection(values)) + return rex(type, rexOpCollection(values)) } @OptIn(PartiQLValueExperimental::class) - override fun visitRexOpStruct(node: Rex.Op.Struct, ctx: StaticType?): Rex = rewrite { + override fun visitRexOpStruct(node: Rex.Op.Struct, ctx: StaticType?): Rex { val fields = node.fields.map { val k = visitRex(it.k, null) val v = visitRex(it.v, null) @@ -510,7 +544,7 @@ internal class PlanTyper( TupleConstraint.UniqueAttrs(structKeysSeent.size == fields.size) ), ) - rex(type, rexOpStruct(fields)) + return rex(type, rexOpStruct(fields)) } override fun visitRexOpPivot(node: Rex.Op.Pivot, ctx: StaticType?): Rex { @@ -525,7 +559,7 @@ internal class PlanTyper( TODO("Type RexOpCollToScalarSubquery") } - override fun visitRexOpSelect(node: Rex.Op.Select, ctx: StaticType?): Rex = rewrite { + override fun visitRexOpSelect(node: Rex.Op.Select, ctx: StaticType?): Rex { val rel = node.rel.type(locals) val typeEnv = TypeEnv(rel.type.schema, ResolutionStrategy.LOCAL) var constructor = node.constructor.type(typeEnv) @@ -541,10 +575,10 @@ internal class PlanTyper( true -> ListType(constructor.type) else -> BagType(constructor.type) } - rex(type, rexOpSelect(constructor, rel)) + return rex(type, rexOpSelect(constructor, rel)) } - override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: StaticType?): Rex = rewrite { + override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: StaticType?): Rex { val args = node.args.map { visitTupleUnionArg(it) } val structFields = mutableListOf() var structIsClosed = true @@ -581,7 +615,7 @@ internal class PlanTyper( ), ) val op = rexOpTupleUnion(args) - rex(type, op) + return rex(type, op) } private fun visitTupleUnionArg(node: Rex.Op.TupleUnion.Arg) = when (node) { @@ -630,7 +664,7 @@ internal class PlanTyper( if (step.key.op is Rex.Op.Lit) { val lit = (step.key.op as Rex.Op.Lit).value if (lit is TextValue<*> && !lit.isNull) { - val id = Plan.identifierSymbol(lit.string!!, Identifier.CaseSensitivity.SENSITIVE) + val id = identifierSymbol(lit.string!!, Identifier.CaseSensitivity.SENSITIVE) inferStructLookup(struct, id) } else { error("Expected text literal, but got $lit") @@ -770,7 +804,7 @@ internal class PlanTyper( /** * Rewrites function arguments, wrapping in the given function if exists. */ - private fun Plan.rewriteFnArgs(mapping: List, args: List): List { + private fun rewriteFnArgs(mapping: List, args: List): List { if (mapping.size != args.size) { error("Fatal, malformed function mapping") // should be unreachable given how a mapping is generated. } @@ -798,7 +832,7 @@ internal class PlanTyper( /** * Constructs a Rex.Op.Path from a resolved local */ - private fun Plan.resolvedLocalPath(local: ResolvedVar.Local): Rex.Op.Path { + private fun resolvedLocalPath(local: ResolvedVar.Local): Rex.Op.Path { val root = rex(local.rootType, rexOpVarResolved(local.ordinal)) val steps = local.tail.map { val case = when (it.bindingCase) { @@ -877,7 +911,7 @@ internal class PlanTyper( is StructType -> t.withNullableFields() else -> t.asNullable() } - Plan.relBinding(it.name, type) + relBinding(it.name, type) } private fun StructType.withNullableFields(): StructType {