From 42da8441fe762ca4a14093aa53b169d36d016621 Mon Sep 17 00:00:00 2001 From: Yingtao Liu Date: Mon, 5 Feb 2024 15:59:03 -0800 Subject: [PATCH 1/2] rebase eval --- .../eval/internal/operator/rex/ExprCast.kt | 348 +++++++++++++++++- .../internal/operator/rex/ExprPathIndex.kt | 16 +- .../eval/internal/operator/rex/ExprStruct.kt | 4 +- partiql-parser/src/main/antlr/PartiQL.g4 | 3 +- .../src/main/antlr/PartiQLTokens.g4 | 1 + .../parser/internal/PartiQLParserDefault.kt | 24 +- .../partiql/spi/connector/sql/SqlBuiltins.kt | 11 + .../spi/connector/sql/builtins/FnAbs.kt | 174 +++++++++ .../connector/sql/builtins/FnCharLength.kt | 78 ++++ .../connector/sql/builtins/FnDateAddDay.kt | 20 +- .../connector/sql/builtins/FnDateAddHour.kt | 16 +- .../connector/sql/builtins/FnDateAddMinute.kt | 16 +- .../connector/sql/builtins/FnDateAddMonth.kt | 15 +- .../connector/sql/builtins/FnDateAddSecond.kt | 16 +- .../connector/sql/builtins/FnDateAddYear.kt | 16 +- .../spi/connector/sql/builtins/FnIsChar.kt | 2 +- .../spi/connector/sql/builtins/FnIsString.kt | 2 +- .../spi/connector/sql/builtins/FnSubstring.kt | 19 +- .../kotlin/org/partiql/value/PartiQLValue.kt | 23 +- .../partiql/value/impl/DecimalValueImpl.kt | 66 ++++ .../partiql/value/impl/Float32ValueImpl.kt | 68 ++++ .../partiql/value/impl/Float64ValueImpl.kt | 73 ++++ .../org/partiql/value/impl/Int16ValueImpl.kt | 38 ++ .../org/partiql/value/impl/Int32ValueImpl.kt | 44 +++ .../org/partiql/value/impl/Int64ValueImpl.kt | 50 +++ .../org/partiql/value/impl/Int8ValueImpl.kt | 32 ++ .../org/partiql/value/impl/IntValueImpl.kt | 50 +++ .../org/partiql/value/impl/NullValueImpl.kt | 56 +++ .../internal/fn/scalar/FnDateAddYear.kt | 188 ++++++++++ .../plugin/internal/fn/scalar/FnIsChar.kt | 69 ++++ .../partiql/runner/executor/EvalExecutor.kt | 23 +- 31 files changed, 1476 insertions(+), 85 deletions(-) create mode 100644 partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnAbs.kt create mode 100644 partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnCharLength.kt create mode 100644 plugins/partiql-plugin/src/main/kotlin/org/partiql/plugin/internal/fn/scalar/FnDateAddYear.kt create mode 100644 plugins/partiql-plugin/src/main/kotlin/org/partiql/plugin/internal/fn/scalar/FnIsChar.kt diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCast.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCast.kt index 3e0d9515c7..bea7325d0f 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCast.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCast.kt @@ -1,18 +1,354 @@ package org.partiql.eval.internal.operator.rex +import com.amazon.ion.Decimal +import com.amazon.ionelement.api.ElementType +import com.amazon.ionelement.api.IonElementException +import com.amazon.ionelement.api.createIonElementLoader +import org.partiql.errors.DataException +import org.partiql.errors.TypeCheckException import org.partiql.eval.internal.Record import org.partiql.eval.internal.operator.Operator import org.partiql.plan.Ref +import org.partiql.value.BagValue +import org.partiql.value.BoolValue +import org.partiql.value.CollectionValue +import org.partiql.value.DecimalValue +import org.partiql.value.Float32Value +import org.partiql.value.Float64Value +import org.partiql.value.Int16Value +import org.partiql.value.Int32Value +import org.partiql.value.Int64Value +import org.partiql.value.Int8Value +import org.partiql.value.IntValue +import org.partiql.value.ListValue +import org.partiql.value.NumericValue import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType +import org.partiql.value.SexpValue +import org.partiql.value.StringValue +import org.partiql.value.SymbolValue +import org.partiql.value.TextValue +import org.partiql.value.bagValue +import org.partiql.value.boolValue +import org.partiql.value.decimalValue +import org.partiql.value.float32Value +import org.partiql.value.float64Value +import org.partiql.value.int16Value +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.int8Value +import org.partiql.value.intValue +import org.partiql.value.listValue +import org.partiql.value.sexpValue +import org.partiql.value.stringValue +import org.partiql.value.symbolValue +import java.math.BigDecimal +import java.math.BigInteger -internal class ExprCast( - private val arg: Operator.Expr, - private val cast: Ref.Cast, -) : Operator.Expr { - +// TODO: This is incomplete +internal class ExprCast(val arg: Operator.Expr, val cast: Ref.Cast) : Operator.Expr { @OptIn(PartiQLValueExperimental::class) override fun eval(record: Record): PartiQLValue { - TODO("Not yet implemented") + val arg = arg.eval(record) + try { + return when (arg.type) { + PartiQLValueType.ANY -> TODO("Not Possible") + PartiQLValueType.BOOL -> castFromBool(arg as BoolValue, cast.target) + PartiQLValueType.INT8 -> castFromNumeric(arg as Int8Value, cast.target) + PartiQLValueType.INT16 -> castFromNumeric(arg as Int16Value, cast.target) + PartiQLValueType.INT32 -> castFromNumeric(arg as Int32Value, cast.target) + PartiQLValueType.INT64 -> castFromNumeric(arg as Int64Value, cast.target) + PartiQLValueType.INT -> castFromNumeric(arg as IntValue, cast.target) + PartiQLValueType.DECIMAL -> castFromNumeric(arg as DecimalValue, cast.target) + PartiQLValueType.DECIMAL_ARBITRARY -> castFromNumeric(arg as DecimalValue, cast.target) + PartiQLValueType.FLOAT32 -> castFromNumeric(arg as Float32Value, cast.target) + PartiQLValueType.FLOAT64 -> castFromNumeric(arg as Float64Value, cast.target) + PartiQLValueType.CHAR -> TODO("Char value implementation is wrong") + PartiQLValueType.STRING -> castFromText(arg as StringValue, cast.target) + PartiQLValueType.SYMBOL -> castFromText(arg as SymbolValue, cast.target) + PartiQLValueType.BINARY -> TODO("Static Type does not support Binary") + PartiQLValueType.BYTE -> TODO("Static Type does not support Byte") + PartiQLValueType.BLOB -> TODO("CAST FROM BLOB not yet implemented") + PartiQLValueType.CLOB -> TODO("CAST FROM CLOB not yet implemented") + PartiQLValueType.DATE -> TODO("CAST FROM DATE not yet implemented") + PartiQLValueType.TIME -> TODO("CAST FROM TIME not yet implemented") + PartiQLValueType.TIMESTAMP -> TODO("CAST FROM TIMESTAMP not yet implemented") + PartiQLValueType.INTERVAL -> TODO("Static Type does not support INTERVAL") + PartiQLValueType.BAG -> castFromCollection(arg as BagValue<*>, cast.target) + PartiQLValueType.LIST -> castFromCollection(arg as ListValue<*>, cast.target) + PartiQLValueType.SEXP -> castFromCollection(arg as SexpValue<*>, cast.target) + PartiQLValueType.STRUCT -> TODO("CAST FROM STRUCT not yet implemented") + PartiQLValueType.NULL -> error("cast from NULL should be handled by Typer") + PartiQLValueType.MISSING -> error("cast from MISSING should be handled by Typer") + } + } catch (e: DataException) { + throw TypeCheckException() + } + } + + @OptIn(PartiQLValueExperimental::class) + private fun castFromBool(value: BoolValue, t: PartiQLValueType): PartiQLValue { + val v = value.value + return when (t) { + PartiQLValueType.ANY -> value + PartiQLValueType.BOOL -> value + PartiQLValueType.INT8 -> when (v) { + true -> int8Value(1) + false -> int8Value(0) + null -> int8Value(null) + } + + PartiQLValueType.INT16 -> when (v) { + true -> int16Value(1) + false -> int16Value(0) + null -> int16Value(null) + } + + PartiQLValueType.INT32 -> when (v) { + true -> int32Value(1) + false -> int32Value(0) + null -> int32Value(null) + } + + PartiQLValueType.INT64 -> when (v) { + true -> int64Value(1) + false -> int64Value(0) + null -> int64Value(null) + } + + PartiQLValueType.INT -> when (v) { + true -> intValue(BigInteger.valueOf(1)) + false -> intValue(BigInteger.valueOf(0)) + null -> intValue(null) + } + + PartiQLValueType.DECIMAL, PartiQLValueType.DECIMAL_ARBITRARY -> when (v) { + true -> decimalValue(BigDecimal.ONE) + false -> decimalValue(BigDecimal.ZERO) + null -> decimalValue(null) + } + + PartiQLValueType.FLOAT32 -> { + when (v) { + true -> float32Value(1.0.toFloat()) + false -> float32Value(0.0.toFloat()) + null -> float32Value(null) + } + } + + PartiQLValueType.FLOAT64 -> when (v) { + true -> float64Value(1.0) + false -> float64Value(0.0) + null -> float64Value(null) + } + + PartiQLValueType.CHAR -> TODO("Char value implementation is wrong") + PartiQLValueType.STRING -> stringValue(v?.toString()) + PartiQLValueType.SYMBOL -> symbolValue(v?.toString()) + PartiQLValueType.BINARY, PartiQLValueType.BYTE, + PartiQLValueType.BLOB, PartiQLValueType.CLOB, + PartiQLValueType.DATE, PartiQLValueType.TIME, PartiQLValueType.TIMESTAMP, + PartiQLValueType.INTERVAL, + PartiQLValueType.BAG, PartiQLValueType.LIST, + PartiQLValueType.SEXP, + PartiQLValueType.STRUCT -> error("can not perform cast from $value to $t") + PartiQLValueType.NULL -> error("cast to null not supported") + PartiQLValueType.MISSING -> error("cast to missing not supported") + } + } + @OptIn(PartiQLValueExperimental::class) + private fun castFromNumeric(value: NumericValue<*>, t: PartiQLValueType): PartiQLValue { + val v = value.value + return when (t) { + PartiQLValueType.ANY -> value + PartiQLValueType.BOOL -> when { + v == null -> boolValue(null) + v == 0.0 -> boolValue(false) + else -> boolValue(true) + } + PartiQLValueType.INT8 -> value.toInt8() + PartiQLValueType.INT16 -> value.toInt16() + PartiQLValueType.INT32 -> value.toInt32() + PartiQLValueType.INT64 -> value.toInt64() + PartiQLValueType.INT -> value.toInt() + PartiQLValueType.DECIMAL -> value.toDecimal() + PartiQLValueType.DECIMAL_ARBITRARY -> value.toDecimal() + PartiQLValueType.FLOAT32 -> value.toFloat32() + PartiQLValueType.FLOAT64 -> value.toFloat64() + PartiQLValueType.CHAR -> TODO("Char value implementation is wrong") + PartiQLValueType.STRING -> stringValue(v?.toString(), value.annotations) + PartiQLValueType.SYMBOL -> symbolValue(v?.toString(), value.annotations) + PartiQLValueType.BINARY, PartiQLValueType.BYTE, + PartiQLValueType.BLOB, PartiQLValueType.CLOB, + PartiQLValueType.DATE, PartiQLValueType.TIME, PartiQLValueType.TIMESTAMP, + PartiQLValueType.INTERVAL, + PartiQLValueType.BAG, PartiQLValueType.LIST, + PartiQLValueType.SEXP, + PartiQLValueType.STRUCT -> error("can not perform cast from $value to $t") + PartiQLValueType.NULL -> error("cast to null not supported") + PartiQLValueType.MISSING -> error("cast to missing not supported") + } + } + + @OptIn(PartiQLValueExperimental::class) + private fun castFromText(value: TextValue, t: PartiQLValueType): PartiQLValue { + return when (t) { + PartiQLValueType.ANY -> value + PartiQLValueType.BOOL -> { + val str = value.value?.lowercase() ?: return boolValue(null, value.annotations) + if (str == "true") return boolValue(true, value.annotations) + if (str == "false") return boolValue(false, value.annotations) + throw TypeCheckException() + } + PartiQLValueType.INT8 -> { + val stringValue = value.value ?: return int8Value(null, value.annotations) + when (val number = getNumberValueFromString(stringValue)) { + is BigInteger -> intValue(number, value.annotations).toInt8() + else -> throw TypeCheckException() + } + } + PartiQLValueType.INT16 -> { + val stringValue = value.value ?: return int16Value(null, value.annotations) + when (val number = getNumberValueFromString(stringValue)) { + is BigInteger -> intValue(number, value.annotations).toInt16() + else -> throw TypeCheckException() + } + } + PartiQLValueType.INT32 -> { + val stringValue = value.value ?: return int32Value(null, value.annotations) + when (val number = getNumberValueFromString(stringValue)) { + is BigInteger -> intValue(number, value.annotations).toInt32() + else -> throw TypeCheckException() + } + } + PartiQLValueType.INT64 -> { + val stringValue = value.value ?: return int64Value(null, value.annotations) + when (val number = getNumberValueFromString(stringValue)) { + is BigInteger -> intValue(number, value.annotations).toInt64() + else -> throw TypeCheckException() + } + } + PartiQLValueType.INT -> { + val stringValue = value.value ?: return intValue(null, value.annotations) + when (val number = getNumberValueFromString(stringValue)) { + is BigInteger -> intValue(number, value.annotations).toInt() + else -> throw TypeCheckException() + } + } + PartiQLValueType.DECIMAL -> { + val stringValue = value.value ?: return int16Value(null, value.annotations) + when (val number = getNumberValueFromString(stringValue)) { + is Decimal -> decimalValue(number, value.annotations).toDecimal() + else -> throw TypeCheckException() + } + } + PartiQLValueType.DECIMAL_ARBITRARY -> { + val stringValue = value.value ?: return int16Value(null, value.annotations) + when (val number = getNumberValueFromString(stringValue)) { + is Decimal -> decimalValue(number, value.annotations).toDecimal() + else -> throw TypeCheckException() + } + } + PartiQLValueType.FLOAT32 -> { + val stringValue = value.value ?: return int16Value(null, value.annotations) + when (val number = getNumberValueFromString(stringValue)) { + is Double -> float64Value(number, value.annotations).toFloat32() + else -> throw TypeCheckException() + } + } + PartiQLValueType.FLOAT64 -> { + val stringValue = value.value ?: return int16Value(null, value.annotations) + when (val number = getNumberValueFromString(stringValue)) { + is Double -> float64Value(number, value.annotations).toFloat32() + else -> throw TypeCheckException() + } + } + PartiQLValueType.CHAR -> TODO("Char value implementation is wrong") + PartiQLValueType.STRING -> stringValue(value.value, value.annotations) + PartiQLValueType.SYMBOL -> symbolValue(value.value, value.annotations) + PartiQLValueType.BINARY, PartiQLValueType.BYTE, + PartiQLValueType.BLOB, PartiQLValueType.CLOB, + PartiQLValueType.DATE, PartiQLValueType.TIME, PartiQLValueType.TIMESTAMP, + PartiQLValueType.INTERVAL, + PartiQLValueType.BAG, PartiQLValueType.LIST, + PartiQLValueType.SEXP, + PartiQLValueType.STRUCT -> error("can not perform cast from INT8 to $t") + PartiQLValueType.NULL -> error("cast to null not supported") + PartiQLValueType.MISSING -> error("cast to missing not supported") + } + } + + // TODO: Fix NULL Collection + @OptIn(PartiQLValueExperimental::class) + private fun castFromCollection(value: CollectionValue<*>, t: PartiQLValueType): PartiQLValue { + val elements = mutableListOf() + value.iterator().forEachRemaining { + elements.add(it) + } + return when (t) { + PartiQLValueType.BAG -> bagValue(elements) + PartiQLValueType.LIST -> listValue(elements) + PartiQLValueType.SEXP -> sexpValue(elements) + else -> error("can not perform cast from $value to $t") + } + } + + // For now, utilize ion to parse string such as 0b10, etc. + private fun getNumberValueFromString(str: String): Number? { + val ion = try { + str.let { createIonElementLoader().loadSingleElement(it.normalizeForCastToInt()) } + } catch (e: IonElementException) { + throw TypeCheckException() + } + return when (ion.type) { + ElementType.INT -> ion.bigIntegerValueOrNull + ElementType.FLOAT -> ion.doubleValueOrNull + ElementType.DECIMAL -> ion.decimalValueOrNull + else -> null + } + } + + private fun String.normalizeForCastToInt(): String { + fun Char.isSign() = this == '-' || this == '+' + fun Char.isHexOrBase2Marker(): Boolean { + val c = this.lowercaseChar() + + return c == 'x' || c == 'b' + } + + fun String.possiblyHexOrBase2() = (length >= 2 && this[1].isHexOrBase2Marker()) || + (length >= 3 && this[0].isSign() && this[2].isHexOrBase2Marker()) + + return when { + length == 0 -> this + possiblyHexOrBase2() -> { + if (this[0] == '+') { + this.drop(1) + } else { + this + } + } + else -> { + val (isNegative, startIndex) = when (this[0]) { + '-' -> Pair(true, 1) + '+' -> Pair(false, 1) + else -> Pair(false, 0) + } + + var toDrop = startIndex + while (toDrop < length && this[toDrop] == '0') { + toDrop += 1 + } + + when { + toDrop == length -> "0" // string is all zeros + toDrop == 0 -> this + toDrop == 1 && isNegative -> this + toDrop > 1 && isNegative -> '-' + this.drop(toDrop) + else -> this.drop(toDrop) + } + } + } } } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprPathIndex.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprPathIndex.kt index 8a4c35d77d..a812383e06 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprPathIndex.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprPathIndex.kt @@ -1,5 +1,6 @@ package org.partiql.eval.internal.operator.rex +import org.partiql.errors.DataException import org.partiql.errors.TypeCheckException import org.partiql.eval.internal.Record import org.partiql.eval.internal.operator.Operator @@ -9,6 +10,7 @@ import org.partiql.value.Int32Value import org.partiql.value.Int64Value import org.partiql.value.Int8Value import org.partiql.value.IntValue +import org.partiql.value.NumericValue import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.check @@ -24,11 +26,15 @@ internal class ExprPathIndex( // Calculate index val index = when (val k = key.eval(record)) { - is Int16Value -> k.int - is Int32Value -> k.int - is Int64Value -> k.int - is Int8Value -> k.int - is IntValue -> k.int + is Int16Value, + is Int32Value, + is Int64Value, + is Int8Value, + is IntValue -> try { + (k as NumericValue<*>).toInt32().value + } catch (e: DataException) { + throw TypeCheckException() + } else -> throw TypeCheckException() } ?: throw TypeCheckException() diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprStruct.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprStruct.kt index 3afd9840c9..232a27e6be 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprStruct.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprStruct.kt @@ -5,7 +5,7 @@ import org.partiql.eval.internal.operator.Operator import org.partiql.value.MissingValue import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.StringValue +import org.partiql.value.TextValue import org.partiql.value.check import org.partiql.value.structValue @@ -13,7 +13,7 @@ internal class ExprStruct(val fields: List) : Operator.Expr { @OptIn(PartiQLValueExperimental::class) override fun eval(record: Record): PartiQLValue { val fields = fields.mapNotNull { - val key = it.key.eval(record).check() + val key = it.key.eval(record).check>() when (val value = it.value.eval(record)) { is MissingValue -> null else -> key.value!! to value diff --git a/partiql-parser/src/main/antlr/PartiQL.g4 b/partiql-parser/src/main/antlr/PartiQL.g4 index 748d038366..041c5e50ae 100644 --- a/partiql-parser/src/main/antlr/PartiQL.g4 +++ b/partiql-parser/src/main/antlr/PartiQL.g4 @@ -716,7 +716,8 @@ functionCall // SQL-99 10.4 — ::= [ ] functionName - : (qualifier+=symbolPrimitive PERIOD)* name=( CHAR_LENGTH | CHARACTER_LENGTH | OCTET_LENGTH | BIT_LENGTH | UPPER | LOWER | SIZE | EXISTS | COUNT ) # FunctionNameReserved + : (qualifier+=symbolPrimitive PERIOD)* name=( CHAR_LENGTH | CHARACTER_LENGTH | OCTET_LENGTH | BIT_LENGTH | + UPPER | LOWER | SIZE | EXISTS | COUNT | MOD) # FunctionNameReserved | (qualifier+=symbolPrimitive PERIOD)* name=symbolPrimitive # FunctionNameSymbol ; diff --git a/partiql-parser/src/main/antlr/PartiQLTokens.g4 b/partiql-parser/src/main/antlr/PartiQLTokens.g4 index 12d396d30d..e80269aff5 100644 --- a/partiql-parser/src/main/antlr/PartiQLTokens.g4 +++ b/partiql-parser/src/main/antlr/PartiQLTokens.g4 @@ -156,6 +156,7 @@ LOWER: 'LOWER'; MATCH: 'MATCH'; MAX: 'MAX'; MIN: 'MIN'; +MOD: 'MOD'; MODULE: 'MODULE'; NAMES: 'NAMES'; NATIONAL: 'NATIONAL'; diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt index 73a029b6e9..acb34d27cf 100644 --- a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt @@ -854,7 +854,7 @@ internal class PartiQLParserDefault : PartiQLParser { throw error(ctx, "Expected a path element literal") } when (val i = v.value) { - is NumericValue<*> -> pathStepIndex(i.int!!) + is NumericValue<*> -> pathStepIndex(i.toInt32().value!!) is StringValue -> pathStepSymbol( identifierSymbol( i.value!!, Identifier.CaseSensitivity.SENSITIVE @@ -1725,14 +1725,30 @@ internal class PartiQLParserDefault : PartiQLParser { } override fun visitFunctionCall(ctx: GeneratedParser.FunctionCallContext) = translate(ctx) { - val function = visit(ctx.functionName()) as Identifier val args = visitOrEmpty(ctx.expr()) - exprCall(function, args) + when (val funcName = ctx.functionName()) { + is GeneratedParser.FunctionNameReservedContext -> { + when (funcName.name.type) { + GeneratedParser.MOD -> exprBinary(Expr.Binary.Op.MODULO, args[0], args[1]) + else -> visitNonReservedFunctionCall(ctx, args) + } + } + else -> visitNonReservedFunctionCall(ctx, args) + } + } + private fun visitNonReservedFunctionCall(ctx: GeneratedParser.FunctionCallContext, args: List): Expr.Call { + val function = visit(ctx.functionName()) as Identifier + return exprCall(function, args) } override fun visitFunctionNameReserved(ctx: GeneratedParser.FunctionNameReservedContext): Identifier { val path = ctx.qualifier.map { visitSymbolPrimitive(it) } - val name = identifierSymbol(ctx.name.text, Identifier.CaseSensitivity.INSENSITIVE) + val name = when (ctx.name.type) { + GeneratedParser.CHARACTER_LENGTH, GeneratedParser.CHAR_LENGTH -> + identifierSymbol("char_length", Identifier.CaseSensitivity.INSENSITIVE) + else -> + identifierSymbol(ctx.name.text, Identifier.CaseSensitivity.INSENSITIVE) + } return if (path.isEmpty()) { name } else { diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt index 17be430f28..a67452a353 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt @@ -13,6 +13,17 @@ internal object SqlBuiltins { @JvmStatic val builtins: List = listOf( + Fn_ABS__INT8__INT8, + Fn_ABS__INT16__INT16, + Fn_ABS__INT32__INT32, + Fn_ABS__INT64__INT64, + Fn_ABS__INT__INT, + Fn_ABS__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY, + Fn_ABS__FLOAT32__FLOAT32, + Fn_ABS__FLOAT64__FLOAT64, + Fn_CHAR_LENGTH__STRING__INT, + Fn_CHAR_LENGTH__SYMBOL__INT, + Fn_CHAR_LENGTH__CLOB__INT, Fn_POS__INT8__INT8, Fn_POS__INT16__INT16, Fn_POS__INT32__INT32, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnAbs.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnAbs.kt new file mode 100644 index 0000000000..bfd4a054a0 --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnAbs.kt @@ -0,0 +1,174 @@ +// ktlint-disable filename +@file:Suppress("ClassName") + +package org.partiql.spi.connector.sql.builtins + +import org.partiql.spi.fn.Fn +import org.partiql.spi.fn.FnExperimental +import org.partiql.spi.fn.FnParameter +import org.partiql.spi.fn.FnSignature +import org.partiql.value.DecimalValue +import org.partiql.value.Float32Value +import org.partiql.value.Float64Value +import org.partiql.value.Int16Value +import org.partiql.value.Int32Value +import org.partiql.value.Int64Value +import org.partiql.value.Int8Value +import org.partiql.value.IntValue +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType.DECIMAL_ARBITRARY +import org.partiql.value.PartiQLValueType.FLOAT32 +import org.partiql.value.PartiQLValueType.FLOAT64 +import org.partiql.value.PartiQLValueType.INT +import org.partiql.value.PartiQLValueType.INT16 +import org.partiql.value.PartiQLValueType.INT32 +import org.partiql.value.PartiQLValueType.INT64 +import org.partiql.value.PartiQLValueType.INT8 +import org.partiql.value.check +import org.partiql.value.decimalValue +import org.partiql.value.float32Value +import org.partiql.value.float64Value +import org.partiql.value.int16Value +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.int8Value +import org.partiql.value.intValue +import kotlin.math.absoluteValue + +// TODO: When negate a negative value, we need to consider overflow +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +internal object Fn_ABS__INT8__INT8 : Fn { + + override val signature = FnSignature( + name = "abs", + returns = INT8, + parameters = listOf(FnParameter("value", INT8)), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): Int8Value { + val value = args[0].check().value!! + return if (value < 0) int8Value(value.times(-1).toByte()) else int8Value(value) + } +} + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +internal object Fn_ABS__INT16__INT16 : Fn { + + override val signature = FnSignature( + name = "abs", + returns = INT16, + parameters = listOf(FnParameter("value", INT16)), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): Int16Value { + val value = args[0].check().value!! + return if (value < 0) int16Value(value.times(-1).toShort()) else int16Value(value) + } +} + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +internal object Fn_ABS__INT32__INT32 : Fn { + + override val signature = FnSignature( + name = "abs", + returns = INT32, + parameters = listOf(FnParameter("value", INT32)), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): Int32Value { + val value = args[0].check().value!! + return int32Value(value.absoluteValue) + } +} + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +internal object Fn_ABS__INT64__INT64 : Fn { + + override val signature = FnSignature( + name = "abs", + returns = INT64, + parameters = listOf(FnParameter("value", INT64)), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): Int64Value { + val value = args[0].check().value!! + return int64Value(value.absoluteValue) + } +} + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +internal object Fn_ABS__INT__INT : Fn { + + override val signature = FnSignature( + name = "abs", + returns = INT, + parameters = listOf(FnParameter("value", INT)), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): IntValue { + val value = args[0].check().value!! + return intValue(value.abs()) + } +} + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +internal object Fn_ABS__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY : Fn { + + override val signature = FnSignature( + name = "abs", + returns = DECIMAL_ARBITRARY, + parameters = listOf(FnParameter("value", DECIMAL_ARBITRARY)), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): DecimalValue { + val value = args[0].check().value!! + return decimalValue(value.abs()) + } +} + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +internal object Fn_ABS__FLOAT32__FLOAT32 : Fn { + + override val signature = FnSignature( + name = "abs", + returns = FLOAT32, + parameters = listOf(FnParameter("value", FLOAT32)), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): Float32Value { + val value = args[0].check().value!! + return float32Value(value.absoluteValue) + } +} + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +internal object Fn_ABS__FLOAT64__FLOAT64 : Fn { + + override val signature = FnSignature( + name = "abs", + returns = FLOAT64, + parameters = listOf(FnParameter("value", FLOAT64)), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): Float64Value { + val value = args[0].check().value!! + return float64Value(value.absoluteValue) + } +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnCharLength.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnCharLength.kt new file mode 100644 index 0000000000..e2c0b40003 --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnCharLength.kt @@ -0,0 +1,78 @@ +// ktlint-disable filename +@file:Suppress("ClassName") + +package org.partiql.spi.connector.sql.builtins + +import org.partiql.spi.fn.Fn +import org.partiql.spi.fn.FnExperimental +import org.partiql.spi.fn.FnParameter +import org.partiql.spi.fn.FnSignature +import org.partiql.value.ClobValue +import org.partiql.value.Int32Value +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType.CLOB +import org.partiql.value.PartiQLValueType.INT32 +import org.partiql.value.PartiQLValueType.STRING +import org.partiql.value.PartiQLValueType.SYMBOL +import org.partiql.value.StringValue +import org.partiql.value.SymbolValue +import org.partiql.value.check +import org.partiql.value.int32Value + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +internal object Fn_CHAR_LENGTH__STRING__INT : Fn { + + override val signature = FnSignature( + name = "char_length", + returns = INT32, + parameters = listOf( + FnParameter("value", STRING), + ), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): Int32Value { + val value = args[0].check().value!! + return int32Value(value.codePointCount(0, value.length)) + } +} + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +internal object Fn_CHAR_LENGTH__SYMBOL__INT : Fn { + + override val signature = FnSignature( + name = "char_length", + returns = INT32, + parameters = listOf( + FnParameter("lhs", SYMBOL), + ), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): Int32Value { + val value = args[0].check().value!! + return int32Value(value.codePointCount(0, value.length)) + } +} + +@OptIn(PartiQLValueExperimental::class, FnExperimental::class) +internal object Fn_CHAR_LENGTH__CLOB__INT : Fn { + + override val signature = FnSignature( + name = "char_length", + returns = INT32, + parameters = listOf( + FnParameter("lhs", CLOB), + ), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): Int32Value { + val value = args[0].check().value!! + return int32Value(value.size) + } +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddDay.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddDay.kt index d0a8b2cc63..1f6c6cc08b 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddDay.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddDay.kt @@ -3,6 +3,8 @@ package org.partiql.spi.connector.sql.builtins +import org.partiql.errors.DataException +import org.partiql.errors.TypeCheckException import org.partiql.spi.fn.Fn import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnParameter @@ -41,7 +43,7 @@ internal object Fn_DATE_ADD_DAY__INT32_DATE__DATE : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.toInt64().value!! return dateValue(datetimeValue.plusDays(intervalValue)) } } @@ -64,7 +66,7 @@ internal object Fn_DATE_ADD_DAY__INT64_DATE__DATE : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.toInt64().value!! return dateValue(datetimeValue.plusDays(intervalValue)) } } @@ -87,8 +89,9 @@ internal object Fn_DATE_ADD_DAY__INT_DATE__DATE : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - // TODO: We need to consider overflow here - val intervalValue = interval.long!! + val intervalValue = try { interval.toInt64().value!! } catch (e: DataException) { + throw TypeCheckException() + } return dateValue(datetimeValue.plusDays(intervalValue)) } } @@ -111,7 +114,7 @@ internal object Fn_DATE_ADD_DAY__INT32_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.toInt64().value!! return timestampValue(datetimeValue.plusDays(intervalValue)) } } @@ -134,7 +137,7 @@ internal object Fn_DATE_ADD_DAY__INT64_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.toInt64().value!! return timestampValue(datetimeValue.plusDays(intervalValue)) } } @@ -157,8 +160,9 @@ internal object Fn_DATE_ADD_DAY__INT_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - // TODO: We need to consider overflow here - val intervalValue = interval.long!! + val intervalValue = try { interval.toInt64().value!! } catch (e: DataException) { + throw TypeCheckException() + } return timestampValue(datetimeValue.plusDays(intervalValue)) } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddHour.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddHour.kt index cb8994c4f8..72d59304cc 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddHour.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddHour.kt @@ -3,6 +3,8 @@ package org.partiql.spi.connector.sql.builtins +import org.partiql.errors.DataException +import org.partiql.errors.TypeCheckException import org.partiql.spi.fn.Fn import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnParameter @@ -41,7 +43,7 @@ internal object Fn_DATE_ADD_HOUR__INT32_TIME__TIME : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.toInt64().value!! return timeValue(datetimeValue.plusHours(intervalValue)) } } @@ -64,7 +66,7 @@ internal object Fn_DATE_ADD_HOUR__INT64_TIME__TIME : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.value!! return timeValue(datetimeValue.plusHours(intervalValue)) } } @@ -87,8 +89,7 @@ internal object Fn_DATE_ADD_HOUR__INT_TIME__TIME : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - // TODO: We need to consider overflow here - val intervalValue = interval.long!! + val intervalValue = try { interval.toInt64().value!! } catch (e: DataException) { throw TypeCheckException() } return timeValue(datetimeValue.plusHours(intervalValue)) } } @@ -111,7 +112,7 @@ internal object Fn_DATE_ADD_HOUR__INT32_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.toInt64().value!! return timestampValue(datetimeValue.plusHours(intervalValue)) } } @@ -134,7 +135,7 @@ internal object Fn_DATE_ADD_HOUR__INT64_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.value!! return timestampValue(datetimeValue.plusHours(intervalValue)) } } @@ -157,8 +158,7 @@ internal object Fn_DATE_ADD_HOUR__INT_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - // TODO: We need to consider overflow here - val intervalValue = interval.long!! + val intervalValue = try { interval.toInt64().value!! } catch (e: DataException) { throw TypeCheckException() } return timestampValue(datetimeValue.plusHours(intervalValue)) } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddMinute.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddMinute.kt index 39092deb2a..b8c11b2910 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddMinute.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddMinute.kt @@ -3,6 +3,8 @@ package org.partiql.spi.connector.sql.builtins +import org.partiql.errors.DataException +import org.partiql.errors.TypeCheckException import org.partiql.spi.fn.Fn import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnParameter @@ -41,7 +43,7 @@ internal object Fn_DATE_ADD_MINUTE__INT32_TIME__TIME : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.toInt64().value!! return timeValue(datetimeValue.plusMinutes(intervalValue)) } } @@ -64,7 +66,7 @@ internal object Fn_DATE_ADD_MINUTE__INT64_TIME__TIME : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.value!! return timeValue(datetimeValue.plusMinutes(intervalValue)) } } @@ -87,8 +89,7 @@ internal object Fn_DATE_ADD_MINUTE__INT_TIME__TIME : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - // TODO: We need to consider overflow here - val intervalValue = interval.long!! + val intervalValue = try { interval.toInt64().value!! } catch (e: DataException) { throw TypeCheckException() } return timeValue(datetimeValue.plusMinutes(intervalValue)) } } @@ -111,7 +112,7 @@ internal object Fn_DATE_ADD_MINUTE__INT32_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.toInt64().value!! return timestampValue(datetimeValue.plusMinutes(intervalValue)) } } @@ -134,7 +135,7 @@ internal object Fn_DATE_ADD_MINUTE__INT64_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.value!! return timestampValue(datetimeValue.plusMinutes(intervalValue)) } } @@ -157,8 +158,7 @@ internal object Fn_DATE_ADD_MINUTE__INT_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - // TODO: We need to consider overflow here - val intervalValue = interval.long!! + val intervalValue = try { interval.toInt64().value!! } catch (e: DataException) { throw TypeCheckException() } return timestampValue(datetimeValue.plusMinutes(intervalValue)) } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddMonth.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddMonth.kt index ffdc65a2d4..c243dea3d1 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddMonth.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddMonth.kt @@ -3,6 +3,8 @@ package org.partiql.spi.connector.sql.builtins +import org.partiql.errors.DataException +import org.partiql.errors.TypeCheckException import org.partiql.spi.fn.Fn import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnParameter @@ -41,7 +43,7 @@ internal object Fn_DATE_ADD_MONTH__INT32_DATE__DATE : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.toInt64().value!! return dateValue(datetimeValue.plusMonths(intervalValue)) } } @@ -64,7 +66,7 @@ internal object Fn_DATE_ADD_MONTH__INT64_DATE__DATE : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.value!! return dateValue(datetimeValue.plusMonths(intervalValue)) } } @@ -87,7 +89,7 @@ internal object Fn_DATE_ADD_MONTH__INT_DATE__DATE : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = try { interval.toInt64().value!! } catch (e: DataException) { throw TypeCheckException() } return dateValue(datetimeValue.plusMonths(intervalValue)) } } @@ -110,7 +112,7 @@ internal object Fn_DATE_ADD_MONTH__INT32_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.toInt64().value!! return timestampValue(datetimeValue.plusMonths(intervalValue)) } } @@ -133,7 +135,7 @@ internal object Fn_DATE_ADD_MONTH__INT64_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.value!! return timestampValue(datetimeValue.plusMonths(intervalValue)) } } @@ -156,8 +158,7 @@ internal object Fn_DATE_ADD_MONTH__INT_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - // TODO: We need to consider overflow here - val intervalValue = interval.long!! + val intervalValue = try { interval.toInt64().value!! } catch (e: DataException) { throw TypeCheckException() } return timestampValue(datetimeValue.plusMonths(intervalValue)) } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddSecond.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddSecond.kt index 8db624de40..8d1715037d 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddSecond.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddSecond.kt @@ -3,6 +3,8 @@ package org.partiql.spi.connector.sql.builtins +import org.partiql.errors.DataException +import org.partiql.errors.TypeCheckException import org.partiql.spi.fn.Fn import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnParameter @@ -41,7 +43,7 @@ internal object Fn_DATE_ADD_SECOND__INT32_TIME__TIME : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.toInt64().value!! return timeValue(datetimeValue.plusSeconds(intervalValue)) } } @@ -64,7 +66,7 @@ internal object Fn_DATE_ADD_SECOND__INT64_TIME__TIME : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.value!! return timeValue(datetimeValue.plusSeconds(intervalValue)) } } @@ -87,8 +89,7 @@ internal object Fn_DATE_ADD_SECOND__INT_TIME__TIME : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - // TODO: We need to consider overflow here - val intervalValue = interval.long!! + val intervalValue = try { interval.toInt64().value!! } catch (e: DataException) { throw TypeCheckException() } return timeValue(datetimeValue.plusSeconds(intervalValue)) } } @@ -111,7 +112,7 @@ internal object Fn_DATE_ADD_SECOND__INT32_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.toInt64().value!! return timestampValue(datetimeValue.plusSeconds(intervalValue)) } } @@ -134,7 +135,7 @@ internal object Fn_DATE_ADD_SECOND__INT64_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.value!! return timestampValue(datetimeValue.plusSeconds(intervalValue)) } } @@ -157,8 +158,7 @@ internal object Fn_DATE_ADD_SECOND__INT_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - // TODO: We need to consider overflow here - val intervalValue = interval.long!! + val intervalValue = try { interval.toInt64().value!! } catch (e: DataException) { throw TypeCheckException() } return timestampValue(datetimeValue.plusSeconds(intervalValue)) } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddYear.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddYear.kt index facf0767f1..3d34375e42 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddYear.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnDateAddYear.kt @@ -3,6 +3,8 @@ package org.partiql.spi.connector.sql.builtins +import org.partiql.errors.DataException +import org.partiql.errors.TypeCheckException import org.partiql.spi.fn.Fn import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnParameter @@ -41,7 +43,7 @@ internal object Fn_DATE_ADD_YEAR__INT32_DATE__DATE : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.toInt64().value!! return dateValue(datetimeValue.plusYears(intervalValue)) } } @@ -64,7 +66,7 @@ internal object Fn_DATE_ADD_YEAR__INT64_DATE__DATE : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.value!! return dateValue(datetimeValue.plusYears(intervalValue)) } } @@ -87,8 +89,7 @@ internal object Fn_DATE_ADD_YEAR__INT_DATE__DATE : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - // TODO: We need to consider overflow here - val intervalValue = interval.long!! + val intervalValue = try { interval.toInt64().value!! } catch (e: DataException) { throw TypeCheckException() } return dateValue(datetimeValue.plusYears(intervalValue)) } } @@ -111,7 +112,7 @@ internal object Fn_DATE_ADD_YEAR__INT32_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.toInt64().value!! return timestampValue(datetimeValue.plusYears(intervalValue)) } } @@ -134,7 +135,7 @@ internal object Fn_DATE_ADD_YEAR__INT64_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - val intervalValue = interval.long!! + val intervalValue = interval.value!! return timestampValue(datetimeValue.plusYears(intervalValue)) } } @@ -157,8 +158,7 @@ internal object Fn_DATE_ADD_YEAR__INT_TIMESTAMP__TIMESTAMP : Fn { val interval = args[0].check() val datetime = args[1].check() val datetimeValue = datetime.value!! - // TODO: We need to consider overflow here - val intervalValue = interval.long!! + val intervalValue = try { interval.toInt64().value!! } catch (e: DataException) { throw TypeCheckException() } return timestampValue(datetimeValue.plusYears(intervalValue)) } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnIsChar.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnIsChar.kt index d2488d9ad8..9b1c812811 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnIsChar.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnIsChar.kt @@ -54,7 +54,7 @@ internal object Fn_IS_CHAR__INT32_ANY__BOOL : Fn { if (value !is StringValue) { return boolValue(false) } - val length = args[0].check().int + val length = args[0].check().value if (length == null || length < 0) { throw TypeCheckException() } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnIsString.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnIsString.kt index 027df96465..8470dbbdca 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnIsString.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnIsString.kt @@ -53,7 +53,7 @@ internal object Fn_IS_STRING__INT32_ANY__BOOL : Fn { if (v !is StringValue) { return boolValue(false) } - val length = args[0].check().int + val length = args[0].check().value if (length == null || length < 0) { throw TypeCheckException() } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnSubstring.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnSubstring.kt index d7626c8a99..88655694f4 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnSubstring.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnSubstring.kt @@ -3,6 +3,7 @@ package org.partiql.spi.connector.sql.builtins +import org.partiql.errors.DataException import org.partiql.errors.TypeCheckException import org.partiql.spi.fn.Fn import org.partiql.spi.fn.FnExperimental @@ -108,7 +109,7 @@ internal object Fn_SUBSTRING__STRING_INT64__STRING : Fn { override fun invoke(args: Array): PartiQLValue { val value = args[0].check().string!! - val start = args[1].check().int!! + val start = try { args[1].check().toInt32().value!! } catch (e: DataException) { throw TypeCheckException() } val result = value.codepointSubstring(start) return stringValue(result) } @@ -131,8 +132,8 @@ internal object Fn_SUBSTRING__STRING_INT64_INT64__STRING : Fn { override fun invoke(args: Array): PartiQLValue { val value = args[0].check().string!! - val start = args[1].check().int!! - val end = args[2].check().int!! + val start = try { args[1].check().toInt32().value!! } catch (e: DataException) { throw TypeCheckException() } + val end = try { args[2].check().toInt32().value!! } catch (e: DataException) { throw TypeCheckException() } if (end < 0) throw TypeCheckException() val result = value.codepointSubstring(start, end) return stringValue(result) @@ -155,7 +156,7 @@ internal object Fn_SUBSTRING__SYMBOL_INT64__SYMBOL : Fn { override fun invoke(args: Array): PartiQLValue { val value = args[0].check().string!! - val start = args[1].check().int!! + val start = try { args[1].check().toInt32().value!! } catch (e: DataException) { throw TypeCheckException() } val result = value.codepointSubstring(start) return symbolValue(result) } @@ -178,8 +179,8 @@ internal object Fn_SUBSTRING__SYMBOL_INT64_INT64__SYMBOL : Fn { override fun invoke(args: Array): PartiQLValue { val value = args[0].check().string!! - val start = args[1].check().int!! - val end = args[2].check().int!! + val start = try { args[1].check().toInt32().value!! } catch (e: DataException) { throw TypeCheckException() } + val end = try { args[2].check().toInt32().value!! } catch (e: DataException) { throw TypeCheckException() } if (end < 0) throw TypeCheckException() val result = value.codepointSubstring(start, end) return symbolValue(result) @@ -202,7 +203,7 @@ internal object Fn_SUBSTRING__CLOB_INT64__CLOB : Fn { override fun invoke(args: Array): PartiQLValue { val value = args[0].check().string!! - val start = args[1].check().int!! + val start = try { args[1].check().toInt32().value!! } catch (e: DataException) { throw TypeCheckException() } val result = value.codepointSubstring(start) return clobValue(result.toByteArray()) } @@ -225,8 +226,8 @@ internal object Fn_SUBSTRING__CLOB_INT64_INT64__CLOB : Fn { override fun invoke(args: Array): PartiQLValue { val value = args[0].check().string!! - val start = args[1].check().int!! - val end = args[2].check().int!! + val start = try { args[1].check().toInt32().value!! } catch (e: DataException) { throw TypeCheckException() } + val end = try { args[2].check().toInt32().value!! } catch (e: DataException) { throw TypeCheckException() } if (end < 0) throw TypeCheckException() val result = value.codepointSubstring(start, end) return clobValue(result.toByteArray()) diff --git a/partiql-types/src/main/kotlin/org/partiql/value/PartiQLValue.kt b/partiql-types/src/main/kotlin/org/partiql/value/PartiQLValue.kt index 91e754e8d2..5e69760f5b 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/PartiQLValue.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/PartiQLValue.kt @@ -94,17 +94,21 @@ public abstract class BoolValue : ScalarValue { @PartiQLValueExperimental public sealed class NumericValue : ScalarValue { - public val int: Int? - get() = value?.toInt() + public abstract fun toInt8(): Int8Value - public val long: Long? - get() = value?.toLong() + public abstract fun toInt16(): Int16Value - public val float: Float? - get() = value?.toFloat() + public abstract fun toInt32(): Int32Value - public val double: Double? - get() = value?.toDouble() + public abstract fun toInt64(): Int64Value + + public abstract fun toInt(): IntValue + + public abstract fun toDecimal(): DecimalValue + + public abstract fun toFloat32(): Float32Value + + public abstract fun toFloat64(): Float64Value abstract override fun copy(annotations: Annotations): NumericValue @@ -514,6 +518,7 @@ public abstract class StructValue : PartiQLValue { } override fun hashCode(): Int { + // TODO return entries.hashCode() } @@ -532,6 +537,8 @@ public abstract class NullValue : PartiQLValue { override val isNull: Boolean = true + public abstract fun withType(type: PartiQLValueType): PartiQLValue + abstract override fun copy(annotations: Annotations): NullValue abstract override fun withAnnotations(annotations: Annotations): NullValue diff --git a/partiql-types/src/main/kotlin/org/partiql/value/impl/DecimalValueImpl.kt b/partiql-types/src/main/kotlin/org/partiql/value/impl/DecimalValueImpl.kt index 8d8c455a8d..62784fb43c 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/impl/DecimalValueImpl.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/impl/DecimalValueImpl.kt @@ -16,11 +16,27 @@ package org.partiql.value.impl import kotlinx.collections.immutable.PersistentList import kotlinx.collections.immutable.toPersistentList +import org.partiql.errors.DataException import org.partiql.value.Annotations import org.partiql.value.DecimalValue +import org.partiql.value.Float32Value +import org.partiql.value.Float64Value +import org.partiql.value.Int16Value +import org.partiql.value.Int32Value +import org.partiql.value.Int64Value +import org.partiql.value.Int8Value +import org.partiql.value.IntValue import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.float32Value +import org.partiql.value.float64Value +import org.partiql.value.int16Value +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.int8Value +import org.partiql.value.intValue import org.partiql.value.util.PartiQLValueVisitor import java.math.BigDecimal +import java.math.RoundingMode @OptIn(PartiQLValueExperimental::class) internal data class DecimalValueImpl( @@ -34,5 +50,55 @@ internal data class DecimalValueImpl( override fun withoutAnnotations(): DecimalValue = _withoutAnnotations() + // permits if no leading significant digits loss + // rounding down for cast + override fun toInt8(): Int8Value = + try { + int8Value(this.value?.setScale(0, RoundingMode.DOWN)?.byteValueExact(), annotations) + } catch (e: ArithmeticException) { + throw DataException("Overflow when casting ${this.value} to INT8") + } + + override fun toInt16(): Int16Value = + try { + int16Value(this.value?.setScale(0, RoundingMode.DOWN)?.shortValueExact(), annotations) + } catch (e: ArithmeticException) { + throw DataException("Overflow when casting ${this.value} to INT16") + } + + override fun toInt32(): Int32Value = + try { + int32Value(this.value?.setScale(0, RoundingMode.DOWN)?.intValueExact(), annotations) + } catch (e: ArithmeticException) { + throw DataException("Overflow when casting ${this.value} to INT32") + } + + override fun toInt64(): Int64Value = + try { + int64Value(this.value?.setScale(0, RoundingMode.DOWN)?.longValueExact(), annotations) + } catch (e: ArithmeticException) { + throw DataException("Overflow when casting ${this.value} to INT64") + } + + override fun toInt(): IntValue = intValue(this.value?.setScale(0, RoundingMode.DOWN)?.toBigInteger(), annotations) + + override fun toDecimal(): DecimalValue = this + + override fun toFloat32(): Float32Value { + val float = this.value?.toFloat() + if (float == Float.NEGATIVE_INFINITY || float == Float.NEGATIVE_INFINITY) { + throw DataException("Overflow when casting ${this.value} to FLOAT32") + } + return float32Value(float, annotations) + } + + override fun toFloat64(): Float64Value { + val double = this.value?.toDouble() + if (double == Double.NEGATIVE_INFINITY || double == Double.NEGATIVE_INFINITY) { + throw DataException("Overflow when casting ${this.value} to FLOAT64") + } + return float64Value(double, annotations) + } + override fun accept(visitor: PartiQLValueVisitor, ctx: C): R = visitor.visitDecimal(this, ctx) } diff --git a/partiql-types/src/main/kotlin/org/partiql/value/impl/Float32ValueImpl.kt b/partiql-types/src/main/kotlin/org/partiql/value/impl/Float32ValueImpl.kt index 54278cf7b5..2502223036 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/impl/Float32ValueImpl.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/impl/Float32ValueImpl.kt @@ -16,10 +16,26 @@ package org.partiql.value.impl import kotlinx.collections.immutable.PersistentList import kotlinx.collections.immutable.toPersistentList +import org.partiql.errors.DataException import org.partiql.value.Annotations +import org.partiql.value.DecimalValue import org.partiql.value.Float32Value +import org.partiql.value.Float64Value +import org.partiql.value.Int16Value +import org.partiql.value.Int32Value +import org.partiql.value.Int64Value +import org.partiql.value.Int8Value +import org.partiql.value.IntValue import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.decimalValue +import org.partiql.value.float64Value +import org.partiql.value.int16Value +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.int8Value +import org.partiql.value.intValue import org.partiql.value.util.PartiQLValueVisitor +import java.math.BigDecimal @OptIn(PartiQLValueExperimental::class) internal data class Float32ValueImpl( @@ -32,6 +48,58 @@ internal data class Float32ValueImpl( override fun withAnnotations(annotations: Annotations): Float32Value = _withAnnotations(annotations) override fun withoutAnnotations(): Float32Value = _withoutAnnotations() + override fun toInt8(): Int8Value { + if (this.value == null) { + return int8Value(null, annotations) + } + if (this.value > Byte.MAX_VALUE || this.value < Byte.MIN_VALUE) { + throw DataException("Overflow when casting ${this.value} to INT8") + } + return int8Value(this.value.toInt().toByte(), annotations) + } + + override fun toInt16(): Int16Value { + if (this.value == null) { + return int16Value(null, annotations) + } + if (this.value > Short.MAX_VALUE || this.value < Short.MIN_VALUE) { + throw DataException("Overflow when casting ${this.value} to INT16") + } + return int16Value(this.value.toInt().toShort(), annotations) + } + + override fun toInt32(): Int32Value { + if (this.value == null) { + return int32Value(null, annotations) + } + if (this.value > Int.MAX_VALUE || this.value < Int.MIN_VALUE) { + throw DataException("Overflow when casting ${this.value} to INT32") + } + return int32Value(this.value.toInt(), annotations) + } + + override fun toInt64(): Int64Value { + if (this.value == null) { + return int64Value(null, annotations) + } + if (this.value > Long.MAX_VALUE || this.value < Long.MIN_VALUE) { + throw DataException("Overflow when casting ${this.value} to INT64") + } + return int64Value(this.value.toLong(), annotations) + } + + override fun toInt(): IntValue = + intValue(this.value?.toDouble()?.let { BigDecimal(it) }?.toBigInteger(), annotations) + + // TODO: FIX-ME: + // This first convert the float value to bigDecimal + // which mess up with precision. + override fun toDecimal(): DecimalValue = + decimalValue(this.value?.toDouble()?.let { BigDecimal(it) }, annotations) + + override fun toFloat32(): Float32Value = this + + override fun toFloat64(): Float64Value = float64Value(this.value?.toDouble(), annotations) override fun accept(visitor: PartiQLValueVisitor, ctx: C): R = visitor.visitFloat32(this, ctx) } diff --git a/partiql-types/src/main/kotlin/org/partiql/value/impl/Float64ValueImpl.kt b/partiql-types/src/main/kotlin/org/partiql/value/impl/Float64ValueImpl.kt index ee337a3592..6d40fd213e 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/impl/Float64ValueImpl.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/impl/Float64ValueImpl.kt @@ -16,10 +16,26 @@ package org.partiql.value.impl import kotlinx.collections.immutable.PersistentList import kotlinx.collections.immutable.toPersistentList +import org.partiql.errors.DataException import org.partiql.value.Annotations +import org.partiql.value.DecimalValue +import org.partiql.value.Float32Value import org.partiql.value.Float64Value +import org.partiql.value.Int16Value +import org.partiql.value.Int32Value +import org.partiql.value.Int64Value +import org.partiql.value.Int8Value +import org.partiql.value.IntValue import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.decimalValue +import org.partiql.value.float32Value +import org.partiql.value.int16Value +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.int8Value +import org.partiql.value.intValue import org.partiql.value.util.PartiQLValueVisitor +import java.math.BigDecimal @OptIn(PartiQLValueExperimental::class) internal data class Float64ValueImpl( @@ -31,6 +47,63 @@ internal data class Float64ValueImpl( override fun withAnnotations(annotations: Annotations): Float64Value = _withAnnotations(annotations) override fun withoutAnnotations(): Float64Value = _withoutAnnotations() + override fun toInt8(): Int8Value { + if (this.value == null) { + return int8Value(null, annotations) + } + if (this.value > Byte.MAX_VALUE || this.value < Byte.MIN_VALUE) { + throw DataException("Overflow when casting ${this.value} to INT8") + } + return int8Value(this.value.toInt().toByte(), annotations) + } + + override fun toInt16(): Int16Value { + if (this.value == null) { + return int16Value(null, annotations) + } + if (this.value > Short.MAX_VALUE || this.value < Short.MIN_VALUE) { + throw DataException("Overflow when casting ${this.value} to INT16") + } + return int16Value(this.value.toInt().toShort(), annotations) + } + + override fun toInt32(): Int32Value { + if (this.value == null) { + return int32Value(null, annotations) + } + if (this.value > Int.MAX_VALUE || this.value < Int.MIN_VALUE) { + throw DataException("Overflow when casting ${this.value} to INT32") + } + return int32Value(this.value.toInt(), annotations) + } + + override fun toInt64(): Int64Value { + if (this.value == null) { + return int64Value(null, annotations) + } + if (this.value > Long.MAX_VALUE || this.value < Long.MIN_VALUE) { + throw DataException("Overflow when casting ${this.value} to INT64") + } + return int64Value(this.value.toLong(), annotations) + } + + override fun toInt(): IntValue = + intValue(this.value?.let { BigDecimal(it) }?.toBigInteger(), annotations) + + override fun toDecimal(): DecimalValue = + decimalValue(this.value?.let { BigDecimal(it) }, annotations) + + override fun toFloat32(): Float32Value { + if (this.value == null) { + return float32Value(null, annotations) + } + if (this.value > Float.MAX_VALUE || this.value < Float.MIN_VALUE) { + throw DataException("Overflow when casting ${this.value} to Float32") + } + return float32Value(this.value.toFloat(), annotations) + } + + override fun toFloat64(): Float64Value = this override fun accept(visitor: PartiQLValueVisitor, ctx: C): R = visitor.visitFloat64(this, ctx) } diff --git a/partiql-types/src/main/kotlin/org/partiql/value/impl/Int16ValueImpl.kt b/partiql-types/src/main/kotlin/org/partiql/value/impl/Int16ValueImpl.kt index 7f3dc8bd6c..56e4510d04 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/impl/Int16ValueImpl.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/impl/Int16ValueImpl.kt @@ -16,10 +16,27 @@ package org.partiql.value.impl import kotlinx.collections.immutable.PersistentList import kotlinx.collections.immutable.toPersistentList +import org.partiql.errors.DataException import org.partiql.value.Annotations +import org.partiql.value.DecimalValue +import org.partiql.value.Float32Value +import org.partiql.value.Float64Value import org.partiql.value.Int16Value +import org.partiql.value.Int32Value +import org.partiql.value.Int64Value +import org.partiql.value.Int8Value +import org.partiql.value.IntValue import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.datetime.DateTimeUtil.toBigDecimal +import org.partiql.value.decimalValue +import org.partiql.value.float32Value +import org.partiql.value.float64Value +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.int8Value +import org.partiql.value.intValue import org.partiql.value.util.PartiQLValueVisitor +import java.math.BigInteger @OptIn(PartiQLValueExperimental::class) internal data class Int16ValueImpl( @@ -32,6 +49,27 @@ internal data class Int16ValueImpl( override fun withAnnotations(annotations: Annotations): Int16Value = _withAnnotations(annotations) override fun withoutAnnotations(): Int16Value = _withoutAnnotations() + override fun toInt8(): Int8Value { + val byte = this.value?.toByte() ?: return int8Value(null, annotations) + if (byte.toShort() != this.value) { + throw DataException("Overflow when casting ${this.value} to INT8") + } + return int8Value(byte, annotations) + } + + override fun toInt16(): Int16Value = this + + override fun toInt32(): Int32Value = int32Value(this.value?.toInt(), annotations) + + override fun toInt64(): Int64Value = int64Value(this.value?.toLong(), annotations) + + override fun toInt(): IntValue = intValue(this.value?.toLong()?.let { BigInteger.valueOf(it) }, annotations) + + override fun toDecimal(): DecimalValue = decimalValue(this.value?.toBigDecimal(), annotations) + + override fun toFloat32(): Float32Value = float32Value(this.value?.toFloat(), annotations) + + override fun toFloat64(): Float64Value = float64Value(this.value?.toDouble(), annotations) override fun accept(visitor: PartiQLValueVisitor, ctx: C): R = visitor.visitInt16(this, ctx) } diff --git a/partiql-types/src/main/kotlin/org/partiql/value/impl/Int32ValueImpl.kt b/partiql-types/src/main/kotlin/org/partiql/value/impl/Int32ValueImpl.kt index 055bd73bdf..c6fce319f1 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/impl/Int32ValueImpl.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/impl/Int32ValueImpl.kt @@ -16,10 +16,27 @@ package org.partiql.value.impl import kotlinx.collections.immutable.PersistentList import kotlinx.collections.immutable.toPersistentList +import org.partiql.errors.DataException import org.partiql.value.Annotations +import org.partiql.value.DecimalValue +import org.partiql.value.Float32Value +import org.partiql.value.Float64Value +import org.partiql.value.Int16Value import org.partiql.value.Int32Value +import org.partiql.value.Int64Value +import org.partiql.value.Int8Value +import org.partiql.value.IntValue import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.datetime.DateTimeUtil.toBigDecimal +import org.partiql.value.decimalValue +import org.partiql.value.float32Value +import org.partiql.value.float64Value +import org.partiql.value.int16Value +import org.partiql.value.int64Value +import org.partiql.value.int8Value +import org.partiql.value.intValue import org.partiql.value.util.PartiQLValueVisitor +import java.math.BigInteger @OptIn(PartiQLValueExperimental::class) internal data class Int32ValueImpl( @@ -32,6 +49,33 @@ internal data class Int32ValueImpl( override fun withAnnotations(annotations: Annotations): Int32Value = _withAnnotations(annotations) override fun withoutAnnotations(): Int32Value = _withoutAnnotations() + override fun toInt8(): Int8Value { + val byte = this.value?.toByte() ?: return int8Value(null, annotations) + if (byte.toInt() != this.value) { + throw DataException("Overflow when casting ${this.value} to INT8") + } + return int8Value(byte, annotations) + } + + override fun toInt16(): Int16Value { + val short = this.value?.toShort() ?: return int16Value(null, annotations) + if (short.toInt() != this.value) { + throw DataException("Overflow when casting ${this.value} to INT16") + } + return int16Value(short, annotations) + } + + override fun toInt32(): Int32Value = this + + override fun toInt64(): Int64Value = int64Value(this.value?.toLong(), annotations) + + override fun toInt(): IntValue = intValue(this.value?.toLong()?.let { BigInteger.valueOf(it) }, annotations) + + override fun toDecimal(): DecimalValue = decimalValue(this.value?.toBigDecimal(), annotations) + + override fun toFloat32(): Float32Value = float32Value(this.value?.toFloat(), annotations) + + override fun toFloat64(): Float64Value = float64Value(this.value?.toDouble(), annotations) override fun accept(visitor: PartiQLValueVisitor, ctx: C): R = visitor.visitInt32(this, ctx) } diff --git a/partiql-types/src/main/kotlin/org/partiql/value/impl/Int64ValueImpl.kt b/partiql-types/src/main/kotlin/org/partiql/value/impl/Int64ValueImpl.kt index 8c7e45590f..63a4c84588 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/impl/Int64ValueImpl.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/impl/Int64ValueImpl.kt @@ -16,10 +16,27 @@ package org.partiql.value.impl import kotlinx.collections.immutable.PersistentList import kotlinx.collections.immutable.toPersistentList +import org.partiql.errors.DataException import org.partiql.value.Annotations +import org.partiql.value.DecimalValue +import org.partiql.value.Float32Value +import org.partiql.value.Float64Value +import org.partiql.value.Int16Value +import org.partiql.value.Int32Value import org.partiql.value.Int64Value +import org.partiql.value.Int8Value +import org.partiql.value.IntValue import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.datetime.DateTimeUtil.toBigDecimal +import org.partiql.value.decimalValue +import org.partiql.value.float32Value +import org.partiql.value.float64Value +import org.partiql.value.int16Value +import org.partiql.value.int32Value +import org.partiql.value.int8Value +import org.partiql.value.intValue import org.partiql.value.util.PartiQLValueVisitor +import java.math.BigInteger @OptIn(PartiQLValueExperimental::class) internal data class Int64ValueImpl( @@ -31,6 +48,39 @@ internal data class Int64ValueImpl( override fun withAnnotations(annotations: Annotations): Int64Value = _withAnnotations(annotations) override fun withoutAnnotations(): Int64Value = _withoutAnnotations() + override fun toInt8(): Int8Value { + val byte = this.value?.toByte() ?: return int8Value(null, annotations) + if (byte.toLong() != this.value) { + throw DataException("Overflow when casting ${this.value} to INT8") + } + return int8Value(byte, annotations) + } + + override fun toInt16(): Int16Value { + val short = this.value?.toShort() ?: return int16Value(null, annotations) + if (short.toLong() != this.value) { + throw DataException("Overflow when casting ${this.value} to INT16") + } + return int16Value(short, annotations) + } + + override fun toInt32(): Int32Value { + val int = this.value?.toInt() ?: return int32Value(null, annotations) + if (int.toLong() != this.value) { + throw DataException("Overflow when casting ${this.value} to INT16") + } + return int32Value(int, annotations) + } + + override fun toInt64(): Int64Value = this + + override fun toInt(): IntValue = intValue(this.value?.let { BigInteger.valueOf(it) }, annotations) + + override fun toDecimal(): DecimalValue = decimalValue(this.value?.toBigDecimal(), annotations) + + override fun toFloat32(): Float32Value = float32Value(this.value?.toFloat(), annotations) + + override fun toFloat64(): Float64Value = float64Value(this.value?.toDouble(), annotations) override fun accept(visitor: PartiQLValueVisitor, ctx: C): R = visitor.visitInt64(this, ctx) } diff --git a/partiql-types/src/main/kotlin/org/partiql/value/impl/Int8ValueImpl.kt b/partiql-types/src/main/kotlin/org/partiql/value/impl/Int8ValueImpl.kt index cb9d4ec489..04baf8df43 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/impl/Int8ValueImpl.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/impl/Int8ValueImpl.kt @@ -17,9 +17,25 @@ package org.partiql.value.impl import kotlinx.collections.immutable.PersistentList import kotlinx.collections.immutable.toPersistentList import org.partiql.value.Annotations +import org.partiql.value.DecimalValue +import org.partiql.value.Float32Value +import org.partiql.value.Float64Value +import org.partiql.value.Int16Value +import org.partiql.value.Int32Value +import org.partiql.value.Int64Value import org.partiql.value.Int8Value +import org.partiql.value.IntValue import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.datetime.DateTimeUtil.toBigDecimal +import org.partiql.value.decimalValue +import org.partiql.value.float32Value +import org.partiql.value.float64Value +import org.partiql.value.int16Value +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.intValue import org.partiql.value.util.PartiQLValueVisitor +import java.math.BigInteger @OptIn(PartiQLValueExperimental::class) internal data class Int8ValueImpl( @@ -32,5 +48,21 @@ internal data class Int8ValueImpl( override fun withAnnotations(annotations: Annotations): Int8Value = _withAnnotations(annotations) override fun withoutAnnotations(): Int8Value = _withoutAnnotations() + override fun toInt8(): Int8Value = this + + override fun toInt16(): Int16Value = int16Value(this.value?.toShort(), annotations) + + override fun toInt32(): Int32Value = int32Value(this.value?.toInt(), annotations) + + override fun toInt64(): Int64Value = int64Value(this.value?.toLong(), annotations) + + override fun toInt(): IntValue = intValue(this.value?.toLong()?.let { BigInteger.valueOf(it) }, annotations) + + override fun toDecimal(): DecimalValue = decimalValue(this.value?.toBigDecimal(), annotations) + + override fun toFloat32(): Float32Value = float32Value(this.value?.toFloat(), annotations) + + override fun toFloat64(): Float64Value = float64Value(this.value?.toDouble(), annotations) + override fun accept(visitor: PartiQLValueVisitor, ctx: C): R = visitor.visitInt8(this, ctx) } diff --git a/partiql-types/src/main/kotlin/org/partiql/value/impl/IntValueImpl.kt b/partiql-types/src/main/kotlin/org/partiql/value/impl/IntValueImpl.kt index 7aa21efe3d..326569501d 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/impl/IntValueImpl.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/impl/IntValueImpl.kt @@ -16,9 +16,24 @@ package org.partiql.value.impl import kotlinx.collections.immutable.PersistentList import kotlinx.collections.immutable.toPersistentList +import org.partiql.errors.DataException import org.partiql.value.Annotations +import org.partiql.value.DecimalValue +import org.partiql.value.Float32Value +import org.partiql.value.Float64Value +import org.partiql.value.Int16Value +import org.partiql.value.Int32Value +import org.partiql.value.Int64Value +import org.partiql.value.Int8Value import org.partiql.value.IntValue import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.decimalValue +import org.partiql.value.float32Value +import org.partiql.value.float64Value +import org.partiql.value.int16Value +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.int8Value import org.partiql.value.util.PartiQLValueVisitor import java.math.BigInteger @@ -33,6 +48,41 @@ internal data class IntValueImpl( override fun withAnnotations(annotations: Annotations): IntValue = _withAnnotations(annotations) override fun withoutAnnotations(): IntValue = _withoutAnnotations() + override fun toInt8(): Int8Value = + try { + int8Value(this.value?.byteValueExact(), annotations) + } catch (e: ArithmeticException) { + throw DataException("Overflow when casting ${this.value} to INT8") + } + + override fun toInt16(): Int16Value = + try { + int16Value(this.value?.shortValueExact(), annotations) + } catch (e: ArithmeticException) { + throw DataException("Overflow when casting ${this.value} to INT16") + } + + override fun toInt32(): Int32Value = + try { + int32Value(this.value?.intValueExact(), annotations) + } catch (e: ArithmeticException) { + throw DataException("Overflow when casting ${this.value} to INT32") + } + + override fun toInt64(): Int64Value = + try { + int64Value(this.value?.longValueExact(), annotations) + } catch (e: ArithmeticException) { + throw DataException("Overflow when casting ${this.value} to INT64") + } + + override fun toInt(): IntValue = this + + override fun toDecimal(): DecimalValue = decimalValue(this.value?.toBigDecimal(), annotations) + + override fun toFloat32(): Float32Value = float32Value(this.value?.toFloat(), annotations) + + override fun toFloat64(): Float64Value = float64Value(this.value?.toDouble(), annotations) override fun accept(visitor: PartiQLValueVisitor, ctx: C): R = visitor.visitInt(this, ctx) } diff --git a/partiql-types/src/main/kotlin/org/partiql/value/impl/NullValueImpl.kt b/partiql-types/src/main/kotlin/org/partiql/value/impl/NullValueImpl.kt index c6d28cba25..eb83eafe1e 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/impl/NullValueImpl.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/impl/NullValueImpl.kt @@ -18,13 +18,69 @@ import kotlinx.collections.immutable.PersistentList import kotlinx.collections.immutable.toPersistentList import org.partiql.value.Annotations import org.partiql.value.NullValue +import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType +import org.partiql.value.bagValue +import org.partiql.value.binaryValue +import org.partiql.value.blobValue +import org.partiql.value.boolValue +import org.partiql.value.byteValue +import org.partiql.value.charValue +import org.partiql.value.clobValue +import org.partiql.value.dateValue +import org.partiql.value.decimalValue +import org.partiql.value.float32Value +import org.partiql.value.float64Value +import org.partiql.value.int16Value +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.int8Value +import org.partiql.value.intValue +import org.partiql.value.intervalValue +import org.partiql.value.listValue +import org.partiql.value.sexpValue +import org.partiql.value.stringValue +import org.partiql.value.structValue +import org.partiql.value.symbolValue +import org.partiql.value.timeValue +import org.partiql.value.timestampValue import org.partiql.value.util.PartiQLValueVisitor @OptIn(PartiQLValueExperimental::class) internal data class NullValueImpl( override val annotations: PersistentList, ) : NullValue() { + override fun withType(type: PartiQLValueType): PartiQLValue = when (type) { + PartiQLValueType.ANY -> this + PartiQLValueType.BOOL -> boolValue(null, annotations) + PartiQLValueType.INT8 -> int8Value(null, annotations) + PartiQLValueType.INT16 -> int16Value(null, annotations) + PartiQLValueType.INT32 -> int32Value(null, annotations) + PartiQLValueType.INT64 -> int64Value(null, annotations) + PartiQLValueType.INT -> intValue(null, annotations) + PartiQLValueType.DECIMAL -> decimalValue(null, annotations) + PartiQLValueType.DECIMAL_ARBITRARY -> decimalValue(null, annotations) + PartiQLValueType.FLOAT32 -> float32Value(null, annotations) + PartiQLValueType.FLOAT64 -> float64Value(null, annotations) + PartiQLValueType.CHAR -> charValue(null, annotations) + PartiQLValueType.STRING -> stringValue(null, annotations) + PartiQLValueType.SYMBOL -> symbolValue(null, annotations) + PartiQLValueType.BINARY -> binaryValue(null, annotations) + PartiQLValueType.BYTE -> byteValue(null, annotations) + PartiQLValueType.BLOB -> blobValue(null, annotations) + PartiQLValueType.CLOB -> clobValue(null, annotations) + PartiQLValueType.DATE -> dateValue(null, annotations) + PartiQLValueType.TIME -> timeValue(null, annotations) + PartiQLValueType.TIMESTAMP -> timestampValue(null, annotations) + PartiQLValueType.INTERVAL -> intervalValue(null, annotations) + PartiQLValueType.BAG -> bagValue(null, annotations) + PartiQLValueType.LIST -> listValue(null, annotations) + PartiQLValueType.SEXP -> sexpValue(null, annotations) + PartiQLValueType.STRUCT -> structValue(null, annotations) + PartiQLValueType.NULL -> this + PartiQLValueType.MISSING -> error("cast to missing not supported") + } override fun copy(annotations: Annotations) = NullValueImpl(annotations.toPersistentList()) diff --git a/plugins/partiql-plugin/src/main/kotlin/org/partiql/plugin/internal/fn/scalar/FnDateAddYear.kt b/plugins/partiql-plugin/src/main/kotlin/org/partiql/plugin/internal/fn/scalar/FnDateAddYear.kt new file mode 100644 index 0000000000..c8b736c236 --- /dev/null +++ b/plugins/partiql-plugin/src/main/kotlin/org/partiql/plugin/internal/fn/scalar/FnDateAddYear.kt @@ -0,0 +1,188 @@ +// ktlint-disable filename +@file:Suppress("ClassName") + +package org.partiql.plugin.internal.fn.scalar + +import org.partiql.errors.DataException +import org.partiql.errors.TypeCheckException +import org.partiql.spi.function.PartiQLFunction +import org.partiql.spi.function.PartiQLFunctionExperimental +import org.partiql.types.function.FunctionParameter +import org.partiql.types.function.FunctionSignature +import org.partiql.value.DateValue +import org.partiql.value.Int32Value +import org.partiql.value.Int64Value +import org.partiql.value.IntValue +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType.DATE +import org.partiql.value.PartiQLValueType.INT +import org.partiql.value.PartiQLValueType.INT32 +import org.partiql.value.PartiQLValueType.INT64 +import org.partiql.value.PartiQLValueType.TIMESTAMP +import org.partiql.value.TimestampValue +import org.partiql.value.check +import org.partiql.value.dateValue +import org.partiql.value.timestampValue + +@OptIn(PartiQLValueExperimental::class, PartiQLFunctionExperimental::class) +internal object Fn_DATE_ADD_YEAR__INT32_DATE__DATE : PartiQLFunction.Scalar { + + override val signature = FunctionSignature.Scalar( + name = "date_add_year", + returns = DATE, + parameters = listOf( + FunctionParameter("interval", INT32), + FunctionParameter("datetime", DATE), + ), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): PartiQLValue { + val interval = args[0].check() + val datetime = args[1].check() + return if (datetime.value == null || interval.value == null) { + dateValue(null) + } else { + val datetimeValue = datetime.value!! + val intervalValue = interval.toInt64().value!! + dateValue(datetimeValue.plusYears(intervalValue)) + } + } +} + +@OptIn(PartiQLValueExperimental::class, PartiQLFunctionExperimental::class) +internal object Fn_DATE_ADD_YEAR__INT64_DATE__DATE : PartiQLFunction.Scalar { + + override val signature = FunctionSignature.Scalar( + name = "date_add_year", + returns = DATE, + parameters = listOf( + FunctionParameter("interval", INT64), + FunctionParameter("datetime", DATE), + ), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): PartiQLValue { + val interval = args[0].check() + val datetime = args[1].check() + return if (datetime.value == null || interval.value == null) { + dateValue(null) + } else { + val datetimeValue = datetime.value!! + val intervalValue = interval.value!! + dateValue(datetimeValue.plusYears(intervalValue)) + } + } +} + +@OptIn(PartiQLValueExperimental::class, PartiQLFunctionExperimental::class) +internal object Fn_DATE_ADD_YEAR__INT_DATE__DATE : PartiQLFunction.Scalar { + + override val signature = FunctionSignature.Scalar( + name = "date_add_year", + returns = DATE, + parameters = listOf( + FunctionParameter("interval", INT), + FunctionParameter("datetime", DATE), + ), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): PartiQLValue { + val interval = args[0].check() + val datetime = args[1].check() + return if (datetime.value == null || interval.value == null) { + dateValue(null) + } else { + val datetimeValue = datetime.value!! + val intervalValue = try { interval.toInt64().value!! } catch (e: DataException) { throw TypeCheckException() } + dateValue(datetimeValue.plusYears(intervalValue)) + } + } +} + +@OptIn(PartiQLValueExperimental::class, PartiQLFunctionExperimental::class) +internal object Fn_DATE_ADD_YEAR__INT32_TIMESTAMP__TIMESTAMP : PartiQLFunction.Scalar { + + override val signature = FunctionSignature.Scalar( + name = "date_add_year", + returns = TIMESTAMP, + parameters = listOf( + FunctionParameter("interval", INT32), + FunctionParameter("datetime", TIMESTAMP), + ), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): PartiQLValue { + val interval = args[0].check() + val datetime = args[1].check() + return if (datetime.value == null || interval.value == null) { + timestampValue(null) + } else { + val datetimeValue = datetime.value!! + val intervalValue = interval.toInt64().value!! + timestampValue(datetimeValue.plusYears(intervalValue)) + } + } +} + +@OptIn(PartiQLValueExperimental::class, PartiQLFunctionExperimental::class) +internal object Fn_DATE_ADD_YEAR__INT64_TIMESTAMP__TIMESTAMP : PartiQLFunction.Scalar { + + override val signature = FunctionSignature.Scalar( + name = "date_add_year", + returns = TIMESTAMP, + parameters = listOf( + FunctionParameter("interval", INT64), + FunctionParameter("datetime", TIMESTAMP), + ), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): PartiQLValue { + val interval = args[0].check() + val datetime = args[1].check() + return if (datetime.value == null || interval.value == null) { + timestampValue(null) + } else { + val datetimeValue = datetime.value!! + val intervalValue = interval.value!! + timestampValue(datetimeValue.plusYears(intervalValue)) + } + } +} + +@OptIn(PartiQLValueExperimental::class, PartiQLFunctionExperimental::class) +internal object Fn_DATE_ADD_YEAR__INT_TIMESTAMP__TIMESTAMP : PartiQLFunction.Scalar { + + override val signature = FunctionSignature.Scalar( + name = "date_add_year", + returns = TIMESTAMP, + parameters = listOf( + FunctionParameter("interval", INT), + FunctionParameter("datetime", TIMESTAMP), + ), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): PartiQLValue { + val interval = args[0].check() + val datetime = args[1].check() + return if (datetime.value == null || interval.value == null) { + timestampValue(null) + } else { + val datetimeValue = datetime.value!! + val intervalValue = try { interval.toInt64().value!! } catch (e: DataException) { throw TypeCheckException() } + timestampValue(datetimeValue.plusYears(intervalValue)) + } + } +} diff --git a/plugins/partiql-plugin/src/main/kotlin/org/partiql/plugin/internal/fn/scalar/FnIsChar.kt b/plugins/partiql-plugin/src/main/kotlin/org/partiql/plugin/internal/fn/scalar/FnIsChar.kt new file mode 100644 index 0000000000..ae010524e2 --- /dev/null +++ b/plugins/partiql-plugin/src/main/kotlin/org/partiql/plugin/internal/fn/scalar/FnIsChar.kt @@ -0,0 +1,69 @@ +// ktlint-disable filename +@file:Suppress("ClassName") + +package org.partiql.plugin.internal.fn.scalar + +import org.partiql.errors.TypeCheckException +import org.partiql.spi.function.PartiQLFunction +import org.partiql.spi.function.PartiQLFunctionExperimental +import org.partiql.types.function.FunctionParameter +import org.partiql.types.function.FunctionSignature +import org.partiql.value.CharValue +import org.partiql.value.Int32Value +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType.ANY +import org.partiql.value.PartiQLValueType.BOOL +import org.partiql.value.PartiQLValueType.INT32 +import org.partiql.value.StringValue +import org.partiql.value.boolValue +import org.partiql.value.check + +@OptIn(PartiQLValueExperimental::class, PartiQLFunctionExperimental::class) +internal object Fn_IS_CHAR__ANY__BOOL : PartiQLFunction.Scalar { + + override val signature = FunctionSignature.Scalar( + name = "is_char", + returns = BOOL, + parameters = listOf(FunctionParameter("value", ANY)), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): PartiQLValue { + val arg = args[0] + return if (arg.isNull) { + boolValue(null) + } else { + boolValue(arg is CharValue) + } + } +} + +@OptIn(PartiQLValueExperimental::class, PartiQLFunctionExperimental::class) +internal object Fn_IS_CHAR__INT32_ANY__BOOL : PartiQLFunction.Scalar { + + override val signature = FunctionSignature.Scalar( + name = "is_char", + returns = BOOL, + parameters = listOf( + FunctionParameter("length", INT32), + FunctionParameter("value", ANY), + ), + isNullCall = true, + isNullable = false, + ) + + override fun invoke(args: Array): PartiQLValue { + val length = args[0].check().value + if (length == null || length < 0) { + throw TypeCheckException() + } + val v = args[1] + return when { + v.isNull -> boolValue(null) + v !is StringValue -> boolValue(false) + else -> boolValue(v.value!!.length == length) + } + } +} diff --git a/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt b/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt index eaacbcf22b..4014777f2b 100644 --- a/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt +++ b/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt @@ -2,6 +2,7 @@ package org.partiql.runner.executor import com.amazon.ion.IonStruct import com.amazon.ion.IonValue +import com.amazon.ionelement.api.AnyElement import com.amazon.ionelement.api.ElementType import com.amazon.ionelement.api.StructElement import com.amazon.ionelement.api.toIonElement @@ -12,6 +13,7 @@ import org.partiql.eval.PartiQLStatement import org.partiql.lang.eval.CompileOptions import org.partiql.lang.eval.TypingMode import org.partiql.parser.PartiQLParser +import org.partiql.plan.Statement import org.partiql.planner.PartiQLPlanner import org.partiql.plugins.memory.MemoryCatalog import org.partiql.plugins.memory.MemoryConnector @@ -145,13 +147,32 @@ class EvalExecutor( private fun infer(env: StructElement): Connector { val map = mutableMapOf() env.fields.forEach { - map[it.name] = StaticType.ANY + map[it.name] = inferEnv(it.value) } val catalog = MemoryCatalog("default") catalog.load(env) return MemoryConnector(catalog) } + private fun inferEnv(env: AnyElement): StaticType { + val catalog = MemoryCatalog.builder().name("conformance_test").build() + val connector = MemoryConnector(catalog) + val session = PartiQLPlanner.Session( + queryId = "query", + userId = "user", + currentCatalog = "default", + catalogs = mapOf( + "default" to connector.getMetadata(object : ConnectorSession { + override fun getQueryId(): String = "query" + override fun getUserId(): String = "user" + }) + ) + ) + val stmt = parser.parse("`$env`").root + val plan = planner.plan(stmt, session) + return (plan.plan.statement as Statement.Query).root.type + } + /** * Loads each declared global of the catalog from the data element. * From ec697097993e5463111f9c9fb591a6562e63b2f4 Mon Sep 17 00:00:00 2001 From: Yingtao Liu Date: Tue, 6 Feb 2024 14:37:15 -0800 Subject: [PATCH 2/2] fix merge issue --- .../org/partiql/spi/connector/sql/builtins/FnSubstring.kt | 2 -- 1 file changed, 2 deletions(-) diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnSubstring.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnSubstring.kt index dd2d58825d..33fc29a880 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnSubstring.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnSubstring.kt @@ -228,8 +228,6 @@ internal object Fn_SUBSTRING__CLOB_INT64_INT64__CLOB : Fn { val string = args[0].check().value!!.toString(Charsets.UTF_8) val start = try { args[1].check().toInt32().value!! } catch (e: DataException) { throw TypeCheckException() } val end = try { args[2].check().toInt32().value!! } catch (e: DataException) { throw TypeCheckException() } - val start = args[1].check().int!! - val end = args[2].check().int!! if (end < 0) throw TypeCheckException() val result = string.codepointSubstring(start, end) return clobValue(result.toByteArray())