From f01d5622788735aa2e13ebddc0c95f27df32959a Mon Sep 17 00:00:00 2001 From: Yingtao Liu Date: Tue, 31 Oct 2023 13:29:58 -0700 Subject: [PATCH 1/6] add coalesce to header --- .../org/partiql/planner/PartiQLHeader.kt | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt index 02ab2411c9..5565077c3d 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt @@ -4,17 +4,7 @@ import org.partiql.ast.DatetimeField import org.partiql.types.function.FunctionParameter import org.partiql.types.function.FunctionSignature import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType.ANY -import org.partiql.value.PartiQLValueType.BOOL -import org.partiql.value.PartiQLValueType.CHAR -import org.partiql.value.PartiQLValueType.DATE -import org.partiql.value.PartiQLValueType.DECIMAL -import org.partiql.value.PartiQLValueType.INT -import org.partiql.value.PartiQLValueType.INT32 -import org.partiql.value.PartiQLValueType.INT64 -import org.partiql.value.PartiQLValueType.STRING -import org.partiql.value.PartiQLValueType.TIME -import org.partiql.value.PartiQLValueType.TIMESTAMP +import org.partiql.value.PartiQLValueType.* /** * A header which uses the PartiQL Lang Kotlin default standard library. All functions exist in a global namespace. @@ -326,8 +316,20 @@ object PartiQLHeader : Header() { ) } - // TODO - private fun coalesce(): List = emptyList() + // COALESCE(expression, expression, ... ) + // Initial implementation of Coalesce. + // As the number of parameter can not be pre-determined, we wrap those into a list + private fun coalesce(): List = listOf( + FunctionSignature.Scalar( + name = "coalesce", + returns = ANY, + parameters = listOf( + FunctionParameter("values", LIST) + ), + isNullCall = false, + isNullable = true, + ) + ) // NULLIF(x, y) private fun nullIf(): List = types.nullable.map { t -> From 4946ab2e0136c6b8e229aef887a30496134238ee Mon Sep 17 00:00:00 2001 From: Yingtao Liu Date: Wed, 1 Nov 2023 12:37:46 -0700 Subject: [PATCH 2/6] runtime type resolution --- partiql-lang/build.gradle.kts | 1 + .../main/kotlin/org/partiql/planner/Header.kt | 2 +- .../org/partiql/planner/PartiQLHeader.kt | 215 +++++++++++++++--- 3 files changed, 190 insertions(+), 28 deletions(-) diff --git a/partiql-lang/build.gradle.kts b/partiql-lang/build.gradle.kts index 7080fc3b70..0c3da19acb 100644 --- a/partiql-lang/build.gradle.kts +++ b/partiql-lang/build.gradle.kts @@ -85,5 +85,6 @@ tasks.processTestResources { tasks.shadowJar { archiveBaseName.set("shadow") + exclude("**/*.kotlin_metadata") archiveClassifier.set("") } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt index 257be92466..399e017a7b 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt @@ -43,7 +43,7 @@ public abstract class Header { * For functions, output CREATE FUNCTION statements. */ override fun toString(): String = buildString { - functions.groupBy { it.name }.forEach { + (functions + operators + aggregations).groupBy { it.name }.forEach { appendLine("-- [${it.key}] ---------") appendLine() it.value.forEach { fn -> appendLine(fn) } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt index 5565077c3d..3434355447 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt @@ -105,15 +105,15 @@ object PartiQLHeader : Header() { // OPERATORS - private fun not(): List = listOf(unary("not", BOOL, BOOL)) + private fun not(): List = listOf(unary("not", BOOL, BOOL), unary("not", ANY, ANY)) private fun pos(): List = types.numeric.map { t -> unary("pos", t, t) - } + } + listOf(unary("pos", ANY, ANY)) private fun neg(): List = types.numeric.map { t -> unary("neg", t, t) - } + } + listOf(unary("pos", ANY, ANY)) private fun eq(): List = types.all.map { t -> FunctionSignature.Scalar( @@ -127,55 +127,55 @@ object PartiQLHeader : Header() { private fun ne(): List = types.all.map { t -> binary("ne", BOOL, t, t) - } + } + listOf(binary("ne", ANY, ANY, ANY)) - private fun and(): List = listOf(binary("and", BOOL, BOOL, BOOL)) + private fun and(): List = listOf(binary("and", BOOL, BOOL, BOOL), binary("and", ANY, ANY, ANY)) - private fun or(): List = listOf(binary("or", BOOL, BOOL, BOOL)) + private fun or(): List = listOf(binary("or", BOOL, BOOL, BOOL), binary("or", ANY, ANY, ANY)) private fun lt(): List = types.numeric.map { t -> binary("lt", BOOL, t, t) - } + } + listOf(binary("lt", ANY, ANY, ANY)) private fun lte(): List = types.numeric.map { t -> binary("lte", BOOL, t, t) - } + } + listOf(binary("lte", ANY, ANY, ANY)) private fun gt(): List = types.numeric.map { t -> binary("gt", BOOL, t, t) - } + } + listOf(binary("gt", ANY, ANY, ANY)) private fun gte(): List = types.numeric.map { t -> binary("gte", BOOL, t, t) - } + } + listOf(binary("gte", ANY, ANY, ANY)) private fun plus(): List = types.numeric.map { t -> binary("plus", t, t, t) - } + } + listOf(binary("plus", ANY, ANY, ANY)) private fun minus(): List = types.numeric.map { t -> binary("minus", t, t, t) - } + } + listOf(binary("minus", ANY, ANY, ANY)) private fun times(): List = types.numeric.map { t -> binary("times", t, t, t) - } + } + listOf(binary("times", ANY, ANY, ANY)) private fun div(): List = types.numeric.map { t -> binary("divide", t, t, t) - } + } + listOf(binary("divide", ANY, ANY, ANY)) private fun mod(): List = types.numeric.map { t -> binary("modulo", t, t, t) - } + } + listOf(binary("modulo", ANY, ANY, ANY)) private fun concat(): List = types.text.map { t -> binary("concat", t, t, t) - } + } + listOf(binary("concat", ANY, ANY, ANY)) private fun bitwiseAnd(): List = types.integer.map { t -> binary("bitwise_and", t, t, t) - } + } + listOf(binary("bitwise_and", ANY, ANY, ANY)) // BUILT INTS @@ -223,6 +223,27 @@ object PartiQLHeader : Header() { isNullCall = true, isNullable = false, ), + FunctionSignature.Scalar( + name = "like", + returns = ANY, + parameters = listOf( + FunctionParameter("value", ANY), + FunctionParameter("pattern", ANY), + ), + isNullCall = true, + isNullable = false, + ), + FunctionSignature.Scalar( + name = "like_escape", + returns = ANY, + parameters = listOf( + FunctionParameter("value", ANY), + FunctionParameter("pattern", ANY), + FunctionParameter("escape", ANY), + ), + isNullCall = true, + isNullable = false, + ), ) private fun between(): List = types.numeric.map { t -> @@ -328,7 +349,16 @@ object PartiQLHeader : Header() { ), isNullCall = false, isNullable = true, - ) + ), + FunctionSignature.Scalar( + name = "coalesce", + returns = ANY, + parameters = listOf( + FunctionParameter("values", ANY) + ), + isNullCall = false, + isNullable = true, + ), ) // NULLIF(x, y) @@ -343,7 +373,18 @@ object PartiQLHeader : Header() { isNullCall = true, isNullable = true, ) - } + } + listOf( + FunctionSignature.Scalar( + name = "null_if", + returns = ANY, + parameters = listOf( + FunctionParameter("value", ANY), + FunctionParameter("nullifier", ANY), + ), + isNullCall = true, + isNullable = true, + ) + ) // SUBSTRING (expression, start[, length]?) // SUBSTRINGG(expression from start [FOR length]? ) @@ -371,7 +412,29 @@ object PartiQLHeader : Header() { isNullable = false, ) ) - }.flatten() + }.flatten() + listOf( + FunctionSignature.Scalar( + name = "substring", + returns = ANY, + parameters = listOf( + FunctionParameter("value", ANY), + FunctionParameter("start", ANY), + ), + isNullCall = true, + isNullable = false, + ), + FunctionSignature.Scalar( + name = "substring", + returns = ANY, + parameters = listOf( + FunctionParameter("value", ANY), + FunctionParameter("start", ANY), + FunctionParameter("end", ANY), + ), + isNullCall = true, + isNullable = false, + ) + ) // position (str1, str2) // position (str1 in str2) @@ -386,7 +449,18 @@ object PartiQLHeader : Header() { isNullCall = true, isNullable = false, ) - } + } + listOf( + FunctionSignature.Scalar( + name = "position", + returns = ANY, + parameters = listOf( + FunctionParameter("probe", ANY), + FunctionParameter("value", ANY), + ), + isNullCall = true, + isNullable = false, + ) + ) // trim(str) private fun trim(): List = types.text.map { t -> @@ -399,7 +473,17 @@ object PartiQLHeader : Header() { isNullCall = true, isNullable = false, ) - } + } + listOf( + FunctionSignature.Scalar( + name = "trim", + returns = ANY, + parameters = listOf( + FunctionParameter("value", ANY), + ), + isNullCall = true, + isNullable = false, + ) + ) // TODO: We need to add a special form function for TRIM(BOTH FROM value) private fun trimSpecial(): List = types.text.map { t -> @@ -459,7 +543,62 @@ object PartiQLHeader : Header() { isNullable = false, ), ) - }.flatten() + }.flatten() + listOf( + // TRIM(chars FROM value) + // TRIM(both chars from value) + FunctionSignature.Scalar( + name = "trim_chars", + returns = ANY, + parameters = listOf( + FunctionParameter("value", ANY), + FunctionParameter("chars", ANY), + ), + isNullCall = true, + isNullable = false, + ), + // TRIM(LEADING FROM value) + FunctionSignature.Scalar( + name = "trim_leading", + returns = ANY, + parameters = listOf( + FunctionParameter("value", ANY), + ), + isNullCall = true, + isNullable = false, + ), + // TRIM(LEADING chars FROM value) + FunctionSignature.Scalar( + name = "trim_leading_chars", + returns = ANY, + parameters = listOf( + FunctionParameter("value", ANY), + FunctionParameter("chars", ANY), + ), + isNullCall = true, + isNullable = false, + ), + // TRIM(TRAILING FROM value) + FunctionSignature.Scalar( + name = "trim_trailing", + returns = ANY, + parameters = listOf( + FunctionParameter("value", ANY), + ), + isNullCall = true, + isNullable = false, + ), + // TRIM(TRAILING chars FROM value) + FunctionSignature.Scalar( + name = "trim_trailing_chars", + returns = ANY, + parameters = listOf( + FunctionParameter("value", ANY), + FunctionParameter("chars", ANY), + ), + isNullCall = true, + isNullable = false, + ), + ) // TODO private fun overlay(): List = emptyList() @@ -469,8 +608,8 @@ object PartiQLHeader : Header() { private fun dateAdd(): List { val operators = mutableListOf() - for (type in types.datetime) { - for (field in DatetimeField.values()) { + for (field in DatetimeField.values()) { + for (type in types.datetime) { if (field == DatetimeField.TIMEZONE_HOUR || field == DatetimeField.TIMEZONE_MINUTE) { continue } @@ -486,14 +625,25 @@ object PartiQLHeader : Header() { ) operators.add(signature) } + val anySignature = FunctionSignature.Scalar( + name = "date_diff_${field.name.lowercase()}", + returns = ANY, + parameters = listOf( + FunctionParameter("datetime1", ANY), + FunctionParameter("datetime2", ANY), + ), + isNullCall = true, + isNullable = false, + ) + operators.add(anySignature) } return operators } private fun dateDiff(): List { val operators = mutableListOf() - for (type in types.datetime) { - for (field in DatetimeField.values()) { + for (field in DatetimeField.values()) { + for (type in types.datetime) { if (field == DatetimeField.TIMEZONE_HOUR || field == DatetimeField.TIMEZONE_MINUTE) { continue } @@ -509,6 +659,17 @@ object PartiQLHeader : Header() { ) operators.add(signature) } + val anySignature = FunctionSignature.Scalar( + name = "date_diff_${field.name.lowercase()}", + returns = ANY, + parameters = listOf( + FunctionParameter("datetime1", ANY), + FunctionParameter("datetime2", ANY), + ), + isNullCall = true, + isNullable = false, + ) + operators.add(anySignature) } return operators } From 9129ab4b471b642eb24cd4bdd5cdf822571ff8b8 Mon Sep 17 00:00:00 2001 From: Yingtao Liu Date: Thu, 2 Nov 2023 10:42:00 -0700 Subject: [PATCH 3/6] rewrite control flow functions to case when --- .../main/kotlin/org/partiql/planner/Header.kt | 3 +- .../org/partiql/planner/PartiQLHeader.kt | 150 +++++++++--------- .../planner/transforms/RexConverter.kt | 53 +++++-- .../org/partiql/planner/typer/FnResolver.kt | 2 +- .../types/function/FunctionSignature.kt | 5 +- 5 files changed, 119 insertions(+), 94 deletions(-) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt index 399e017a7b..983e042a5f 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt @@ -68,13 +68,14 @@ public abstract class Header { ) @JvmStatic - internal fun binary(name: String, returns: PartiQLValueType, lhs: PartiQLValueType, rhs: PartiQLValueType) = + internal fun binary(name: String, returns: PartiQLValueType, lhs: PartiQLValueType, rhs: PartiQLValueType, isMissable: Boolean = false) = FunctionSignature.Scalar( name = name, returns = returns, parameters = listOf(FunctionParameter("lhs", lhs), FunctionParameter("rhs", rhs)), isNullCall = true, isNullable = false, + isMissable = isMissable ) } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt index 3434355447..5a214dee84 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt @@ -10,6 +10,8 @@ import org.partiql.value.PartiQLValueType.* * A header which uses the PartiQL Lang Kotlin default standard library. All functions exist in a global namespace. * Once we have catalogs with information_schema, the PartiQL Header will be fixed on a specification version and * user defined functions will be defined within their own schema. + * + * TODO: The model of ANY type in function signature is less than ideal. If we have */ @OptIn(PartiQLValueExperimental::class) object PartiQLHeader : Header() { @@ -65,8 +67,6 @@ object PartiQLHeader : Header() { private fun scalarBuiltins(): List = listOf( upper(), lower(), - coalesce(), - nullIf(), position(), substring(), trim(), @@ -127,58 +127,63 @@ object PartiQLHeader : Header() { private fun ne(): List = types.all.map { t -> binary("ne", BOOL, t, t) - } + listOf(binary("ne", ANY, ANY, ANY)) + } - private fun and(): List = listOf(binary("and", BOOL, BOOL, BOOL), binary("and", ANY, ANY, ANY)) + private fun and(): List = listOf( + binary("and", BOOL, BOOL, BOOL), + binary("and", BOOL, ANY, ANY, true) + ) - private fun or(): List = listOf(binary("or", BOOL, BOOL, BOOL), binary("or", ANY, ANY, ANY)) + private fun or(): List = listOf( + binary("or", BOOL, BOOL, BOOL), + binary("or", BOOL, ANY, ANY, true) + ) private fun lt(): List = types.numeric.map { t -> binary("lt", BOOL, t, t) - } + listOf(binary("lt", ANY, ANY, ANY)) + } + listOf(binary("lt", BOOL, ANY, ANY, true)) private fun lte(): List = types.numeric.map { t -> binary("lte", BOOL, t, t) - } + listOf(binary("lte", ANY, ANY, ANY)) + } + listOf(binary("lte", BOOL, ANY, ANY, true)) private fun gt(): List = types.numeric.map { t -> binary("gt", BOOL, t, t) - } + listOf(binary("gt", ANY, ANY, ANY)) + } + listOf(binary("gt", BOOL, ANY, ANY, true)) private fun gte(): List = types.numeric.map { t -> binary("gte", BOOL, t, t) - } + listOf(binary("gte", ANY, ANY, ANY)) + } + listOf(binary("gte", BOOL, ANY, ANY, true)) private fun plus(): List = types.numeric.map { t -> binary("plus", t, t, t) - } + listOf(binary("plus", ANY, ANY, ANY)) + } + listOf(binary("plus", BOOL, ANY, ANY, true)) private fun minus(): List = types.numeric.map { t -> binary("minus", t, t, t) - } + listOf(binary("minus", ANY, ANY, ANY)) + } + listOf(binary("minus", ANY, ANY, ANY, true)) private fun times(): List = types.numeric.map { t -> binary("times", t, t, t) - } + listOf(binary("times", ANY, ANY, ANY)) + } + listOf(binary("times", ANY, ANY, ANY, true)) private fun div(): List = types.numeric.map { t -> binary("divide", t, t, t) - } + listOf(binary("divide", ANY, ANY, ANY)) + } + listOf(binary("divide", ANY, ANY, ANY, true)) private fun mod(): List = types.numeric.map { t -> binary("modulo", t, t, t) - } + listOf(binary("modulo", ANY, ANY, ANY)) + } + listOf(binary("modulo", ANY, ANY, ANY, true)) private fun concat(): List = types.text.map { t -> binary("concat", t, t, t) - } + listOf(binary("concat", ANY, ANY, ANY)) + } + listOf(binary("concat", ANY, ANY, ANY, true)) private fun bitwiseAnd(): List = types.integer.map { t -> binary("bitwise_and", t, t, t) - } + listOf(binary("bitwise_and", ANY, ANY, ANY)) - - // BUILT INTS + } + listOf(binary("bitwise_and", ANY, ANY, ANY, true)) + // BUILT INS private fun upper(): List = types.text.map { t -> FunctionSignature.Scalar( name = "upper", @@ -187,7 +192,16 @@ object PartiQLHeader : Header() { isNullCall = true, isNullable = false, ) - } + } + listOf( + FunctionSignature.Scalar( + name = "upper", + returns = ANY, + parameters = listOf(FunctionParameter("value", ANY)), + isNullCall = true, + isNullable = false, + isMissable = true + ) + ) private fun lower(): List = types.text.map { t -> FunctionSignature.Scalar( @@ -197,7 +211,16 @@ object PartiQLHeader : Header() { isNullCall = true, isNullable = false, ) - } + } + listOf( + FunctionSignature.Scalar( + name = "lower", + returns = ANY, + parameters = listOf(FunctionParameter("value", ANY)), + isNullCall = true, + isNullable = false, + isMissable = true + ) + ) // SPECIAL FORMS @@ -225,17 +248,18 @@ object PartiQLHeader : Header() { ), FunctionSignature.Scalar( name = "like", - returns = ANY, + returns = BOOL, parameters = listOf( FunctionParameter("value", ANY), FunctionParameter("pattern", ANY), ), isNullCall = true, isNullable = false, + isMissable = true ), FunctionSignature.Scalar( name = "like_escape", - returns = ANY, + returns = BOOL, parameters = listOf( FunctionParameter("value", ANY), FunctionParameter("pattern", ANY), @@ -243,6 +267,7 @@ object PartiQLHeader : Header() { ), isNullCall = true, isNullable = false, + isMissable = true ), ) @@ -258,7 +283,20 @@ object PartiQLHeader : Header() { isNullCall = true, isNullable = false, ) - } + } + listOf( + FunctionSignature.Scalar( + name = "between", + returns = BOOL, + parameters = listOf( + FunctionParameter("value", ANY), + FunctionParameter("lower", ANY), + FunctionParameter("upper", ANY), + ), + isNullCall = true, + isNullable = false, + isMissable = true + ) + ) private fun inCollection(): List = types.all.map { element -> types.collections.map { collection -> @@ -337,55 +375,6 @@ object PartiQLHeader : Header() { ) } - // COALESCE(expression, expression, ... ) - // Initial implementation of Coalesce. - // As the number of parameter can not be pre-determined, we wrap those into a list - private fun coalesce(): List = listOf( - FunctionSignature.Scalar( - name = "coalesce", - returns = ANY, - parameters = listOf( - FunctionParameter("values", LIST) - ), - isNullCall = false, - isNullable = true, - ), - FunctionSignature.Scalar( - name = "coalesce", - returns = ANY, - parameters = listOf( - FunctionParameter("values", ANY) - ), - isNullCall = false, - isNullable = true, - ), - ) - - // NULLIF(x, y) - private fun nullIf(): List = types.nullable.map { t -> - FunctionSignature.Scalar( - name = "null_if", - returns = t, - parameters = listOf( - FunctionParameter("value", t), - FunctionParameter("nullifier", BOOL), // TODO: why is this BOOL? - ), - isNullCall = true, - isNullable = true, - ) - } + listOf( - FunctionSignature.Scalar( - name = "null_if", - returns = ANY, - parameters = listOf( - FunctionParameter("value", ANY), - FunctionParameter("nullifier", ANY), - ), - isNullCall = true, - isNullable = true, - ) - ) - // SUBSTRING (expression, start[, length]?) // SUBSTRINGG(expression from start [FOR length]? ) private fun substring(): List = types.text.map { t -> @@ -422,6 +411,7 @@ object PartiQLHeader : Header() { ), isNullCall = true, isNullable = false, + isMissable = false, ), FunctionSignature.Scalar( name = "substring", @@ -433,6 +423,7 @@ object PartiQLHeader : Header() { ), isNullCall = true, isNullable = false, + isMissable = false ) ) @@ -459,6 +450,7 @@ object PartiQLHeader : Header() { ), isNullCall = true, isNullable = false, + isMissable = true, ) ) @@ -482,6 +474,7 @@ object PartiQLHeader : Header() { ), isNullCall = true, isNullable = false, + isMissable = true, ) ) @@ -555,6 +548,7 @@ object PartiQLHeader : Header() { ), isNullCall = true, isNullable = false, + isMissable = true, ), // TRIM(LEADING FROM value) FunctionSignature.Scalar( @@ -565,6 +559,7 @@ object PartiQLHeader : Header() { ), isNullCall = true, isNullable = false, + isMissable = true, ), // TRIM(LEADING chars FROM value) FunctionSignature.Scalar( @@ -576,6 +571,7 @@ object PartiQLHeader : Header() { ), isNullCall = true, isNullable = false, + isMissable = true ), // TRIM(TRAILING FROM value) FunctionSignature.Scalar( @@ -586,6 +582,7 @@ object PartiQLHeader : Header() { ), isNullCall = true, isNullable = false, + isMissable = true ), // TRIM(TRAILING chars FROM value) FunctionSignature.Scalar( @@ -597,6 +594,7 @@ object PartiQLHeader : Header() { ), isNullCall = true, isNullable = false, + isMissable = true ), ) @@ -629,11 +627,12 @@ object PartiQLHeader : Header() { name = "date_diff_${field.name.lowercase()}", returns = ANY, parameters = listOf( - FunctionParameter("datetime1", ANY), - FunctionParameter("datetime2", ANY), + FunctionParameter("interval", INT), + FunctionParameter("datetime", ANY), ), isNullCall = true, isNullable = false, + isMissable = true ) operators.add(anySignature) } @@ -661,13 +660,14 @@ object PartiQLHeader : Header() { } val anySignature = FunctionSignature.Scalar( name = "date_diff_${field.name.lowercase()}", - returns = ANY, + returns = INT64, parameters = listOf( FunctionParameter("datetime1", ANY), FunctionParameter("datetime2", ANY), ), isNullCall = true, isNullable = false, + isMissable = true ) operators.add(anySignature) } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RexConverter.kt index 699767f7f7..0877352cae 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RexConverter.kt @@ -330,26 +330,47 @@ internal object RexConverter { return rex(type, call) } - override fun visitExprCoalesce(node: Expr.Coalesce, ctx: Env): Rex { + // coalesce(expr1, expr2, ... exprN) -> + // CASE + // WHEN expr1 IS NOT NULL THEN EXPR1 + // ... + // WHEN exprn is NOT NULL THEN exprn + // ELSE NULL END + override fun visitExprCoalesce(node: Expr.Coalesce, ctx: Env): Rex = plan { val type = StaticType.ANY - // Args - val arg0 = rex(StaticType.LIST, rexOpCollection(node.args.map { visitExpr(it, ctx) })) - // Call - val call = callNonHidden("coalesce", arg0) - return rex(type, call) + + val createBranch: (Rex) -> Rex.Op.Case.Branch = { expr: Rex -> + val updatedCondition = rex(type, negate(call("is_null", expr))) + rexOpCaseBranch(updatedCondition, expr) + } + + val branches = node.args.map { + createBranch(visitExpr(it, ctx)) + }.toMutableList() + + val defaultRex = rex(type = StaticType.NULL, op = rexOpLit(value = nullValue())) + branches += rexOpCaseBranch(bool(true), defaultRex) + val op = rexOpCase(branches) + rex(type, op) } - /** - * NULLIF(, ) - */ - override fun visitExprNullIf(node: Expr.NullIf, ctx: Env): Rex { + // nullIf(expr1, expr2) -> + // CASE + // WHEN expr1 = expr2 THEN NULL + // ELSE expr1 END + override fun visitExprNullIf(node: Expr.NullIf, ctx: Env): Rex = plan { val type = StaticType.ANY - // Args - val arg0 = visitExpr(node.value, ctx) - val arg1 = visitExpr(node.nullifier, ctx) - // Call - val call = callNonHidden("null_if", arg0, arg1) - return rex(type, call) + val expr1 = visitExpr(node.value, ctx) + val expr2 = visitExpr(node.nullifier, ctx) + val id = identifierSymbol(Expr.Binary.Op.EQ.name.lowercase(), Identifier.CaseSensitivity.SENSITIVE) + val fn = fnUnresolved(id, true) + val call = rexOpCall(fn, listOf(expr1, expr2)) + val branches = listOf( + rexOpCaseBranch(rex(type, call), rex(type = StaticType.NULL, op = rexOpLit(value = nullValue()))), + rexOpCaseBranch(bool(true), expr1) + ) + val op = rexOpCase(branches) + rex(type, op) } /** diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/FnResolver.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/FnResolver.kt index ad1808f1d3..b73d8afb0f 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/FnResolver.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/FnResolver.kt @@ -153,7 +153,7 @@ internal class FnResolver(private val headers: List
) { return when (match) { null -> FnMatch.Error(fn.identifier, args, candidates) else -> { - val isMissable = hadMissingArg || isUnsafeCast(match.signature.specific) + val isMissable = hadMissingArg || isUnsafeCast(match.signature.specific) || match.signature.isMissable FnMatch.Ok(match.signature, match.mapping, isMissable) } } diff --git a/partiql-types/src/main/kotlin/org/partiql/types/function/FunctionSignature.kt b/partiql-types/src/main/kotlin/org/partiql/types/function/FunctionSignature.kt index 89839c5edd..c2abeaa0ac 100644 --- a/partiql-types/src/main/kotlin/org/partiql/types/function/FunctionSignature.kt +++ b/partiql-types/src/main/kotlin/org/partiql/types/function/FunctionSignature.kt @@ -55,6 +55,7 @@ public sealed class FunctionSignature( parameters: List, description: String? = null, isNullable: Boolean = true, + @JvmField public val isMissable: Boolean = false, @JvmField public val isDeterministic: Boolean = true, @JvmField public val isNullCall: Boolean = false, ) : FunctionSignature(name, returns, parameters, description, isNullable) { @@ -67,7 +68,8 @@ public sealed class FunctionSignature( other.parameters.size != parameters.size || other.isDeterministic != isDeterministic || other.isNullCall != isNullCall || - other.isNullable != isNullable + other.isNullable != isNullable || + other.isMissable != isMissable ) { return false } @@ -87,6 +89,7 @@ public sealed class FunctionSignature( result = 31 * result + isDeterministic.hashCode() result = 31 * result + isNullCall.hashCode() result = 31 * result + isNullable.hashCode() + result = 31 * result + isMissable.hashCode() result = 31 * result + (description?.hashCode() ?: 0) return result } From 0822a57e79368ef3018162561950e76b5032d8b7 Mon Sep 17 00:00:00 2001 From: Yingtao Liu Date: Thu, 2 Nov 2023 10:52:13 -0700 Subject: [PATCH 4/6] typo fix --- .../main/kotlin/org/partiql/planner/Header.kt | 3 ++- .../org/partiql/planner/PartiQLHeader.kt | 21 +++++++++++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt index 983e042a5f..b474c98255 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt @@ -58,13 +58,14 @@ public abstract class Header { companion object { @JvmStatic - internal fun unary(name: String, returns: PartiQLValueType, value: PartiQLValueType) = + internal fun unary(name: String, returns: PartiQLValueType, value: PartiQLValueType, isMissable: Boolean = false) = FunctionSignature.Scalar( name = name, returns = returns, parameters = listOf(FunctionParameter("value", value)), isNullCall = true, isNullable = false, + isMissable = isMissable ) @JvmStatic diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt index 5a214dee84..2a0141bf42 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt @@ -4,14 +4,23 @@ import org.partiql.ast.DatetimeField import org.partiql.types.function.FunctionParameter import org.partiql.types.function.FunctionSignature import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType.* +import org.partiql.value.PartiQLValueType.ANY +import org.partiql.value.PartiQLValueType.BOOL +import org.partiql.value.PartiQLValueType.CHAR +import org.partiql.value.PartiQLValueType.DATE +import org.partiql.value.PartiQLValueType.DECIMAL +import org.partiql.value.PartiQLValueType.INT +import org.partiql.value.PartiQLValueType.INT32 +import org.partiql.value.PartiQLValueType.INT64 +import org.partiql.value.PartiQLValueType.STRING +import org.partiql.value.PartiQLValueType.TIME +import org.partiql.value.PartiQLValueType.TIMESTAMP /** * A header which uses the PartiQL Lang Kotlin default standard library. All functions exist in a global namespace. * Once we have catalogs with information_schema, the PartiQL Header will be fixed on a specification version and * user defined functions will be defined within their own schema. * - * TODO: The model of ANY type in function signature is less than ideal. If we have */ @OptIn(PartiQLValueExperimental::class) object PartiQLHeader : Header() { @@ -105,7 +114,7 @@ object PartiQLHeader : Header() { // OPERATORS - private fun not(): List = listOf(unary("not", BOOL, BOOL), unary("not", ANY, ANY)) + private fun not(): List = listOf(unary("not", BOOL, BOOL), unary("not", BOOL, ANY, true)) private fun pos(): List = types.numeric.map { t -> unary("pos", t, t) @@ -113,7 +122,7 @@ object PartiQLHeader : Header() { private fun neg(): List = types.numeric.map { t -> unary("neg", t, t) - } + listOf(unary("pos", ANY, ANY)) + } + listOf(unary("neg", ANY, ANY)) private fun eq(): List = types.all.map { t -> FunctionSignature.Scalar( @@ -411,7 +420,7 @@ object PartiQLHeader : Header() { ), isNullCall = true, isNullable = false, - isMissable = false, + isMissable = true, ), FunctionSignature.Scalar( name = "substring", @@ -423,7 +432,7 @@ object PartiQLHeader : Header() { ), isNullCall = true, isNullable = false, - isMissable = false + isMissable = true ) ) From d0540a720b47d53aa3d9e29d7042093826990d06 Mon Sep 17 00:00:00 2001 From: Yingtao Liu Date: Thu, 2 Nov 2023 10:55:19 -0700 Subject: [PATCH 5/6] fix typo --- .../src/main/kotlin/org/partiql/planner/PartiQLHeader.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt index 2a0141bf42..cbbfc9682b 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt @@ -166,7 +166,7 @@ object PartiQLHeader : Header() { private fun plus(): List = types.numeric.map { t -> binary("plus", t, t, t) - } + listOf(binary("plus", BOOL, ANY, ANY, true)) + } + listOf(binary("plus", ANY, ANY, ANY, true)) private fun minus(): List = types.numeric.map { t -> binary("minus", t, t, t) From e79268b0afc79663e3e0112ff43e5f7bbeba2b32 Mon Sep 17 00:00:00 2001 From: Yingtao Liu Date: Thu, 2 Nov 2023 11:13:50 -0700 Subject: [PATCH 6/6] address feedback --- .../src/main/kotlin/org/partiql/planner/PartiQLHeader.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt index cbbfc9682b..5cc103cf1e 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt @@ -636,7 +636,7 @@ object PartiQLHeader : Header() { name = "date_diff_${field.name.lowercase()}", returns = ANY, parameters = listOf( - FunctionParameter("interval", INT), + FunctionParameter("interval", ANY), FunctionParameter("datetime", ANY), ), isNullCall = true,