diff --git a/partiql-ast/api/partiql-ast.api b/partiql-ast/api/partiql-ast.api index 347351147..c21dc12ad 100644 --- a/partiql-ast/api/partiql-ast.api +++ b/partiql-ast/api/partiql-ast.api @@ -102,7 +102,7 @@ public final class org/partiql/ast/Ast { public static final fun selectProjectItemAll (Lorg/partiql/ast/Expr;)Lorg/partiql/ast/Select$Project$Item$All; public static final fun selectProjectItemExpression (Lorg/partiql/ast/Expr;Lorg/partiql/ast/Binder;)Lorg/partiql/ast/Select$Project$Item$Expression; public static final fun selectStar (Lorg/partiql/ast/SetQuantifier;)Lorg/partiql/ast/Select$Star; - public static final fun selectValue (Lorg/partiql/ast/Expr;Lorg/partiql/ast/SetQuantifier;)Lorg/partiql/ast/Select$Value; + public static final fun selectValue (Lorg/partiql/ast/Expr;Lorg/partiql/ast/SetQuantifier;Z)Lorg/partiql/ast/Select$Value; public static final fun setOp (Lorg/partiql/ast/SetOp$Type;Lorg/partiql/ast/SetQuantifier;)Lorg/partiql/ast/SetOp; public static final fun sort (Lorg/partiql/ast/Expr;Lorg/partiql/ast/Sort$Dir;Lorg/partiql/ast/Sort$Nulls;)Lorg/partiql/ast/Sort; public static final fun statementDDL (Lorg/partiql/ast/DdlOp;)Lorg/partiql/ast/Statement$DDL; @@ -2589,15 +2589,17 @@ public final class org/partiql/ast/Select$Star$Companion { public final class org/partiql/ast/Select$Value : org/partiql/ast/Select { public static final field Companion Lorg/partiql/ast/Select$Value$Companion; + public final field coercible Z public final field constructor Lorg/partiql/ast/Expr; public final field setq Lorg/partiql/ast/SetQuantifier; - public fun (Lorg/partiql/ast/Expr;Lorg/partiql/ast/SetQuantifier;)V + public fun (Lorg/partiql/ast/Expr;Lorg/partiql/ast/SetQuantifier;Z)V public fun accept (Lorg/partiql/ast/visitor/AstVisitor;Ljava/lang/Object;)Ljava/lang/Object; public static final fun builder ()Lorg/partiql/ast/builder/SelectValueBuilder; public final fun component1 ()Lorg/partiql/ast/Expr; public final fun component2 ()Lorg/partiql/ast/SetQuantifier; - public final fun copy (Lorg/partiql/ast/Expr;Lorg/partiql/ast/SetQuantifier;)Lorg/partiql/ast/Select$Value; - public static synthetic fun copy$default (Lorg/partiql/ast/Select$Value;Lorg/partiql/ast/Expr;Lorg/partiql/ast/SetQuantifier;ILjava/lang/Object;)Lorg/partiql/ast/Select$Value; + public final fun component3 ()Z + public final fun copy (Lorg/partiql/ast/Expr;Lorg/partiql/ast/SetQuantifier;Z)Lorg/partiql/ast/Select$Value; + public static synthetic fun copy$default (Lorg/partiql/ast/Select$Value;Lorg/partiql/ast/Expr;Lorg/partiql/ast/SetQuantifier;ZILjava/lang/Object;)Lorg/partiql/ast/Select$Value; public fun equals (Ljava/lang/Object;)Z public fun getChildren ()Ljava/util/List; public fun hashCode ()I @@ -4211,8 +4213,8 @@ public final class org/partiql/ast/builder/AstBuilder { public static synthetic fun selectProjectItemExpression$default (Lorg/partiql/ast/builder/AstBuilder;Lorg/partiql/ast/Expr;Lorg/partiql/ast/Binder;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/ast/Select$Project$Item$Expression; public final fun selectStar (Lorg/partiql/ast/SetQuantifier;Lkotlin/jvm/functions/Function1;)Lorg/partiql/ast/Select$Star; public static synthetic fun selectStar$default (Lorg/partiql/ast/builder/AstBuilder;Lorg/partiql/ast/SetQuantifier;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/ast/Select$Star; - public final fun selectValue (Lorg/partiql/ast/Expr;Lorg/partiql/ast/SetQuantifier;Lkotlin/jvm/functions/Function1;)Lorg/partiql/ast/Select$Value; - public static synthetic fun selectValue$default (Lorg/partiql/ast/builder/AstBuilder;Lorg/partiql/ast/Expr;Lorg/partiql/ast/SetQuantifier;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/ast/Select$Value; + public final fun selectValue (Lorg/partiql/ast/Expr;Lorg/partiql/ast/SetQuantifier;Ljava/lang/Boolean;Lkotlin/jvm/functions/Function1;)Lorg/partiql/ast/Select$Value; + public static synthetic fun selectValue$default (Lorg/partiql/ast/builder/AstBuilder;Lorg/partiql/ast/Expr;Lorg/partiql/ast/SetQuantifier;Ljava/lang/Boolean;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/ast/Select$Value; public final fun setOp (Lorg/partiql/ast/SetOp$Type;Lorg/partiql/ast/SetQuantifier;Lkotlin/jvm/functions/Function1;)Lorg/partiql/ast/SetOp; public static synthetic fun setOp$default (Lorg/partiql/ast/builder/AstBuilder;Lorg/partiql/ast/SetOp$Type;Lorg/partiql/ast/SetQuantifier;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/ast/SetOp; public final fun sort (Lorg/partiql/ast/Expr;Lorg/partiql/ast/Sort$Dir;Lorg/partiql/ast/Sort$Nulls;Lkotlin/jvm/functions/Function1;)Lorg/partiql/ast/Sort; @@ -5628,12 +5630,15 @@ public final class org/partiql/ast/builder/SelectStarBuilder { public final class org/partiql/ast/builder/SelectValueBuilder { public fun ()V - public fun (Lorg/partiql/ast/Expr;Lorg/partiql/ast/SetQuantifier;)V - public synthetic fun (Lorg/partiql/ast/Expr;Lorg/partiql/ast/SetQuantifier;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lorg/partiql/ast/Expr;Lorg/partiql/ast/SetQuantifier;Ljava/lang/Boolean;)V + public synthetic fun (Lorg/partiql/ast/Expr;Lorg/partiql/ast/SetQuantifier;Ljava/lang/Boolean;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun build ()Lorg/partiql/ast/Select$Value; + public final fun coercible (Ljava/lang/Boolean;)Lorg/partiql/ast/builder/SelectValueBuilder; public final fun constructor (Lorg/partiql/ast/Expr;)Lorg/partiql/ast/builder/SelectValueBuilder; + public final fun getCoercible ()Ljava/lang/Boolean; public final fun getConstructor ()Lorg/partiql/ast/Expr; public final fun getSetq ()Lorg/partiql/ast/SetQuantifier; + public final fun setCoercible (Ljava/lang/Boolean;)V public final fun setConstructor (Lorg/partiql/ast/Expr;)V public final fun setSetq (Lorg/partiql/ast/SetQuantifier;)V public final fun setq (Lorg/partiql/ast/SetQuantifier;)Lorg/partiql/ast/builder/SelectValueBuilder; diff --git a/partiql-ast/src/main/resources/partiql_ast.ion b/partiql-ast/src/main/resources/partiql_ast.ion index 3c375836c..d4fee4542 100644 --- a/partiql-ast/src/main/resources/partiql_ast.ion +++ b/partiql-ast/src/main/resources/partiql_ast.ion @@ -611,6 +611,7 @@ select::[ value::{ constructor: expr, setq: optional::set_quantifier, + coercible: bool }, ] diff --git a/partiql-ast/src/test/kotlin/org/partiql/ast/helpers/ToLegacyAstTest.kt b/partiql-ast/src/test/kotlin/org/partiql/ast/helpers/ToLegacyAstTest.kt index 79a0b47f5..b778613e0 100644 --- a/partiql-ast/src/test/kotlin/org/partiql/ast/helpers/ToLegacyAstTest.kt +++ b/partiql-ast/src/test/kotlin/org/partiql/ast/helpers/ToLegacyAstTest.kt @@ -570,6 +570,7 @@ class ToLegacyAstTest { expect("(project_value (lit null))") { selectValue { constructor = NULL + coercible = false } }, // FROM_SOURCE Variants diff --git a/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt b/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt index fe9c61dc6..57b85a8d6 100644 --- a/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt +++ b/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt @@ -1073,6 +1073,7 @@ class SqlDialectTest { exprSFW { select = selectValue { constructor = v("a") + coercible = false } from = table("T") } @@ -1082,6 +1083,7 @@ class SqlDialectTest { select = selectValue { setq = SetQuantifier.ALL constructor = v("a") + coercible = false } from = table("T") } @@ -1091,6 +1093,7 @@ class SqlDialectTest { select = selectValue { setq = SetQuantifier.DISTINCT constructor = v("a") + coercible = false } from = table("T") } diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt index bfb62b750..04e816dbf 100644 --- a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt @@ -1071,7 +1071,7 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitSelectValue(ctx: GeneratedParser.SelectValueContext) = translate(ctx) { val constructor = visitExpr(ctx.expr()) val setq = convertSetQuantifier(ctx.setQuantifierStrategy()) - selectValue(constructor, setq) + selectValue(constructor, setq, false) } override fun visitProjectionItem(ctx: GeneratedParser.ProjectionItemContext) = translate(ctx) { diff --git a/partiql-planner/api/partiql-planner.api b/partiql-planner/api/partiql-planner.api index d1f7ddd52..a8abf405f 100644 --- a/partiql-planner/api/partiql-planner.api +++ b/partiql-planner/api/partiql-planner.api @@ -37,7 +37,9 @@ public final class org/partiql/planner/PartiQLPlannerBuilder { public final fun addPass (Lorg/partiql/planner/PartiQLPlannerPass;)Lorg/partiql/planner/PartiQLPlannerBuilder; public final fun addPasses ([Lorg/partiql/planner/PartiQLPlannerPass;)Lorg/partiql/planner/PartiQLPlannerBuilder; public final fun build ()Lorg/partiql/planner/PartiQLPlanner; + public final fun casePreserve ()Lorg/partiql/planner/PartiQLPlannerBuilder; public final fun catalogs ([Lkotlin/Pair;)Lorg/partiql/planner/PartiQLPlannerBuilder; + public final fun lookUpBehavior (Ljava/lang/String;)Lorg/partiql/planner/PartiQLPlannerBuilder; public final fun passes (Ljava/util/List;)Lorg/partiql/planner/PartiQLPlannerBuilder; public final fun signalMode ()Lorg/partiql/planner/PartiQLPlannerBuilder; } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt index 9e5b72930..4e02d59d2 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt @@ -56,6 +56,10 @@ public interface PartiQLPlanner { public fun builder(): PartiQLPlannerBuilder = PartiQLPlannerBuilder() @JvmStatic - public fun default(): PartiQLPlanner = PartiQLPlannerBuilder().build() + public fun default(): PartiQLPlanner = + PartiQLPlannerBuilder() + .casePreserve() + .lookUpBehavior("INSENSITIVE") + .build() } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerBuilder.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerBuilder.kt index 7af3c275f..92cf986ab 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerBuilder.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerBuilder.kt @@ -1,7 +1,9 @@ package org.partiql.planner +import org.partiql.planner.internal.BooleanFlag import org.partiql.planner.internal.PartiQLPlannerDefault import org.partiql.planner.internal.PlannerFlag +import org.partiql.planner.internal.RValue import org.partiql.spi.connector.ConnectorMetadata /** @@ -50,7 +52,24 @@ public class PartiQLPlannerBuilder { * Java style method for setting the planner to signal mode */ public fun signalMode(): PartiQLPlannerBuilder = this.apply { - this.flags.add(PlannerFlag.SIGNAL_MODE) + this.flags.add(BooleanFlag.SIGNAL_MODE) + } + + public fun casePreserve(): PartiQLPlannerBuilder = this.apply { + this.flags.add(BooleanFlag.CASE_PRESERVATION) + } + + public fun lookUpBehavior(behavior: String): PartiQLPlannerBuilder = this.apply { + this.flags.removeAll( + this.flags.filterIsInstance().toSet() + ) + when (behavior.uppercase()) { + "FOLDING_UP" -> this.flags.add(RValue.FOLDING_UP) + "FOLDING_DOWN" -> this.flags.add(RValue.FOLDING_DOWN) + "SENSITIVE" -> this.flags.add(RValue.SENSITIVE) + "INSENSITIVE" -> this.flags.add(RValue.INSENSITIVE) + else -> error("Illegal flag, expect one of ${RValue.values()}") + } } /** diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlannerFlag.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlannerFlag.kt index bd91be053..ed62fa390 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlannerFlag.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlannerFlag.kt @@ -1,6 +1,9 @@ package org.partiql.planner.internal -internal enum class PlannerFlag { +// Marker interface +internal interface PlannerFlag + +internal enum class BooleanFlag : PlannerFlag { /** * Determine the planner behavior upon encounter an operation that always returns MISSING. * @@ -16,5 +19,38 @@ internal enum class PlannerFlag { * * The result plan will turn the problematic operation into a missing node. */ - SIGNAL_MODE + SIGNAL_MODE, + + /** + * Determines lvalue behavior + * + * If this flag is included: + * Lvalue (AS binding, DDL, etc) are case-sensitive + * + * If not included: + * Lvalue are normalized using an implementation defined normalization rule + */ + CASE_PRESERVATION +} + +internal enum class RValue : PlannerFlag { + /** + * Using upper case text to look up + */ + FOLDING_UP, + + /** + * Using lower case text to look up + */ + FOLDING_DOWN, + + /** + * Using original text to look up + */ + SENSITIVE, + + /** + * Match behavior: text comparison with case ignored. + */ + INSENSITIVE } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/astPasses/Normalize.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/astPasses/Normalize.kt index 9465005b0..1eb2508bd 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/astPasses/Normalize.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/astPasses/Normalize.kt @@ -24,5 +24,6 @@ internal fun Statement.normalize(): Statement { var ast = this ast = NormalizeFromSource.apply(ast) ast = NormalizeGroupBy.apply(ast) + ast = NormalizeSelect.apply(ast) return ast } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/astPasses/NormalizeIdentifier.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/astPasses/NormalizeIdentifier.kt new file mode 100644 index 000000000..7b7dbdd61 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/astPasses/NormalizeIdentifier.kt @@ -0,0 +1,23 @@ +package org.partiql.planner.internal.astPasses + +import org.partiql.planner.internal.RValue + +/** + * Case Preservation: + * - If on, turn case-insensitive binder to case-sensitive binder, with symbol preserve the original case + * - If off, normalize the binder. + * - The normalization function in the spec perhaps will end up being implementation defined. + * - Therefore, for now we normalize by preserving case. + * - This is to stay consistent with the current behavior, and reduce one moving element during testing.... + * + * Identifier Normalization: + * - Determines the normalization for identifier (variable look up) + * - Folding down: Using lower case text to look up + * - Folding up: Using upper case text to look up. + * - Case Sensitive: Using original text to look up + * - Case Insensitive: Matching behavior, string comparison with case ignored. + */ +internal class NormalizeIdentifier( + val casePreservation: Boolean, + val rValue: RValue +) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/astPasses/NormalizeSelect.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/astPasses/NormalizeSelect.kt index 0d0afae25..e4a9609e7 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/astPasses/NormalizeSelect.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/astPasses/NormalizeSelect.kt @@ -20,6 +20,7 @@ import org.partiql.ast.From import org.partiql.ast.GroupBy import org.partiql.ast.Identifier import org.partiql.ast.Select +import org.partiql.ast.Statement import org.partiql.ast.binder import org.partiql.ast.exprCall import org.partiql.ast.exprCase @@ -30,6 +31,7 @@ import org.partiql.ast.exprStruct import org.partiql.ast.exprStructField import org.partiql.ast.exprVar import org.partiql.ast.identifierSymbol +import org.partiql.ast.normalize.AstPass import org.partiql.ast.selectProject import org.partiql.ast.selectProjectItemExpression import org.partiql.ast.selectValue @@ -82,25 +84,15 @@ import org.partiql.value.stringValue * } FROM A AS x * ``` * - * NOTE: This does NOT transform subqueries. It operates directly on an [Expr.SFW] -- and that is it. Therefore: - * ``` - * SELECT - * (SELECT 1 FROM T AS "T") - * FROM R AS "R" - * ``` - * will be transformed to: - * ``` - * SELECT VALUE { - * '_1': (SELECT 1 FROM T AS "T") -- notice that SELECT 1 didn't get transformed. - * } FROM R AS "R" - * ``` - * * Requires [NormalizeFromSource]. */ -internal object NormalizeSelect { +internal object NormalizeSelect : AstPass { - internal fun normalize(node: Expr.SFW): Expr.SFW { - return Visitor.visitSFW(node, newCtx()) + override fun apply(statement: Statement): Statement { + return when (statement) { + is Statement.Query -> Visitor.visitStatementQuery(statement, newCtx()) + else -> statement + } } /** @@ -138,7 +130,10 @@ internal object NormalizeSelect { */ private val col = { index: Int -> "_${index + 1}" } - internal fun visitSFW(node: Expr.SFW, ctx: () -> Int): Expr.SFW { + override fun visitStatementQuery(node: Statement.Query, ctx: () -> Int) = + super.visitStatementQuery(node, ctx) as Statement.Query + + override fun visitExprSFW(node: Expr.SFW, ctx: () -> Int): Expr.SFW { val sfw = super.visitExprSFW(node, ctx) as Expr.SFW return when (val select = sfw.select) { is Select.Star -> { @@ -152,10 +147,6 @@ internal object NormalizeSelect { } } - override fun visitExprSFW(node: Expr.SFW, ctx: () -> Int): Expr.SFW { - return node - } - override fun visitSelectProject(node: Select.Project, ctx: () -> Int): Select.Value { // Visit items, adding a binder if necessary @@ -221,7 +212,8 @@ internal object NormalizeSelect { args = tupleUnionArgs, setq = null // setq = null for scalar fn ), - setq = select.setq + setq = select.setq, + coercible = false ) } @@ -240,7 +232,8 @@ internal object NormalizeSelect { val constructor = exprStruct(fields) return selectValue( constructor = constructor, - setq = select.setq + setq = select.setq, + coercible = false ) } @@ -261,7 +254,8 @@ internal object NormalizeSelect { function = identifierSymbol("TUPLEUNION", Identifier.CaseSensitivity.SENSITIVE), args = tupleUnionArgs, setq = null // setq = null for scalar fn - ) + ), + coercible = false ) } @@ -278,7 +272,8 @@ internal object NormalizeSelect { setq = node.setq, constructor = exprStruct( fields = structFields - ) + ), + coercible = true ) } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt index 0d0c7de20..196b0ee78 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt @@ -5,6 +5,7 @@ import org.partiql.plan.PlanNode import org.partiql.plan.partiQLPlan import org.partiql.plan.rexOpCast import org.partiql.plan.rexOpErr +import org.partiql.planner.internal.BooleanFlag import org.partiql.planner.internal.PlannerFlag import org.partiql.planner.internal.ProblemGenerator import org.partiql.planner.internal.ir.Identifier @@ -27,7 +28,7 @@ import org.partiql.value.PartiQLValueExperimental internal class PlanTransform( flags: Set ) { - private val signalMode = flags.contains(PlannerFlag.SIGNAL_MODE) + private val signalMode = flags.contains(BooleanFlag.SIGNAL_MODE) fun transform(node: PartiQLPlan, onProblem: ProblemCallback): org.partiql.plan.PartiQLPlan { val symbols = Symbols.empty() diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt index b16c0944e..f79de19dc 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt @@ -34,7 +34,6 @@ import org.partiql.ast.identifierSymbol import org.partiql.ast.util.AstRewriter import org.partiql.ast.visitor.AstBaseVisitor import org.partiql.planner.internal.Env -import org.partiql.planner.internal.astPasses.NormalizeSelect import org.partiql.planner.internal.ir.Rel import org.partiql.planner.internal.ir.Rex import org.partiql.planner.internal.ir.rel @@ -90,9 +89,8 @@ internal object RelConverter { * Here we convert an SFW to composed [Rel]s, then apply the appropriate relation-value projection to get a [Rex]. */ internal fun apply(sfw: Expr.SFW, env: Env): Rex { - val normalizedSfw = NormalizeSelect.normalize(sfw) - val rel = normalizedSfw.accept(ToRel(env), nil) - val rex = when (val projection = normalizedSfw.select) { + val rel = sfw.accept(ToRel(env), nil) + val rex = when (val projection = sfw.select) { // PIVOT ... FROM is Select.Pivot -> { val key = projection.key.toRex(env) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt index 330397672..eba4a6e22 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt @@ -120,20 +120,20 @@ internal object RexConverter { */ internal fun visitExprCoerce(node: Expr, ctx: Env, coercion: Rex.Op.Subquery.Coercion = Rex.Op.Subquery.Coercion.SCALAR): Rex { val rex = super.visitExpr(node, ctx) - return when (isSqlSelect(node)) { - true -> { - val select = rex.op as Rex.Op.Select - rex( - CompilerType(PType.typeDynamic()), - rexOpSubquery( - constructor = select.constructor, - rel = select.rel, - coercion = coercion - ) + // Only need to look into the coercible flag if this is a project value + val sfw = node as? Expr.SFW ?: return rex + val select = node.select as? Select.Value ?: return rex + return if (select.coercible) { + val selectRex = rex.op as Rex.Op.Select + rex( + CompilerType(PType.typeDynamic()), + rexOpSubquery( + constructor = selectRex.constructor, + rel = selectRex.rel, + coercion = coercion ) - } - false -> rex - } + ) + } else rex } override fun visitExprVar(node: Expr.Var, context: Env): Rex { diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/transforms/NormalizeSelectTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/transforms/NormalizeSelectTest.kt index 8e1c60b33..224cbfb9a 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/transforms/NormalizeSelectTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/transforms/NormalizeSelectTest.kt @@ -6,12 +6,14 @@ import org.partiql.ast.Expr import org.partiql.ast.From import org.partiql.ast.Identifier import org.partiql.ast.Select +import org.partiql.ast.Statement import org.partiql.ast.binder import org.partiql.ast.builder.ast import org.partiql.ast.exprLit import org.partiql.ast.exprVar import org.partiql.ast.identifierSymbol import org.partiql.ast.selectProjectItemExpression +import org.partiql.ast.statementQuery import org.partiql.planner.internal.astPasses.NormalizeSelect import org.partiql.value.PartiQLValueExperimental import org.partiql.value.int32Value @@ -36,13 +38,12 @@ class NormalizeSelectTest { varItem("b"), varItem("c"), ) - val expected = selectValue( + val expected = selectValueFromSelect( "a" to variable("a"), "b" to variable("b"), "c" to variable("c"), ) - val actual = NormalizeSelect.normalize(input) - assertEquals(expected, actual) + assertEqualsShim(expected, input) } /** @@ -61,13 +62,12 @@ class NormalizeSelectTest { litItem(2), litItem(3), ) - val expected = selectValue( + val expected = selectValueFromSelect( "_1" to lit(1), "_2" to lit(2), "_3" to lit(3), ) - val actual = NormalizeSelect.normalize(input) - assertEquals(expected, actual) + assertEqualsShim(expected, input) } /** @@ -86,13 +86,12 @@ class NormalizeSelectTest { litItem(2), litItem(3), ) - val expected = selectValue( + val expected = selectValueFromSelect( "a" to variable("a"), "_1" to lit(2), "_2" to lit(3), ) - val actual = NormalizeSelect.normalize(input) - assertEquals(expected, actual) + assertEqualsShim(expected, input) } /** @@ -111,13 +110,13 @@ class NormalizeSelectTest { litItem(2, "b"), litItem(3, "c"), ) - val expected = selectValue( + val expected = selectValueFromSelect( "a" to variable("a"), "b" to lit(2), "c" to lit(3), + ) - val actual = NormalizeSelect.normalize(input) - assertEquals(expected, actual) + assertEqualsShim(expected, input) } // ----- HELPERS ------------------------- @@ -143,7 +142,7 @@ class NormalizeSelectTest { } @OptIn(PartiQLValueExperimental::class) - private fun selectValue(vararg items: Pair) = ast { + private fun selectValueFromSelect(vararg items: Pair) = ast { exprSFW { select = selectValue { constructor = exprStruct { @@ -154,6 +153,7 @@ class NormalizeSelectTest { } } } + coercible = true } from = fromValue { expr = exprVar { @@ -180,4 +180,10 @@ class NormalizeSelectTest { @OptIn(PartiQLValueExperimental::class) private fun lit(value: Int) = exprLit(int32Value(value)) + + private fun assertEqualsShim(expected: Expr.SFW, input: Expr.SFW) { + val normalizedStatement = NormalizeSelect.apply(statementQuery(input)) as Statement.Query + val normalizedSFW = normalizedStatement.expr + assertEquals(expected, normalizedSFW) + } }