From 3de4c881a75584aef51a81dd66a7e4814326b631 Mon Sep 17 00:00:00 2001 From: Yingtao Liu Date: Tue, 28 Nov 2023 14:45:38 -0800 Subject: [PATCH] hack around dynamic function --- .../PartiQLSchemaInferencerTests.kt | 7 ++- .../org/partiql/planner/typer/PlanTyper.kt | 59 +++++++++++++------ 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt index b09883d658..8ce6be66f4 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt @@ -3190,12 +3190,13 @@ class PartiQLSchemaInferencerTests { query = "order_info.CUSTOMER_ID = 1", expected = TYPE_BOOL ), + // MISSING = 1 ErrorTestCase( name = "Case Sensitive failure", catalog = CATALOG_DB, catalogPath = DB_SCHEMA_MARKETS, query = "order_info.\"CUSTOMER_ID\" = 1", - expected = TYPE_BOOL + expected = NULL ), SuccessTestCase( name = "Case Sensitive success", @@ -3209,14 +3210,14 @@ class PartiQLSchemaInferencerTests { catalog = CATALOG_DB, catalogPath = DB_SCHEMA_MARKETS, query = "(order_info.customer_id = 1) AND (order_info.marketplace_id = 2)", - expected = TYPE_BOOL + expected = StaticType.unionOf(BOOL, NULL) ), SuccessTestCase( name = "2-Level Junction", catalog = CATALOG_DB, catalogPath = DB_SCHEMA_MARKETS, query = "(order_info.customer_id = 1) AND (order_info.marketplace_id = 2) OR (order_info.customer_id = 3) AND (order_info.marketplace_id = 4)", - expected = TYPE_BOOL + expected = StaticType.unionOf(BOOL, NULL) ), SuccessTestCase( name = "INT and STR Comparison", diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt index 89e6f9cc2d..e8e15e0a23 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt @@ -687,26 +687,49 @@ internal class PlanTyper( * currently limiting the scope of this intentionally. */ private fun foldCaseBranch(condition: Rex, result: Rex): Rex.Op.Case.Branch { - val call = condition.op as? Rex.Op.Call.Static ?: return rexOpCaseBranch(condition, result) - val fn = call.fn as? Fn.Resolved ?: return rexOpCaseBranch(condition, result) - if (fn.signature.name.equals("is_struct", ignoreCase = true).not()) { - return rexOpCaseBranch(condition, result) - } - val ref = call.args.getOrNull(0) ?: error("IS STRUCT requires an argument.") - val simplifiedCondition = when { - ref.type.allTypes.all { it is StructType } -> rex(StaticType.BOOL, rexOpLit(boolValue(true))) - ref.type.allTypes.none { it is StructType } -> rex(StaticType.BOOL, rexOpLit(boolValue(false))) - else -> condition - } + val call = condition.op as? Rex.Op.Call ?: return rexOpCaseBranch(condition, result) + when (call) { + is Rex.Op.Call.Dynamic -> { + val rex = call.candidates.map { candidate -> + val fn = candidate.fn as? Fn.Resolved ?: return rexOpCaseBranch(condition, result) + if (fn.signature.name.equals("is_struct", ignoreCase = true).not()) { + return rexOpCaseBranch(condition, result) + } + val ref = call.args.getOrNull(0) ?: error("IS STRUCT requires an argument.") + // Replace the result's type + val type = AnyOfType(ref.type.allTypes.filterIsInstance().toSet()) + val replacementVal = ref.copy(type = type) + when (ref.op is Rex.Op.Var.Resolved) { + true -> RexReplacer.replace(result, ref, replacementVal) + false -> result + } + } + val type = rex.toUnionType().flatten() - // Replace the result's type - val type = AnyOfType(ref.type.allTypes.filterIsInstance().toSet()) - val replacementVal = ref.copy(type = type) - val rex = when (ref.op is Rex.Op.Var.Resolved) { - true -> RexReplacer.replace(result, ref, replacementVal) - false -> result + return rexOpCaseBranch(condition, result.copy(type)) + } + is Rex.Op.Call.Static -> { + val fn = call.fn as? Fn.Resolved ?: return rexOpCaseBranch(condition, result) + if (fn.signature.name.equals("is_struct", ignoreCase = true).not()) { + return rexOpCaseBranch(condition, result) + } + val ref = call.args.getOrNull(0) ?: error("IS STRUCT requires an argument.") + val simplifiedCondition = when { + ref.type.allTypes.all { it is StructType } -> rex(StaticType.BOOL, rexOpLit(boolValue(true))) + ref.type.allTypes.none { it is StructType } -> rex(StaticType.BOOL, rexOpLit(boolValue(false))) + else -> condition + } + + // Replace the result's type + val type = AnyOfType(ref.type.allTypes.filterIsInstance().toSet()) + val replacementVal = ref.copy(type = type) + val rex = when (ref.op is Rex.Op.Var.Resolved) { + true -> RexReplacer.replace(result, ref, replacementVal) + false -> result + } + return rexOpCaseBranch(simplifiedCondition, rex) + } } - return rexOpCaseBranch(simplifiedCondition, rex) } override fun visitRexOpCollection(node: Rex.Op.Collection, ctx: StaticType?): Rex {