diff --git a/lang/resources/org/partiql/type-domains/partiql.ion b/lang/resources/org/partiql/type-domains/partiql.ion index 084efa3f94..8224ca3948 100644 --- a/lang/resources/org/partiql/type-domains/partiql.ion +++ b/lang/resources/org/partiql/type-domains/partiql.ion @@ -672,7 +672,7 @@ may then be further optimized by selecting better implementations of each operat // The value of `uniqueId` is PartiQL integration defined and can be any symbol that uniquely // identifies the global variable. Examples include database object ids or the alphabetical case // respecting table name found after case-insensitive lookup. - (global_id name::symbol uniqueId::symbol) + (global_id uniqueId::symbol case::case_sensitivity) ) ) ) diff --git a/lang/src/org/partiql/lang/CompilerPipeline.kt b/lang/src/org/partiql/lang/CompilerPipeline.kt index 57ec44f2a3..f481b3a2ac 100644 --- a/lang/src/org/partiql/lang/CompilerPipeline.kt +++ b/lang/src/org/partiql/lang/CompilerPipeline.kt @@ -39,7 +39,7 @@ import org.partiql.lang.types.StaticType import org.partiql.lang.util.interruptibleFold /** - * Contains all of the information needed for processing steps. + * Contains all information needed for processing steps. */ data class StepContext( /** The instance of [ExprValueFactory] that is used by the pipeline. */ @@ -102,6 +102,11 @@ interface CompilerPipeline { */ val procedures: @JvmSuppressWildcards Map + /** + * The configured global type bindings. + */ + val globalTypeBindings: Bindings? + /** Compiles the specified PartiQL query using the configured parser. */ fun compile(query: String): Expression @@ -244,7 +249,7 @@ internal class CompilerPipelineImpl( override val customDataTypes: List, override val procedures: Map, private val preProcessingSteps: List, - private val globalTypeBindings: Bindings? + override val globalTypeBindings: Bindings? ) : CompilerPipeline { private val compiler = EvaluatingCompiler( diff --git a/lang/src/org/partiql/lang/ast/passes/SemanticException.kt b/lang/src/org/partiql/lang/ast/passes/SemanticException.kt index 7996aa9a83..c32576d104 100644 --- a/lang/src/org/partiql/lang/ast/passes/SemanticException.kt +++ b/lang/src/org/partiql/lang/ast/passes/SemanticException.kt @@ -37,7 +37,7 @@ class SemanticException( constructor(err: Problem, cause: Throwable? = null) : this( message = "", - errorCode = ErrorCode.SEMANTIC_INFERENCER_ERROR, + errorCode = ErrorCode.SEMANTIC_PROBLEM, errorContext = propertyValueMapOf( Property.LINE_NUMBER to err.sourceLocation.lineNum, Property.COLUMN_NUMBER to err.sourceLocation.charOffset, diff --git a/lang/src/org/partiql/lang/errors/ErrorCode.kt b/lang/src/org/partiql/lang/errors/ErrorCode.kt index 256be64e65..61cc8e8e59 100644 --- a/lang/src/org/partiql/lang/errors/ErrorCode.kt +++ b/lang/src/org/partiql/lang/errors/ErrorCode.kt @@ -647,7 +647,7 @@ enum class ErrorCode( "got: ${errorContext?.get(Property.ACTUAL_ARGUMENT_TYPES) ?: UNKNOWN}" }, - SEMANTIC_INFERENCER_ERROR( + SEMANTIC_PROBLEM( ErrorCategory.SEMANTIC, LOCATION + setOf(Property.MESSAGE), "" @@ -980,12 +980,6 @@ enum class ErrorCode( ErrorBehaviorInPermissiveMode.RETURN_MISSING ), - EVALUATOR_SQL_EXCEPTION( - ErrorCategory.EVALUATOR, - LOCATION, - "SQL exception" - ), - EVALUATOR_COUNT_START_NOT_ALLOWED( ErrorCategory.EVALUATOR, LOCATION, diff --git a/lang/src/org/partiql/lang/eval/builtins/DynamicLookupExprFunction.kt b/lang/src/org/partiql/lang/eval/builtins/DynamicLookupExprFunction.kt new file mode 100644 index 0000000000..7d87cc3bd2 --- /dev/null +++ b/lang/src/org/partiql/lang/eval/builtins/DynamicLookupExprFunction.kt @@ -0,0 +1,101 @@ +package org.partiql.lang.eval.builtins + +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.eval.BindingCase +import org.partiql.lang.eval.BindingName +import org.partiql.lang.eval.EvaluationException +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.ExprFunction +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.ExprValueType +import org.partiql.lang.eval.physical.throwUndefinedVariableException +import org.partiql.lang.eval.stringValue +import org.partiql.lang.types.FunctionSignature +import org.partiql.lang.types.StaticType +import org.partiql.lang.types.VarargFormalParameter + +/** + * Performs dynamic variable resolution. Query authors should never call this function directly (and indeed it is + * named to avoid collision with the names of custom functions)--instead, the query planner injects call sites + * to this function to perform dynamic variable resolution of undefined variables. This provides a migration path + * for legacy customers that depend on this behavior. + * + * Arguments: + * + * - variable name + * - case sensitivity + * - lookup strategy (globals then locals or locals then globals) + * - A variadic list of locations to be searched. + * + * The variadic arguments must be of type `any` because the planner doesn't yet have knowledge of static types + * and therefore cannot filter out local variables types that are not structs. + */ +class DynamicLookupExprFunction : ExprFunction { + override val signature: FunctionSignature + get() { + return FunctionSignature( + name = DYNAMIC_LOOKUP_FUNCTION_NAME, + // Required parameters are: variable name, case sensitivity and lookup strategy + requiredParameters = listOf(StaticType.SYMBOL, StaticType.SYMBOL, StaticType.SYMBOL), + variadicParameter = VarargFormalParameter(StaticType.ANY, 0..Int.MAX_VALUE), + returnType = StaticType.ANY + ) + } + + override fun callWithVariadic( + session: EvaluationSession, + required: List, + variadic: List + ): ExprValue { + val variableName = required[0].stringValue() + + val caseSensitivity = when (val caseSensitivityParameterValue = required[1].stringValue()) { + "case_sensitive" -> BindingCase.SENSITIVE + "case_insensitive" -> BindingCase.INSENSITIVE + else -> throw EvaluationException( + message = "Invalid case sensitivity: $caseSensitivityParameterValue", + errorCode = ErrorCode.INTERNAL_ERROR, + internal = true + ) + } + + val bindingName = BindingName(variableName, caseSensitivity) + + val globalsFirst = when (val lookupStrategyParameterValue = required[2].stringValue()) { + "locals_then_globals" -> false + "globals_then_locals" -> true + else -> throw EvaluationException( + message = "Invalid lookup strategy: $lookupStrategyParameterValue", + errorCode = ErrorCode.INTERNAL_ERROR, + internal = true + ) + } + + val found = when { + globalsFirst -> { + session.globals[bindingName] ?: searchLocals(variadic, bindingName) + } + else -> { + searchLocals(variadic, bindingName) ?: session.globals[bindingName] + } + } + + if (found == null) { + // We don't know the metas inside ExprFunction implementations. The ThunkFactory error handlers + // should add line & col info to the exception & rethrow anyway. + throwUndefinedVariableException(bindingName, metas = null) + } else { + return found + } + } + + private fun searchLocals(possibleLocations: List, bindingName: BindingName) = + possibleLocations.asSequence().map { + when (it.type) { + ExprValueType.STRUCT -> + it.bindings[bindingName] + else -> + null + } + }.firstOrNull { it != null } +} diff --git a/lang/src/org/partiql/lang/eval/physical/EvaluatorState.kt b/lang/src/org/partiql/lang/eval/physical/EvaluatorState.kt new file mode 100644 index 0000000000..fd9e711226 --- /dev/null +++ b/lang/src/org/partiql/lang/eval/physical/EvaluatorState.kt @@ -0,0 +1,35 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.lang.eval.physical + +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.ExprValue + +/** + * Contains state needed during query evaluation such as an instance of [EvaluationSession] and an array of [registers] + * for each local variable that is part of the query. + * + * Since the elements of [registers] are mutable, when/if we decide to make query execution multi-threaded, we'll have + * to take care to not share [EvaluatorState] instances among different threads. + * + * @param session The evaluation session. + * @param registers An array of registers containing [ExprValue]s needed during query execution. Generally, there is + * one register per local variable. This is an array (and not a [List]) because its semantics match exactly what we + * need: fixed length but mutable elements. + */ +internal class EvaluatorState( + val session: EvaluationSession, + val registers: Array +) diff --git a/lang/src/org/partiql/lang/eval/physical/OffsetLimitHelpers.kt b/lang/src/org/partiql/lang/eval/physical/OffsetLimitHelpers.kt new file mode 100644 index 0000000000..f68ac2d049 --- /dev/null +++ b/lang/src/org/partiql/lang/eval/physical/OffsetLimitHelpers.kt @@ -0,0 +1,106 @@ +package org.partiql.lang.eval.physical + +import com.amazon.ion.IntegerSize +import com.amazon.ion.IonInt +import org.partiql.lang.ast.SourceLocationMeta +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.eval.ExprValueType +import org.partiql.lang.eval.err +import org.partiql.lang.eval.errorContextFrom +import org.partiql.lang.eval.numberValue + +// The functions in this file look very similar and so the temptation to DRY is quite strong.... +// However, there are enough subtle differences between them that avoiding the duplication isn't worth it. + +internal fun evalLimitRowCount(rowCountThunk: PhysicalPlanThunk, env: EvaluatorState, limitLocationMeta: SourceLocationMeta?): Long { + val limitExprValue = rowCountThunk(env) + + if (limitExprValue.type != ExprValueType.INT) { + err( + "LIMIT value was not an integer", + ErrorCode.EVALUATOR_NON_INT_LIMIT_VALUE, + errorContextFrom(limitLocationMeta).also { + it[Property.ACTUAL_TYPE] = limitExprValue.type.toString() + }, + internal = false + ) + } + + // `Number.toLong()` (used below) does *not* cause an overflow exception if the underlying [Number] + // implementation (i.e. Decimal or BigInteger) exceeds the range that can be represented by Longs. + // This can cause very confusing behavior if the user specifies a LIMIT value that exceeds + // Long.MAX_VALUE, because no results will be returned from their query. That no overflow exception + // is thrown is not a problem as long as PartiQL's restriction of integer values to +/- 2^63 remains. + // We throw an exception here if the value exceeds the supported range (say if we change that + // restriction or if a custom [ExprValue] is provided which exceeds that value). + val limitIonValue = limitExprValue.ionValue as IonInt + if (limitIonValue.integerSize == IntegerSize.BIG_INTEGER) { + err( + "IntegerSize.BIG_INTEGER not supported for LIMIT values", + ErrorCode.INTERNAL_ERROR, + errorContextFrom(limitLocationMeta), + internal = true + ) + } + + val limitValue = limitExprValue.numberValue().toLong() + + if (limitValue < 0) { + err( + "negative LIMIT", + ErrorCode.EVALUATOR_NEGATIVE_LIMIT, + errorContextFrom(limitLocationMeta), + internal = false + ) + } + + // we can't use the Kotlin's Sequence.take(n) for this since it accepts only an integer. + // this references [Sequence.take(count: Long): Sequence] defined in [org.partiql.util]. + return limitValue +} + +internal fun evalOffsetRowCount(rowCountThunk: PhysicalPlanThunk, env: EvaluatorState, offsetLocationMeta: SourceLocationMeta?): Long { + val offsetExprValue = rowCountThunk(env) + + if (offsetExprValue.type != ExprValueType.INT) { + err( + "OFFSET value was not an integer", + ErrorCode.EVALUATOR_NON_INT_OFFSET_VALUE, + errorContextFrom(offsetLocationMeta).also { + it[Property.ACTUAL_TYPE] = offsetExprValue.type.toString() + }, + internal = false + ) + } + + // `Number.toLong()` (used below) does *not* cause an overflow exception if the underlying [Number] + // implementation (i.e. Decimal or BigInteger) exceeds the range that can be represented by Longs. + // This can cause very confusing behavior if the user specifies a OFFSET value that exceeds + // Long.MAX_VALUE, because no results will be returned from their query. That no overflow exception + // is thrown is not a problem as long as PartiQL's restriction of integer values to +/- 2^63 remains. + // We throw an exception here if the value exceeds the supported range (say if we change that + // restriction or if a custom [ExprValue] is provided which exceeds that value). + val offsetIonValue = offsetExprValue.ionValue as IonInt + if (offsetIonValue.integerSize == IntegerSize.BIG_INTEGER) { + err( + "IntegerSize.BIG_INTEGER not supported for OFFSET values", + ErrorCode.INTERNAL_ERROR, + errorContextFrom(offsetLocationMeta), + internal = true + ) + } + + val offsetValue = offsetExprValue.numberValue().toLong() + + if (offsetValue < 0) { + err( + "negative OFFSET", + ErrorCode.EVALUATOR_NEGATIVE_OFFSET, + errorContextFrom(offsetLocationMeta), + internal = false + ) + } + + return offsetValue +} diff --git a/lang/src/org/partiql/lang/eval/physical/PhysicalBexprToThunkConverter.kt b/lang/src/org/partiql/lang/eval/physical/PhysicalBexprToThunkConverter.kt new file mode 100644 index 0000000000..76daad8443 --- /dev/null +++ b/lang/src/org/partiql/lang/eval/physical/PhysicalBexprToThunkConverter.kt @@ -0,0 +1,335 @@ +package org.partiql.lang.eval.physical + +import com.amazon.ionelement.api.BoolElement +import com.amazon.ionelement.api.MetaContainer +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.ExprValueFactory +import org.partiql.lang.eval.ExprValueType +import org.partiql.lang.eval.Thunk +import org.partiql.lang.eval.ThunkValue +import org.partiql.lang.eval.address +import org.partiql.lang.eval.booleanValue +import org.partiql.lang.eval.isUnknown +import org.partiql.lang.eval.name +import org.partiql.lang.eval.relation.RelationIterator +import org.partiql.lang.eval.relation.RelationScope +import org.partiql.lang.eval.relation.RelationType +import org.partiql.lang.eval.relation.relation +import org.partiql.lang.eval.sourceLocationMeta +import org.partiql.lang.eval.unnamedValue +import org.partiql.lang.util.toIntExact + +private val DEFAULT_IMPL = PartiqlPhysical.build { impl("default") } + +/** A specialization of [Thunk] that we use for evaluation of physical plans. */ +internal typealias PhysicalPlanThunk = Thunk + +/** A specialization of [ThunkValue] that we use for evaluation of physical plans. */ +internal typealias PhsycialPlanThunkValue = ThunkValue + +internal class PhysicalBexprToThunkConverter( + private val exprConverter: PhysicalExprToThunkConverter, + private val valueFactory: ExprValueFactory, +) : PartiqlPhysical.Bexpr.Converter { + + private fun blockNonDefaultImpl(i: PartiqlPhysical.Impl) { + if (i != DEFAULT_IMPL) { + TODO("Support non-default operator implementations") + } + } + + override fun convertProject(node: PartiqlPhysical.Bexpr.Project): RelationThunkEnv { + TODO("not implemented") + } + + override fun convertScan(node: PartiqlPhysical.Bexpr.Scan): RelationThunkEnv { + blockNonDefaultImpl(node.i) + + val exprThunk = exprConverter.convert(node.expr) + val asIndex = node.asDecl.index.value.toIntExact() + val atIndex = node.atDecl?.index?.value?.toIntExact() ?: -1 + val byIndex = node.byDecl?.index?.value?.toIntExact() ?: -1 + + return relationThunk(node.metas) { env -> + val valueToScan = exprThunk.invoke(env) + + // coerces non-collection types to a singleton Sequence<>. + val rows: Sequence = when (valueToScan.type) { + ExprValueType.LIST, ExprValueType.BAG -> valueToScan.asSequence() + else -> sequenceOf(valueToScan) + } + + relation(RelationType.BAG) { + var rowsIter: Iterator = rows.iterator() + while (rowsIter.hasNext()) { + val item = rowsIter.next() + env.registers[asIndex] = item.unnamedValue() // Remove any ordinal (output is a bag) + + if (atIndex >= 0) { + env.registers[atIndex] = item.name ?: valueFactory.missingValue + } + + if (byIndex >= 0) { + env.registers[byIndex] = item.address ?: valueFactory.missingValue + } + yield() + } + } + } + } + + override fun convertFilter(node: PartiqlPhysical.Bexpr.Filter): RelationThunkEnv { + blockNonDefaultImpl(node.i) + + val predicateThunk = exprConverter.convert(node.predicate) + val sourceThunk = this.convert(node.source) + + return relationThunk(node.metas) { env -> + val sourceToFilter = sourceThunk(env) + createFilterRelItr(sourceToFilter, predicateThunk, env) + } + } + + override fun convertJoin(node: PartiqlPhysical.Bexpr.Join): RelationThunkEnv { + blockNonDefaultImpl(node.i) + + val leftThunk = this.convert(node.left) + val rightThunk = this.convert(node.right) + val predicateThunk = node.predicate?.let { exprConverter.convert(it).takeIf { !node.predicate.isLitTrue() } } + + return when (node.joinType) { + is PartiqlPhysical.JoinType.Inner -> { + createInnerJoinThunk(node.metas, leftThunk, rightThunk, predicateThunk) + } + is PartiqlPhysical.JoinType.Left -> { + val rightVariableIndexes = node.right.extractAccessibleVarDecls().map { it.index.value.toIntExact() } + createLeftJoinThunk( + joinMetas = node.metas, + leftThunk = leftThunk, + rightThunk = rightThunk, + rightVariableIndexes = rightVariableIndexes, + predicateThunk = predicateThunk + ) + } + is PartiqlPhysical.JoinType.Right -> { + // Note that this is the same as the left join but the right and left sides are swapped. + val leftVariableIndexes = node.left.extractAccessibleVarDecls().map { it.index.value.toIntExact() } + createLeftJoinThunk( + joinMetas = node.metas, + leftThunk = rightThunk, + rightThunk = leftThunk, + rightVariableIndexes = leftVariableIndexes, + predicateThunk = predicateThunk + ) + } + is PartiqlPhysical.JoinType.Full -> TODO("Full join") + } + } + + private fun createInnerJoinThunk( + joinMetas: MetaContainer, + leftThunk: RelationThunkEnv, + rightThunk: RelationThunkEnv, + predicateThunk: PhysicalPlanThunk? + ) = if (predicateThunk == null) { + relationThunk(joinMetas) { env -> + createCrossJoinRelItr(leftThunk, rightThunk, env) + } + } else { + relationThunk(joinMetas) { env -> + val crossJoinRelItr = createCrossJoinRelItr(leftThunk, rightThunk, env) + createFilterRelItr(crossJoinRelItr, predicateThunk, env) + } + } + + private fun createCrossJoinRelItr( + leftThunk: RelationThunkEnv, + rightThunk: RelationThunkEnv, + env: EvaluatorState + ): RelationIterator { + return relation(RelationType.BAG) { + val leftItr = leftThunk(env) + while (leftItr.nextRow()) { + val rightItr = rightThunk(env) + while (rightItr.nextRow()) { + yield() + } + } + } + } + + private fun createLeftJoinThunk( + joinMetas: MetaContainer, + leftThunk: RelationThunkEnv, + rightThunk: RelationThunkEnv, + rightVariableIndexes: List, + predicateThunk: PhysicalPlanThunk? + ) = + relationThunk(joinMetas) { env -> + createLeftJoinRelItr(leftThunk, rightThunk, rightVariableIndexes, predicateThunk, env) + } + + /** + * Like [createCrossJoinRelItr], but the right-hand relation is padded with unknown values in the event + * that it is empty or that the predicate does not match. + */ + private fun createLeftJoinRelItr( + leftThunk: RelationThunkEnv, + rightThunk: RelationThunkEnv, + rightVariableIndexes: List, + predicateThunk: PhysicalPlanThunk?, + env: EvaluatorState + ): RelationIterator { + return if (predicateThunk == null) { + relation(RelationType.BAG) { + val leftItr = leftThunk(env) + while (leftItr.nextRow()) { + val rightItr = rightThunk(env) + // if the rightItr does has a row... + if (rightItr.nextRow()) { + yield() // yield current row + yieldAll(rightItr) // yield remaining rows + } else { + // no row--yield padded row + yieldPaddedUnknowns(rightVariableIndexes, env) + } + } + } + } else { + relation(RelationType.BAG) { + val leftItr = leftThunk(env) + while (leftItr.nextRow()) { + val rightItr = rightThunk(env) + var yieldedSomething = false + while (rightItr.nextRow()) { + if (coercePredicateResult(predicateThunk(env))) { + yield() + yieldedSomething = true + } + } + // If we still haven't yielded anything, we still need to emit a row with right-hand side variables + // padded with unknowns. + if (!yieldedSomething) { + yieldPaddedUnknowns(rightVariableIndexes, env) + } + } + } + } + } + + private suspend fun RelationScope.yieldPaddedUnknowns( + rightVariableIndexes: List, + env: EvaluatorState + ) { + rightVariableIndexes.forEach { env.registers[it] = valueFactory.nullValue } + yield() + } + + private fun PartiqlPhysical.Bexpr.extractAccessibleVarDecls(): List = + // This fold traverses a [PartiqlPhysical.Bexpr] node and extracts all variable declarations within + // It avoids recursing into sub-queries. + object : PartiqlPhysical.VisitorFold>() { + override fun visitVarDecl( + node: PartiqlPhysical.VarDecl, + accumulator: List + ): List = accumulator + node + + /** + * Avoids recursion into expressions, since these may contain sub-queries with other var-decls that we don't + * care about here. + */ + override fun walkExpr( + node: PartiqlPhysical.Expr, + accumulator: List + ): List { + return accumulator + } + }.walkBexpr(this, emptyList()) + + private fun createFilterRelItr( + relItr: RelationIterator, + predicateThunk: PhysicalPlanThunk, + env: EvaluatorState + ) = relation(RelationType.BAG) { + while (true) { + if (!relItr.nextRow()) { + break + } else { + val matches = predicateThunk(env) + if (coercePredicateResult(matches)) { + yield() + } + } + } + } + + private fun coercePredicateResult(value: ExprValue): Boolean = + when { + value.isUnknown() -> false + else -> value.booleanValue() // <-- throws if [value] is not a boolean. + } + + override fun convertOffset(node: PartiqlPhysical.Bexpr.Offset): RelationThunkEnv { + val rowCountThunk = exprConverter.convert(node.rowCount) + val sourceThunk = this.convert(node.source) + val rowCountLocation = node.rowCount.metas.sourceLocationMeta + return relationThunk(node.metas) { env -> + val skipCount: Long = evalOffsetRowCount(rowCountThunk, env, rowCountLocation) + relation(RelationType.BAG) { + val sourceRel = sourceThunk(env) + var rowCount = 0L + while (rowCount++ < skipCount) { + // stop iterating if we finish run out of rows before we hit the offset. + if (!sourceRel.nextRow()) { + return@relation + } + } + + yieldAll(sourceRel) + } + } + } + + override fun convertLimit(node: PartiqlPhysical.Bexpr.Limit): RelationThunkEnv { + val rowCountThunk = exprConverter.convert(node.rowCount) + val sourceThunk = this.convert(node.source) + val rowCountLocation = node.rowCount.metas.sourceLocationMeta + return relationThunk(node.metas) { env -> + val limitCount = evalLimitRowCount(rowCountThunk, env, rowCountLocation) + val rowIter = sourceThunk(env) + relation(RelationType.BAG) { + var rowCount = 0L + while (rowCount++ < limitCount && rowIter.nextRow()) { + yield() + } + } + } + } + + override fun convertLet(node: PartiqlPhysical.Bexpr.Let): RelationThunkEnv { + val sourceThunk = this.convert(node.source) + class CompiledBinding(val index: Int, val valueThunk: PhysicalPlanThunk) + val compiledBindings = node.bindings.map { + CompiledBinding( + it.decl.index.value.toIntExact(), + exprConverter.convert(it.value) + ) + } + return relationThunk(node.metas) { env -> + val sourceItr = sourceThunk(env) + + relation(sourceItr.relType) { + while (sourceItr.nextRow()) { + compiledBindings.forEach { + env.registers[it.index] = it.valueThunk(env) + } + yield() + } + } + } + } +} + +private fun PartiqlPhysical.Expr.isLitTrue() = + this is PartiqlPhysical.Expr.Lit && this.value is BoolElement && this.value.booleanValue diff --git a/lang/src/org/partiql/lang/eval/physical/PhysicalExprToThunkConverter.kt b/lang/src/org/partiql/lang/eval/physical/PhysicalExprToThunkConverter.kt new file mode 100644 index 0000000000..57f11c12d1 --- /dev/null +++ b/lang/src/org/partiql/lang/eval/physical/PhysicalExprToThunkConverter.kt @@ -0,0 +1,13 @@ +package org.partiql.lang.eval.physical + +import org.partiql.lang.domains.PartiqlPhysical + +/** + * Simple API that defines a method to convert a [PartiqlPhysical.Expr] to a [PhysicalPlanThunk]. + * + * Intended to prevent [PhysicalBexprToThunkConverter] from having to take a direct dependency on + * [org.partiql.lang.eval.EvaluatingCompiler]. + */ +internal interface PhysicalExprToThunkConverter { + fun convert(expr: PartiqlPhysical.Expr): PhysicalPlanThunk +} diff --git a/lang/src/org/partiql/lang/eval/physical/PhysicalExprToThunkConverterImpl.kt b/lang/src/org/partiql/lang/eval/physical/PhysicalExprToThunkConverterImpl.kt new file mode 100644 index 0000000000..0673175914 --- /dev/null +++ b/lang/src/org/partiql/lang/eval/physical/PhysicalExprToThunkConverterImpl.kt @@ -0,0 +1,1902 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.lang.eval.physical + +import com.amazon.ion.IonString +import com.amazon.ion.IonValue +import com.amazon.ion.Timestamp +import com.amazon.ionelement.api.MetaContainer +import com.amazon.ionelement.api.toIonValue +import org.partiql.lang.ast.SourceLocationMeta +import org.partiql.lang.ast.sourceLocation +import org.partiql.lang.ast.toPartiQlMetaContainer +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.domains.staticType +import org.partiql.lang.domains.toBindingCase +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.errors.PropertyValueMap +import org.partiql.lang.eval.AnyOfCastTable +import org.partiql.lang.eval.Arguments +import org.partiql.lang.eval.BaseExprValue +import org.partiql.lang.eval.BindingName +import org.partiql.lang.eval.CastFunc +import org.partiql.lang.eval.DEFAULT_COMPARATOR +import org.partiql.lang.eval.ErrorDetails +import org.partiql.lang.eval.EvaluationException +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.ExprFunction +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.ExprValueFactory +import org.partiql.lang.eval.ExprValueType +import org.partiql.lang.eval.Expression +import org.partiql.lang.eval.Named +import org.partiql.lang.eval.ProjectionIterationBehavior +import org.partiql.lang.eval.RequiredArgs +import org.partiql.lang.eval.RequiredWithOptional +import org.partiql.lang.eval.RequiredWithVariadic +import org.partiql.lang.eval.SequenceExprValue +import org.partiql.lang.eval.StructOrdering +import org.partiql.lang.eval.ThunkValue +import org.partiql.lang.eval.TypedOpBehavior +import org.partiql.lang.eval.TypingMode +import org.partiql.lang.eval.booleanValue +import org.partiql.lang.eval.builtins.storedprocedure.StoredProcedure +import org.partiql.lang.eval.call +import org.partiql.lang.eval.cast +import org.partiql.lang.eval.compareTo +import org.partiql.lang.eval.createErrorSignaler +import org.partiql.lang.eval.createThunkFactory +import org.partiql.lang.eval.err +import org.partiql.lang.eval.errInvalidArgumentType +import org.partiql.lang.eval.errNoContext +import org.partiql.lang.eval.errorContextFrom +import org.partiql.lang.eval.errorIf +import org.partiql.lang.eval.exprEquals +import org.partiql.lang.eval.fillErrorContext +import org.partiql.lang.eval.isNotUnknown +import org.partiql.lang.eval.isUnknown +import org.partiql.lang.eval.like.parsePattern +import org.partiql.lang.eval.namedValue +import org.partiql.lang.eval.numberValue +import org.partiql.lang.eval.rangeOver +import org.partiql.lang.eval.stringValue +import org.partiql.lang.eval.syntheticColumnName +import org.partiql.lang.eval.time.Time +import org.partiql.lang.eval.unnamedValue +import org.partiql.lang.eval.visitors.PartiqlPhysicalSanityValidator +import org.partiql.lang.planner.EvaluatorOptions +import org.partiql.lang.types.AnyOfType +import org.partiql.lang.types.AnyType +import org.partiql.lang.types.FunctionSignature +import org.partiql.lang.types.IntType +import org.partiql.lang.types.SingleType +import org.partiql.lang.types.StaticType +import org.partiql.lang.types.TypedOpParameter +import org.partiql.lang.types.UnknownArguments +import org.partiql.lang.types.UnsupportedTypeCheckException +import org.partiql.lang.types.toTypedOpParameter +import org.partiql.lang.util.checkThreadInterrupted +import org.partiql.lang.util.codePointSequence +import org.partiql.lang.util.div +import org.partiql.lang.util.isZero +import org.partiql.lang.util.minus +import org.partiql.lang.util.plus +import org.partiql.lang.util.rem +import org.partiql.lang.util.stringValue +import org.partiql.lang.util.times +import org.partiql.lang.util.timestampValue +import org.partiql.lang.util.toIntExact +import org.partiql.lang.util.totalMinutes +import org.partiql.lang.util.unaryMinus +import java.math.BigDecimal +import java.util.LinkedList +import java.util.TreeSet +import java.util.regex.Pattern +import kotlin.collections.ArrayList + +/** + * A basic "compiler" that converts an instance of [PartiqlPhysical.Expr] to an [Expression]. + * + * This is a modified copy of the legacy `EvaluatingCompiler` class, which is now legacy. + * The primary differences between this class an `EvaluatingCompiler` are: + * + * - All references to `PartiqlPhysical` are replaced with `PartiqlPhysical`. + * - `EvaluatingCompiler` compiles "monolithic" SFW queries--this class compiles relational + * operators (in concert with [PhysicalBexprToThunkConverter]). + * + * This implementation produces a "compiled" form consisting of context-threaded + * code in the form of a tree of [PhysicalPlanThunk]s. An overview of this technique can be found + * [here][1]. + * + * **Note:** *threaded* in this context is used in how the code gets *threaded* together for + * interpretation and **not** the concurrency primitive. That is to say this code is NOT thread + * safe. + * + * [1]: https://www.complang.tuwien.ac.at/anton/lvas/sem06w/fest.pdf + */ +internal class PhysicalExprToThunkConverterImpl( + private val valueFactory: ExprValueFactory, + private val functions: Map, + private val customTypedOpParameters: Map, + private val procedures: Map, + private val evaluatorOptions: EvaluatorOptions = EvaluatorOptions.standard() +) : PhysicalExprToThunkConverter { + private val errorSignaler = evaluatorOptions.typingMode.createErrorSignaler(valueFactory) + private val thunkFactory = evaluatorOptions.typingMode.createThunkFactory( + evaluatorOptions.thunkOptions, + valueFactory + ) + + private fun Number.exprValue(): ExprValue = when (this) { + is Int -> valueFactory.newInt(this) + is Long -> valueFactory.newInt(this) + is Double -> valueFactory.newFloat(this) + is BigDecimal -> valueFactory.newDecimal(this) + else -> errNoContext( + "Cannot convert number to expression value: $this", + errorCode = ErrorCode.EVALUATOR_INVALID_CONVERSION, + internal = true + ) + } + + private fun Boolean.exprValue(): ExprValue = valueFactory.newBoolean(this) + private fun String.exprValue(): ExprValue = valueFactory.newString(this) + + /** + * Compiles a [PartiqlPhysical.Statement] tree to an [Expression]. + * + * Checks [Thread.interrupted] before every expression and sub-expression is compiled + * and throws [InterruptedException] if [Thread.interrupted] it has been set in the + * hope that long-running compilations may be aborted by the caller. + */ + fun compile(plan: PartiqlPhysical.Plan): Expression { + PartiqlPhysicalSanityValidator(evaluatorOptions).walkPlan(plan) + + val thunk = compileAstStatement(plan.stmt) + + return object : Expression { + override fun eval(session: EvaluationSession): ExprValue { + val env = EvaluatorState( + session = session, + registers = Array(plan.locals.size) { valueFactory.missingValue } + ) + + return thunk(env) + } + } + } + + override fun convert(expr: PartiqlPhysical.Expr): PhysicalPlanThunk = this.compileAstExpr(expr) + + /** + * Compiles the specified [PartiqlPhysical.Statement] into a [PhysicalPlanThunk]. + * + * This function will [InterruptedException] if [Thread.interrupted] has been set. + */ + private fun compileAstStatement(ast: PartiqlPhysical.Statement): PhysicalPlanThunk { + return when (ast) { + is PartiqlPhysical.Statement.Query -> compileAstExpr(ast.expr) + is PartiqlPhysical.Statement.Exec -> compileExec(ast) + } + } + + private fun compileAstExpr(expr: PartiqlPhysical.Expr): PhysicalPlanThunk { + checkThreadInterrupted() + val metas = expr.metas + + return when (expr) { + is PartiqlPhysical.Expr.Lit -> compileLit(expr, metas) + is PartiqlPhysical.Expr.Missing -> compileMissing(metas) + is PartiqlPhysical.Expr.LocalId -> compileLocalId(expr, metas) + is PartiqlPhysical.Expr.GlobalId -> compileGlobalId(expr) + is PartiqlPhysical.Expr.SimpleCase -> compileSimpleCase(expr, metas) + is PartiqlPhysical.Expr.SearchedCase -> compileSearchedCase(expr, metas) + is PartiqlPhysical.Expr.Path -> compilePath(expr, metas) + is PartiqlPhysical.Expr.Struct -> compileStruct(expr) + is PartiqlPhysical.Expr.CallAgg -> compileCallAgg(expr, metas) + is PartiqlPhysical.Expr.Parameter -> compileParameter(expr, metas) + is PartiqlPhysical.Expr.Date -> compileDate(expr, metas) + is PartiqlPhysical.Expr.LitTime -> compileLitTime(expr, metas) + + // arithmetic operations + is PartiqlPhysical.Expr.Plus -> compilePlus(expr, metas) + is PartiqlPhysical.Expr.Times -> compileTimes(expr, metas) + is PartiqlPhysical.Expr.Minus -> compileMinus(expr, metas) + is PartiqlPhysical.Expr.Divide -> compileDivide(expr, metas) + is PartiqlPhysical.Expr.Modulo -> compileModulo(expr, metas) + + // comparison operators + is PartiqlPhysical.Expr.And -> compileAnd(expr, metas) + is PartiqlPhysical.Expr.Between -> compileBetween(expr, metas) + is PartiqlPhysical.Expr.Eq -> compileEq(expr, metas) + is PartiqlPhysical.Expr.Gt -> compileGt(expr, metas) + is PartiqlPhysical.Expr.Gte -> compileGte(expr, metas) + is PartiqlPhysical.Expr.Lt -> compileLt(expr, metas) + is PartiqlPhysical.Expr.Lte -> compileLte(expr, metas) + is PartiqlPhysical.Expr.Like -> compileLike(expr, metas) + is PartiqlPhysical.Expr.InCollection -> compileIn(expr, metas) + + // logical operators + is PartiqlPhysical.Expr.Ne -> compileNe(expr, metas) + is PartiqlPhysical.Expr.Or -> compileOr(expr, metas) + + // unary + is PartiqlPhysical.Expr.Not -> compileNot(expr, metas) + is PartiqlPhysical.Expr.Pos -> compilePos(expr, metas) + is PartiqlPhysical.Expr.Neg -> compileNeg(expr, metas) + + // other operators + is PartiqlPhysical.Expr.Concat -> compileConcat(expr, metas) + is PartiqlPhysical.Expr.Call -> compileCall(expr, metas) + is PartiqlPhysical.Expr.NullIf -> compileNullIf(expr, metas) + is PartiqlPhysical.Expr.Coalesce -> compileCoalesce(expr, metas) + + // "typed" operators (RHS is a data type and not an expression) + is PartiqlPhysical.Expr.Cast -> compileCast(expr, metas) + is PartiqlPhysical.Expr.IsType -> compileIs(expr, metas) + is PartiqlPhysical.Expr.CanCast -> compileCanCast(expr, metas) + is PartiqlPhysical.Expr.CanLosslessCast -> compileCanLosslessCast(expr, metas) + + // sequence constructors + is PartiqlPhysical.Expr.List -> compileSeq(ExprValueType.LIST, expr.values, metas) + is PartiqlPhysical.Expr.Sexp -> compileSeq(ExprValueType.SEXP, expr.values, metas) + is PartiqlPhysical.Expr.Bag -> compileSeq(ExprValueType.BAG, expr.values, metas) + + // set operators + is PartiqlPhysical.Expr.Intersect, + is PartiqlPhysical.Expr.Union, + is PartiqlPhysical.Expr.Except -> { + err( + "${expr.javaClass.canonicalName} is not yet supported", + ErrorCode.EVALUATOR_FEATURE_NOT_SUPPORTED_YET, + errorContextFrom(metas).also { + it[Property.FEATURE_NAME] = expr.javaClass.canonicalName + }, + internal = false + ) + } + is PartiqlPhysical.Expr.BindingsToValues -> compileBindingsToValues(expr) + } + } + + private fun compileBindingsToValues(expr: PartiqlPhysical.Expr.BindingsToValues): PhysicalPlanThunk { + val mapThunk = compileAstExpr(expr.exp) + val bexprThunk: RelationThunkEnv = PhysicalBexprToThunkConverter(this, thunkFactory.valueFactory) + .convert(expr.query) + + return thunkFactory.thunkEnv(expr.metas) { env -> + val elements = sequence { + val relItr = bexprThunk(env) + while (relItr.nextRow()) { + yield(mapThunk(env)) + } + } + valueFactory.newBag(elements) + } + } + + private fun compileAstExprs(args: List) = args.map { compileAstExpr(it) } + + private fun compileNullIf(expr: PartiqlPhysical.Expr.NullIf, metas: MetaContainer): PhysicalPlanThunk { + val expr1Thunk = compileAstExpr(expr.expr1) + val expr2Thunk = compileAstExpr(expr.expr2) + + // Note: NULLIF does not propagate the unknown values and .exprEquals provides the correct semantics. + return thunkFactory.thunkEnv(metas) { env -> + val expr1Value = expr1Thunk(env) + val expr2Value = expr2Thunk(env) + when { + expr1Value.exprEquals(expr2Value) -> valueFactory.nullValue + else -> expr1Value + } + } + } + + private fun compileCoalesce(expr: PartiqlPhysical.Expr.Coalesce, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.args) + + return thunkFactory.thunkEnv(metas) { env -> + var nullFound = false + var knownValue: ExprValue? = null + for (thunk in argThunks) { + val argValue = thunk(env) + if (argValue.isNotUnknown()) { + knownValue = argValue + // No need to execute remaining thunks to save computation as first non-unknown value is found + break + } + if (argValue.type == ExprValueType.NULL) { + nullFound = true + } + } + when (knownValue) { + null -> when { + evaluatorOptions.typingMode == TypingMode.PERMISSIVE && !nullFound -> valueFactory.missingValue + else -> valueFactory.nullValue + } + else -> knownValue + } + } + } + + /** + * Returns a function that accepts an [ExprValue] as an argument and returns true it is `NULL`, `MISSING`, or + * within the range specified by [range]. + */ + private fun integerValueValidator( + range: LongRange + ): (ExprValue) -> Boolean = { value -> + when (value.type) { + ExprValueType.NULL, ExprValueType.MISSING -> true + ExprValueType.INT -> { + val longValue: Long = value.scalar.numberValue()?.toLong() + ?: error( + "ExprValue.numberValue() must not be `NULL` when its type is INT." + + "This indicates that the ExprValue instance has a bug." + ) + + // PRO-TIP: make sure to use the `Long` primitive type here with `.contains` otherwise + // Kotlin will use the version of `.contains` that treats [range] as a collection, and it will + // be very slow! + range.contains(longValue) + } + else -> error( + "The expression's static type was supposed to be INT but instead it was ${value.type}" + + "This may indicate the presence of a bug in the type inferencer." + ) + } + } + + /** + * For operators which could return integer type, check integer overflow in case of [TypingMode.PERMISSIVE]. + */ + private fun resolveIntConstraint(computeThunk: PhysicalPlanThunk, metas: MetaContainer): PhysicalPlanThunk = + when (val staticTypes = metas.staticType?.type?.getTypes()) { + // No staticType, can't validate integer size. + null -> computeThunk + else -> { + when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> { + // integer size constraints have not been tested under [TypingMode.LEGACY] because the + // [StaticTypeInferenceVisitorTransform] doesn't support being used with legacy mode yet. + // throw an exception in case we encounter this untested scenario. This might work fine, but I + // wouldn't bet on it. + val hasConstrainedInteger = staticTypes.any { + it is IntType && it.rangeConstraint != IntType.IntRangeConstraint.UNCONSTRAINED + } + if (hasConstrainedInteger) { + TODO("Legacy mode doesn't support integer size constraints yet.") + } else { + computeThunk + } + } + TypingMode.PERMISSIVE -> { + val biggestIntegerType = staticTypes.filterIsInstance().maxByOrNull { + it.rangeConstraint.numBytes + } + when (biggestIntegerType) { + is IntType -> { + val validator = integerValueValidator(biggestIntegerType.rangeConstraint.validRange) + + thunkFactory.thunkEnv(metas) { env -> + val naryResult = computeThunk(env) + errorSignaler.errorIf( + !validator(naryResult), + ErrorCode.EVALUATOR_INTEGER_OVERFLOW, + { ErrorDetails(metas, "Integer overflow", errorContextFrom(metas)) }, + { naryResult } + ) + } + } + // If there is no IntType StaticType, can't validate the integer size either. + null -> computeThunk + else -> computeThunk + } + } + } + } + } + + private fun compilePlus(expr: PartiqlPhysical.Expr.Plus, metas: MetaContainer): PhysicalPlanThunk { + if (expr.operands.size < 2) { + error("Internal Error: PartiqlPhysical.Expr.Plus must have at least 2 arguments") + } + + val argThunks = compileAstExprs(expr.operands) + + val computeThunk = thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + (lValue.numberValue() + rValue.numberValue()).exprValue() + } + + return resolveIntConstraint(computeThunk, metas) + } + + private fun compileMinus(expr: PartiqlPhysical.Expr.Minus, metas: MetaContainer): PhysicalPlanThunk { + if (expr.operands.size < 2) { + error("Internal Error: PartiqlPhysical.Expr.Minus must have at least 2 arguments") + } + + val argThunks = compileAstExprs(expr.operands) + + val computeThunk = thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + (lValue.numberValue() - rValue.numberValue()).exprValue() + } + + return resolveIntConstraint(computeThunk, metas) + } + + private fun compilePos(expr: PartiqlPhysical.Expr.Pos, metas: MetaContainer): PhysicalPlanThunk { + val exprThunk = compileAstExpr(expr.expr) + + val computeThunk = thunkFactory.thunkEnvOperands(metas, exprThunk) { _, value -> + // Invoking .numberValue() here makes this essentially just a type check + value.numberValue() + // Original value is returned unmodified. + value + } + + return resolveIntConstraint(computeThunk, metas) + } + + private fun compileNeg(expr: PartiqlPhysical.Expr.Neg, metas: MetaContainer): PhysicalPlanThunk { + val exprThunk = compileAstExpr(expr.expr) + + val computeThunk = thunkFactory.thunkEnvOperands(metas, exprThunk) { _, value -> + (-value.numberValue()).exprValue() + } + + return resolveIntConstraint(computeThunk, metas) + } + + private fun compileTimes(expr: PartiqlPhysical.Expr.Times, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + val computeThunk = thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + (lValue.numberValue() * rValue.numberValue()).exprValue() + } + + return resolveIntConstraint(computeThunk, metas) + } + + private fun compileDivide(expr: PartiqlPhysical.Expr.Divide, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + val computeThunk = thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + val denominator = rValue.numberValue() + + errorSignaler.errorIf( + denominator.isZero(), + ErrorCode.EVALUATOR_DIVIDE_BY_ZERO, + { ErrorDetails(metas, "/ by zero") } + ) { + try { + (lValue.numberValue() / denominator).exprValue() + } catch (e: ArithmeticException) { + // Setting the internal flag as true as it is not clear what + // ArithmeticException may be thrown by the above + throw EvaluationException( + cause = e, + errorCode = ErrorCode.EVALUATOR_ARITHMETIC_EXCEPTION, + internal = true + ) + } + } + } + + return resolveIntConstraint(computeThunk, metas) + } + + private fun compileModulo(expr: PartiqlPhysical.Expr.Modulo, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + val computeThunk = thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + val denominator = rValue.numberValue() + if (denominator.isZero()) { + err("% by zero", ErrorCode.EVALUATOR_MODULO_BY_ZERO, errorContext = null, internal = false) + } + + (lValue.numberValue() % denominator).exprValue() + } + + return resolveIntConstraint(computeThunk, metas) + } + + private fun compileEq(expr: PartiqlPhysical.Expr.Eq, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkAndMap(metas, argThunks) { lValue, rValue -> + (lValue.exprEquals(rValue)) + } + } + + private fun compileNe(expr: PartiqlPhysical.Expr.Ne, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + ((!lValue.exprEquals(rValue)).exprValue()) + } + } + + private fun compileLt(expr: PartiqlPhysical.Expr.Lt, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkAndMap(metas, argThunks) { lValue, rValue -> lValue < rValue } + } + + private fun compileLte(expr: PartiqlPhysical.Expr.Lte, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkAndMap(metas, argThunks) { lValue, rValue -> lValue <= rValue } + } + + private fun compileGt(expr: PartiqlPhysical.Expr.Gt, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkAndMap(metas, argThunks) { lValue, rValue -> lValue > rValue } + } + + private fun compileGte(expr: PartiqlPhysical.Expr.Gte, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkAndMap(metas, argThunks) { lValue, rValue -> lValue >= rValue } + } + + private fun compileBetween(expr: PartiqlPhysical.Expr.Between, metas: MetaContainer): PhysicalPlanThunk { + val valueThunk = compileAstExpr(expr.value) + val fromThunk = compileAstExpr(expr.from) + val toThunk = compileAstExpr(expr.to) + + return thunkFactory.thunkEnvOperands(metas, valueThunk, fromThunk, toThunk) { _, v, f, t -> + (v >= f && v <= t).exprValue() + } + } + + /** + * `IN` can *almost* be thought of has being syntactic sugar for the `OR` operator. + * + * `a IN (b, c, d)` is equivalent to `a = b OR a = c OR a = d`. On deep inspection, there + * are important implications to this regarding propagation of unknown values. Specifically, the + * presence of any unknown in `b`, `c`, or `d` will result in unknown propagation iif `a` does not + * equal `b`, `c`, or `d`. i.e.: + * + * - `1 in (null, 2, 3)` -> `null` + * - `2 in (null, 2, 3)` -> `true` + * - `2 in (1, 2, 3)` -> `true` + * - `0 in (1, 2, 4)` -> `false` + * + * `IN` is varies from the `OR` operator in that this behavior holds true when other types of expressions are + * used on the right side of `IN` such as sub-queries and variables whose value is that of a list or bag. + */ + private fun compileIn(expr: PartiqlPhysical.Expr.InCollection, metas: MetaContainer): PhysicalPlanThunk { + val args = expr.operands + val leftThunk = compileAstExpr(args[0]) + val rightOp = args[1] + + fun isOptimizedCase(values: List): Boolean = values.all { it is PartiqlPhysical.Expr.Lit && !it.value.isNull } + + fun optimizedCase(values: List): PhysicalPlanThunk { + // Put all the literals in the sequence into a pre-computed map to be checked later by the thunk. + // If the left-hand value is one of these we can short-circuit with a result of TRUE. + // This is the fastest possible case and allows for hundreds of literal values (or more) in the + // sequence without a huge performance penalty. + // NOTE: we cannot use a [HashSet<>] here because [ExprValue] does not implement [Object.hashCode] or + // [Object.equals]. + val precomputedLiteralsMap = values + .filterIsInstance() + .mapTo(TreeSet(DEFAULT_COMPARATOR)) { + valueFactory.newFromIonValue( + it.value.toIonValue( + valueFactory.ion + ) + ) + } + + // the compiled thunk simply checks if the left side is contained on the right side. + // thunkEnvOperands takes care of unknown propagation for the left side; for the right, + // this unknown propagation does not apply since we've eliminated the possibility of unknowns above. + return thunkFactory.thunkEnvOperands(metas, leftThunk) { _, leftValue -> + precomputedLiteralsMap.contains(leftValue).exprValue() + } + } + + return when { + // We can significantly optimize this if rightArg is a sequence constructor which is composed of entirely + // of non-null literal values. + rightOp is PartiqlPhysical.Expr.List && isOptimizedCase(rightOp.values) -> optimizedCase(rightOp.values) + rightOp is PartiqlPhysical.Expr.Bag && isOptimizedCase(rightOp.values) -> optimizedCase(rightOp.values) + rightOp is PartiqlPhysical.Expr.Sexp && isOptimizedCase(rightOp.values) -> optimizedCase(rightOp.values) + // The unoptimized case... + else -> { + val rightThunk = compileAstExpr(rightOp) + + // Legacy mode: + // Returns FALSE when the right side of IN is not a sequence + // Returns NULL if the right side is MISSING or any value on the right side is MISSING + // Permissive mode: + // Returns MISSING when the right side of IN is not a sequence + // Returns MISSING if the right side is MISSING or any value on the right side is MISSING + val (propagateMissingAs, propagateNotASeqAs) = with(valueFactory) { + when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> nullValue to newBoolean(false) + TypingMode.PERMISSIVE -> missingValue to missingValue + } + } + + // Note that standard unknown propagation applies to the left and right operands. Both [TypingMode]s + // are handled by [ThunkFactory.thunkEnvOperands] and that additional rules for unknown propagation are + // implemented within the thunk for the values within the sequence on the right side of IN. + thunkFactory.thunkEnvOperands(metas, leftThunk, rightThunk) { _, leftValue, rightValue -> + var nullSeen = false + var missingSeen = false + + when { + rightValue.type == ExprValueType.MISSING -> propagateMissingAs + !rightValue.type.isSequence -> propagateNotASeqAs + else -> { + rightValue.forEach { + when (it.type) { + ExprValueType.NULL -> nullSeen = true + ExprValueType.MISSING -> missingSeen = true + // short-circuit to TRUE on the first matching value + else -> if (it.exprEquals(leftValue)) { + return@thunkEnvOperands valueFactory.newBoolean(true) + } + } + } + // If we make it here then there was no match. Propagate MISSING, NULL or return false. + // Note that if both MISSING and NULL was encountered, MISSING takes precedence. + when { + missingSeen -> propagateMissingAs + nullSeen -> valueFactory.nullValue + else -> valueFactory.newBoolean(false) + } + } + } + } + } + } + } + + private fun compileNot(expr: PartiqlPhysical.Expr.Not, metas: MetaContainer): PhysicalPlanThunk { + val argThunk = compileAstExpr(expr.expr) + + return thunkFactory.thunkEnvOperands(metas, argThunk) { _, value -> + (!value.booleanValue()).exprValue() + } + } + + private fun compileAnd(expr: PartiqlPhysical.Expr.And, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + // can't use the null propagation supplied by [ThunkFactory.thunkEnv] here because AND short-circuits on + // false values and *NOT* on NULL or MISSING + return when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> thunkFactory.thunkEnv(metas) thunk@{ env -> + var hasUnknowns = false + argThunks.forEach { currThunk -> + val currValue = currThunk(env) + when { + currValue.isUnknown() -> hasUnknowns = true + // Short circuit only if we encounter a known false value. + !currValue.booleanValue() -> return@thunk valueFactory.newBoolean(false) + } + } + + when (hasUnknowns) { + true -> valueFactory.nullValue + false -> valueFactory.newBoolean(true) + } + } + TypingMode.PERMISSIVE -> thunkFactory.thunkEnv(metas) thunk@{ env -> + var hasNull = false + var hasMissing = false + argThunks.forEach { currThunk -> + val currValue = currThunk(env) + when (currValue.type) { + // Short circuit only if we encounter a known false value. + ExprValueType.BOOL -> if (!currValue.booleanValue()) return@thunk valueFactory.newBoolean(false) + ExprValueType.NULL -> hasNull = true + // type mismatch, return missing + else -> hasMissing = true + } + } + + when { + hasMissing -> valueFactory.missingValue + hasNull -> valueFactory.nullValue + else -> valueFactory.newBoolean(true) + } + } + } + } + + private fun compileOr(expr: PartiqlPhysical.Expr.Or, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + // can't use the null propagation supplied by [ThunkFactory.thunkEnv] here because OR short-circuits on + // true values and *NOT* on NULL or MISSING + return when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> + thunkFactory.thunkEnv(metas) thunk@{ env -> + var hasUnknowns = false + argThunks.forEach { currThunk -> + val currValue = currThunk(env) + // How null-propagation works for OR is rather weird according to the SQL-92 spec. + // Nulls are propagated like other expressions only when none of the terms are TRUE. + // If any one of them is TRUE, then the entire expression evaluates to TRUE, i.e.: + // NULL OR TRUE -> TRUE + // NULL OR FALSE -> NULL + // (strange but true) + when { + currValue.isUnknown() -> hasUnknowns = true + currValue.booleanValue() -> return@thunk valueFactory.newBoolean(true) + } + } + + when (hasUnknowns) { + true -> valueFactory.nullValue + false -> valueFactory.newBoolean(false) + } + } + TypingMode.PERMISSIVE -> thunkFactory.thunkEnv(metas) thunk@{ env -> + var hasNull = false + var hasMissing = false + argThunks.forEach { currThunk -> + val currValue = currThunk(env) + when (currValue.type) { + // Short circuit only if we encounter a known true value. + ExprValueType.BOOL -> if (currValue.booleanValue()) return@thunk valueFactory.newBoolean(true) + ExprValueType.NULL -> hasNull = true + else -> hasMissing = true // type mismatch, return missing. + } + } + + when { + hasMissing -> valueFactory.missingValue + hasNull -> valueFactory.nullValue + else -> valueFactory.newBoolean(false) + } + } + } + } + + private fun compileConcat(expr: PartiqlPhysical.Expr.Concat, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + val lType = lValue.type + val rType = rValue.type + + if (lType.isText && rType.isText) { + // null/missing propagation is handled before getting here + (lValue.stringValue() + rValue.stringValue()).exprValue() + } else { + err( + "Wrong argument type for ||", + ErrorCode.EVALUATOR_CONCAT_FAILED_DUE_TO_INCOMPATIBLE_TYPE, + errorContextFrom(metas).also { + it[Property.ACTUAL_ARGUMENT_TYPES] = listOf(lType, rType).toString() + }, + internal = false + ) + } + } + } + + private fun compileCall(expr: PartiqlPhysical.Expr.Call, metas: MetaContainer): PhysicalPlanThunk { + val funcArgThunks = compileAstExprs(expr.args) + val func = functions[expr.funcName.text] ?: err( + "No such function: ${expr.funcName.text}", + ErrorCode.EVALUATOR_NO_SUCH_FUNCTION, + errorContextFrom(metas).also { + it[Property.FUNCTION_NAME] = expr.funcName.text + }, + internal = false + ) + + // Check arity + if (funcArgThunks.size !in func.signature.arity) { + val errorContext = errorContextFrom(metas).also { + it[Property.FUNCTION_NAME] = func.signature.name + it[Property.EXPECTED_ARITY_MIN] = func.signature.arity.first + it[Property.EXPECTED_ARITY_MAX] = func.signature.arity.last + it[Property.ACTUAL_ARITY] = funcArgThunks.size + } + + val message = when { + func.signature.arity.first == 1 && func.signature.arity.last == 1 -> + "${func.signature.name} takes a single argument, received: ${funcArgThunks.size}" + func.signature.arity.first == func.signature.arity.last -> + "${func.signature.name} takes exactly ${func.signature.arity.first} arguments, received: ${funcArgThunks.size}" + else -> + "${func.signature.name} takes between ${func.signature.arity.first} and " + + "${func.signature.arity.last} arguments, received: ${funcArgThunks.size}" + } + + throw EvaluationException( + message, + ErrorCode.EVALUATOR_INCORRECT_NUMBER_OF_ARGUMENTS_TO_FUNC_CALL, + errorContext, + internal = false + ) + } + + fun checkArgumentTypes(signature: FunctionSignature, args: List): Arguments { + fun checkArgumentType(formalStaticType: StaticType, actualArg: ExprValue, position: Int) { + val formalExprValueTypeDomain = formalStaticType.typeDomain + + val actualExprValueType = actualArg.type + val actualStaticType = StaticType.fromExprValue(actualArg) + + if (!actualStaticType.isSubTypeOf(formalStaticType)) { + errInvalidArgumentType( + signature = signature, + position = position, + expectedTypes = formalExprValueTypeDomain.toList(), + actualType = actualExprValueType + ) + } + } + + val required = args.take(signature.requiredParameters.size) + val rest = args.drop(signature.requiredParameters.size) + + signature.requiredParameters.zip(required).forEachIndexed { idx, (expected, actual) -> + checkArgumentType(expected, actual, idx + 1) + } + + return if (signature.optionalParameter != null && rest.isNotEmpty()) { + val opt = rest.last() + checkArgumentType(signature.optionalParameter, opt, required.size + 1) + RequiredWithOptional(required, opt) + } else if (signature.variadicParameter != null) { + rest.forEachIndexed { idx, arg -> + checkArgumentType(signature.variadicParameter.type, arg, required.size + 1 + idx) + } + RequiredWithVariadic(required, rest) + } else { + RequiredArgs(required) + } + } + + val computeThunk = when (func.signature.unknownArguments) { + UnknownArguments.PROPAGATE -> thunkFactory.thunkEnvOperands(metas, funcArgThunks) { env, values -> + val checkedArgs = checkArgumentTypes(func.signature, values) + func.call(env.session, checkedArgs) + } + UnknownArguments.PASS_THRU -> thunkFactory.thunkEnv(metas) { env -> + val funcArgValues = funcArgThunks.map { it(env) } + val checkedArgs = checkArgumentTypes(func.signature, funcArgValues) + func.call(env.session, checkedArgs) + } + } + + return resolveIntConstraint(computeThunk, metas) + } + + private fun compileLit(expr: PartiqlPhysical.Expr.Lit, metas: MetaContainer): PhysicalPlanThunk { + val value = valueFactory.newFromIonValue(expr.value.toIonValue(valueFactory.ion)) + + return thunkFactory.thunkEnv(metas) { value } + } + + private fun compileMissing(metas: MetaContainer): PhysicalPlanThunk = + thunkFactory.thunkEnv(metas) { valueFactory.missingValue } + + private fun compileGlobalId(expr: PartiqlPhysical.Expr.GlobalId): PhysicalPlanThunk { + val bindingCase = expr.case.toBindingCase() + return thunkFactory.thunkEnv(expr.metas) { env -> + val bindingName = BindingName(expr.uniqueId.text, bindingCase) + env.session.globals[bindingName] ?: throwUndefinedVariableException(bindingName, expr.metas) + } + } + + @Suppress("UNUSED_PARAMETER") + private fun compileLocalId(expr: PartiqlPhysical.Expr.LocalId, metas: MetaContainer): PhysicalPlanThunk { + val localIndex = expr.index.value.toIntExact() + return thunkFactory.thunkEnv(metas) { env -> + env.registers[localIndex] + } + } + + private fun compileParameter(expr: PartiqlPhysical.Expr.Parameter, metas: MetaContainer): PhysicalPlanThunk { + val ordinal = expr.index.value.toInt() + val index = ordinal - 1 + + return { env -> + val params = env.session.parameters + if (params.size <= index) { + throw EvaluationException( + "Unbound parameter for ordinal: $ordinal", + ErrorCode.EVALUATOR_UNBOUND_PARAMETER, + errorContextFrom(metas).also { + it[Property.EXPECTED_PARAMETER_ORDINAL] = ordinal + it[Property.BOUND_PARAMETER_COUNT] = params.size + }, + internal = false + ) + } + params[index] + } + } + + /** + * Returns a lambda that implements the `IS` operator type check according to the current + * [TypedOpBehavior]. + */ + private fun makeIsCheck( + staticType: SingleType, + typedOpParameter: TypedOpParameter, + metas: MetaContainer + ): (ExprValue) -> Boolean { + val exprValueType = staticType.runtimeType + + // The "simple" type match function only looks at the [ExprValueType] of the [ExprValue] + // and invokes the custom [validationThunk] if one exists. + val simpleTypeMatchFunc = { expValue: ExprValue -> + val isTypeMatch = when (exprValueType) { + // MISSING IS NULL and NULL IS MISSING + ExprValueType.NULL -> expValue.type.isUnknown + else -> expValue.type == exprValueType + } + (isTypeMatch && typedOpParameter.validationThunk?.let { it(expValue) } != false) + } + + return when (evaluatorOptions.typedOpBehavior) { + TypedOpBehavior.LEGACY -> simpleTypeMatchFunc + TypedOpBehavior.HONOR_PARAMETERS -> { expValue: ExprValue -> + staticType.allTypes.any { + val matchesStaticType = try { + it.isInstance(expValue) + } catch (e: UnsupportedTypeCheckException) { + err( + e.message!!, + ErrorCode.UNIMPLEMENTED_FEATURE, + errorContextFrom(metas), + internal = true + ) + } + + when { + !matchesStaticType -> false + else -> when (val validator = typedOpParameter.validationThunk) { + null -> true + else -> validator(expValue) + } + } + } + } + } + } + + private fun compileIs(expr: PartiqlPhysical.Expr.IsType, metas: MetaContainer): PhysicalPlanThunk { + val expThunk = compileAstExpr(expr.value) + val typedOpParameter = expr.type.toTypedOpParameter(customTypedOpParameters) + if (typedOpParameter.staticType is AnyType) { + return thunkFactory.thunkEnv(metas) { valueFactory.newBoolean(true) } + } + if (evaluatorOptions.typedOpBehavior == TypedOpBehavior.HONOR_PARAMETERS && expr.type is PartiqlPhysical.Type.FloatType && expr.type.precision != null) { + err( + "FLOAT precision parameter is unsupported", + ErrorCode.SEMANTIC_FLOAT_PRECISION_UNSUPPORTED, + errorContextFrom(expr.type.metas), + internal = false + ) + } + + val typeMatchFunc = when (val staticType = typedOpParameter.staticType) { + is SingleType -> makeIsCheck(staticType, typedOpParameter, metas) + is AnyOfType -> staticType.types.map { childType -> + when (childType) { + is SingleType -> makeIsCheck(childType, typedOpParameter, metas) + else -> err( + "Union type cannot have ANY or nested AnyOf type for IS", + ErrorCode.SEMANTIC_UNION_TYPE_INVALID, + errorContextFrom(metas), + internal = true + ) + } + }.let { typeMatchFuncs -> + { expValue: ExprValue -> typeMatchFuncs.any { func -> func(expValue) } } + } + is AnyType -> throw IllegalStateException("Unexpected ANY type in IS compilation") + } + + return thunkFactory.thunkEnv(metas) { env -> + val expValue = expThunk(env) + typeMatchFunc(expValue).exprValue() + } + } + + private fun compileCastHelper(value: PartiqlPhysical.Expr, asType: PartiqlPhysical.Type, metas: MetaContainer): PhysicalPlanThunk { + val expThunk = compileAstExpr(value) + val typedOpParameter = asType.toTypedOpParameter(customTypedOpParameters) + if (typedOpParameter.staticType is AnyType) { + return expThunk + } + if (evaluatorOptions.typedOpBehavior == TypedOpBehavior.HONOR_PARAMETERS && asType is PartiqlPhysical.Type.FloatType && asType.precision != null) { + err( + "FLOAT precision parameter is unsupported", + ErrorCode.SEMANTIC_FLOAT_PRECISION_UNSUPPORTED, + errorContextFrom(asType.metas), + internal = false + ) + } + + fun typeOpValidate( + value: ExprValue, + castOutput: ExprValue, + typeName: String, + locationMeta: SourceLocationMeta? + ) { + if (typedOpParameter.validationThunk?.let { it(castOutput) } == false) { + val errorContext = PropertyValueMap().also { + it[Property.CAST_FROM] = value.type.toString() + it[Property.CAST_TO] = typeName + } + + locationMeta?.let { fillErrorContext(errorContext, it) } + + throw EvaluationException( + "Validation failure for $asType", + ErrorCode.EVALUATOR_CAST_FAILED, + errorContext, + internal = false + ) + } + } + + fun singleTypeCastFunc(singleType: SingleType): CastFunc { + val locationMeta = metas.sourceLocationMeta + return { value -> + val castOutput = value.cast( + singleType, + valueFactory, + evaluatorOptions.typedOpBehavior, + locationMeta, + evaluatorOptions.defaultTimezoneOffset + ) + typeOpValidate(value, castOutput, singleType.runtimeType.toString(), locationMeta) + castOutput + } + } + + fun compileSingleTypeCast(singleType: SingleType): PhysicalPlanThunk { + val castFunc = singleTypeCastFunc(singleType) + // We do not use thunkFactory here because we want to explicitly avoid + // the optional evaluation-time type check for CAN_CAST below. + // Can cast needs that returns false if an + // exception is thrown during a normal cast operation. + return { env -> + val valueToCast = expThunk(env) + castFunc(valueToCast) + } + } + + fun compileCast(type: StaticType): PhysicalPlanThunk = when (type) { + is SingleType -> compileSingleTypeCast(type) + is AnyOfType -> { + val locationMeta = metas.sourceLocationMeta + val castTable = AnyOfCastTable(type, metas, valueFactory, ::singleTypeCastFunc); + + // We do not use thunkFactory here because we want to explicitly avoid + // the optional evaluation-time type check for CAN_CAST below. + // note that this would interfere with the error handling for can_cast that returns false if an + // exception is thrown during a normal cast operation. + { env -> + val sourceValue = expThunk(env) + castTable.cast(sourceValue).also { + // TODO put the right type name here + typeOpValidate(sourceValue, it, "", locationMeta) + } + } + } + is AnyType -> throw IllegalStateException("Unreachable code") + } + + return compileCast(typedOpParameter.staticType) + } + + private fun compileCast(expr: PartiqlPhysical.Expr.Cast, metas: MetaContainer): PhysicalPlanThunk = + thunkFactory.thunkEnv(metas, compileCastHelper(expr.value, expr.asType, metas)) + + private fun compileCanCast(expr: PartiqlPhysical.Expr.CanCast, metas: MetaContainer): PhysicalPlanThunk { + val typedOpParameter = expr.asType.toTypedOpParameter(customTypedOpParameters) + if (typedOpParameter.staticType is AnyType) { + return thunkFactory.thunkEnv(metas) { valueFactory.newBoolean(true) } + } + + val expThunk = compileAstExpr(expr.value) + + // TODO consider making this more efficient by not directly delegating to CAST + // TODO consider also making the operand not double evaluated (e.g. having expThunk memoize) + val castThunkEnv = compileCastHelper(expr.value, expr.asType, expr.metas) + return thunkFactory.thunkEnv(metas) { env -> + val sourceValue = expThunk(env) + try { + when { + // NULL/MISSING can cast to anything as themselves + sourceValue.isUnknown() -> valueFactory.newBoolean(true) + else -> { + val castedValue = castThunkEnv(env) + when { + // NULL/MISSING from cast is a permissive way to signal failure + castedValue.isUnknown() -> valueFactory.newBoolean(false) + else -> valueFactory.newBoolean(true) + } + } + } + } catch (e: EvaluationException) { + if (e.internal) { + throw e + } + valueFactory.newBoolean(false) + } + } + } + + private fun compileCanLosslessCast(expr: PartiqlPhysical.Expr.CanLosslessCast, metas: MetaContainer): PhysicalPlanThunk { + val typedOpParameter = expr.asType.toTypedOpParameter(customTypedOpParameters) + if (typedOpParameter.staticType is AnyType) { + return thunkFactory.thunkEnv(metas) { valueFactory.newBoolean(true) } + } + + val expThunk = compileAstExpr(expr.value) + + // TODO consider making this more efficient by not directly delegating to CAST + val castThunkEnv = compileCastHelper(expr.value, expr.asType, expr.metas) + return thunkFactory.thunkEnv(metas) { env -> + val sourceValue = expThunk(env) + val sourceType = StaticType.fromExprValue(sourceValue) + + fun roundTrip(): ExprValue { + val castedValue = castThunkEnv(env) + + val locationMeta = metas.sourceLocationMeta + fun castFunc(singleType: SingleType) = + { value: ExprValue -> + value.cast( + singleType, + valueFactory, + evaluatorOptions.typedOpBehavior, + locationMeta, + evaluatorOptions.defaultTimezoneOffset + ) + } + + val roundTripped = when (sourceType) { + is SingleType -> castFunc(sourceType)(castedValue) + is AnyOfType -> { + val castTable = AnyOfCastTable(sourceType, metas, valueFactory, ::castFunc) + castTable.cast(sourceValue) + } + // Should not be possible + is AnyType -> throw IllegalStateException("ANY type is not configured correctly in compiler") + } + + val lossless = sourceValue.exprEquals(roundTripped) + return valueFactory.newBoolean(lossless) + } + + try { + when (sourceValue.type) { + // NULL can cast to anything as itself + ExprValueType.NULL -> valueFactory.newBoolean(true) + + // Short-circuit timestamp -> date roundtrip if precision isn't [Timestamp.Precision.DAY] or + // [Timestamp.Precision.MONTH] or [Timestamp.Precision.YEAR] + ExprValueType.TIMESTAMP -> when (typedOpParameter.staticType) { + StaticType.DATE -> when (sourceValue.ionValue.timestampValue().precision) { + Timestamp.Precision.DAY, Timestamp.Precision.MONTH, Timestamp.Precision.YEAR -> roundTrip() + else -> valueFactory.newBoolean(false) + } + StaticType.TIME -> valueFactory.newBoolean(false) + else -> roundTrip() + } + + // For all other cases, attempt a round-trip of the value through the source and dest types + else -> roundTrip() + } + } catch (e: EvaluationException) { + if (e.internal) { + throw e + } + valueFactory.newBoolean(false) + } + } + } + + private fun compileSimpleCase(expr: PartiqlPhysical.Expr.SimpleCase, metas: MetaContainer): PhysicalPlanThunk { + val valueThunk = compileAstExpr(expr.expr) + val branchThunks = expr.cases.pairs.map { Pair(compileAstExpr(it.first), compileAstExpr(it.second)) } + val elseThunk = when (expr.default) { + null -> thunkFactory.thunkEnv(metas) { valueFactory.nullValue } + else -> compileAstExpr(expr.default) + } + + return thunkFactory.thunkEnv(metas) thunk@{ env -> + val caseValue = valueThunk(env) + // if the case value is unknown then we can short-circuit to the elseThunk directly + when { + caseValue.isUnknown() -> elseThunk(env) + else -> { + branchThunks.forEach { bt -> + val branchValue = bt.first(env) + // Just skip any branch values that are unknown, which we consider the same as false here. + when { + branchValue.isUnknown() -> { /* intentionally blank */ + } + else -> { + if (caseValue.exprEquals(branchValue)) { + return@thunk bt.second(env) + } + } + } + } + } + } + elseThunk(env) + } + } + + private fun compileSearchedCase(expr: PartiqlPhysical.Expr.SearchedCase, metas: MetaContainer): PhysicalPlanThunk { + val branchThunks = expr.cases.pairs.map { compileAstExpr(it.first) to compileAstExpr(it.second) } + val elseThunk = when (expr.default) { + null -> thunkFactory.thunkEnv(metas) { valueFactory.nullValue } + else -> compileAstExpr(expr.default) + } + + return when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> thunkFactory.thunkEnv(metas) thunk@{ env -> + branchThunks.forEach { bt -> + val conditionValue = bt.first(env) + // Any unknown value is considered the same as false. + // Note that .booleanValue() here will throw an EvaluationException if + // the data type is not boolean. + // TODO: .booleanValue does not have access to metas, so the EvaluationException is reported to be + // at the line & column of the CASE statement, not the predicate, unfortunately. + if (conditionValue.isNotUnknown() && conditionValue.booleanValue()) { + return@thunk bt.second(env) + } + } + elseThunk(env) + } + // Permissive mode propagates data type mismatches as MISSING, which is + // equivalent to false for searched CASE predicates. To simplify this, + // all we really need to do is consider any non-boolean result from the + // predicate to be false. + TypingMode.PERMISSIVE -> thunkFactory.thunkEnv(metas) thunk@{ env -> + branchThunks.forEach { bt -> + val conditionValue = bt.first(env) + if (conditionValue.type == ExprValueType.BOOL && conditionValue.booleanValue()) { + return@thunk bt.second(env) + } + } + elseThunk(env) + } + } + } + + private fun compileStruct(expr: PartiqlPhysical.Expr.Struct): PhysicalPlanThunk { + val structParts = compileStructParts(expr.parts) + + val ordering = if (expr.parts.none { it is PartiqlPhysical.StructPart.StructFields }) + StructOrdering.ORDERED + else + StructOrdering.UNORDERED + + return thunkFactory.thunkEnv(expr.metas) { env -> + val columns = mutableListOf() + for (element in structParts) { + when (element) { + is CompiledStructPart.Field -> { + val fieldName = element.nameThunk(env) + when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> + if (!fieldName.type.isText) { + err( + "Found struct field key to be of type ${fieldName.type}", + ErrorCode.EVALUATOR_NON_TEXT_STRUCT_FIELD_KEY, + errorContextFrom(expr.metas.sourceLocationMeta).also { pvm -> + pvm[Property.ACTUAL_TYPE] = fieldName.type.toString() + }, + internal = false + ) + } + TypingMode.PERMISSIVE -> + if (!fieldName.type.isText) { + continue + } + } + val fieldValue = element.valueThunk(env) + columns.add(fieldValue.namedValue(fieldName)) + } + is CompiledStructPart.StructMerge -> { + for (projThunk in element.thunks) { + val value = projThunk(env) + if (value.type == ExprValueType.MISSING) continue + + val children = value.asSequence() + if (!children.any() || value.type.isSequence) { + val name = syntheticColumnName(columns.size).exprValue() + columns.add(value.namedValue(name)) + } else { + val valuesToProject = + when (evaluatorOptions.projectionIteration) { + ProjectionIterationBehavior.FILTER_MISSING -> { + value.filter { it.type != ExprValueType.MISSING } + } + ProjectionIterationBehavior.UNFILTERED -> value + } + for (childValue in valuesToProject) { + val namedFacet = childValue.asFacet(Named::class.java) + val name = namedFacet?.name + ?: syntheticColumnName(columns.size).exprValue() + columns.add(childValue.namedValue(name)) + } + } + } + } + } + } + createStructExprValue(columns.asSequence(), ordering) + } + } + + private fun compileStructParts(projectItems: List): List = + projectItems.map { it -> + when (it) { + is PartiqlPhysical.StructPart.StructField -> { + val fieldThunk = compileAstExpr(it.fieldName) + val valueThunk = compileAstExpr(it.value) + CompiledStructPart.Field(fieldThunk, valueThunk) + } + is PartiqlPhysical.StructPart.StructFields -> { + CompiledStructPart.StructMerge(listOf(compileAstExpr(it.partExpr))) + } + } + } + + private fun compileSeq(seqType: ExprValueType, itemExprs: List, metas: MetaContainer): PhysicalPlanThunk { + require(seqType.isSequence) { "seqType must be a sequence!" } + + val itemThunks = compileAstExprs(itemExprs) + + val makeItemThunkSequence = when (seqType) { + ExprValueType.BAG -> { env: EvaluatorState -> + itemThunks.asSequence().map { itemThunk -> + // call to unnamedValue() makes sure we don't expose any underlying value name/ordinal + itemThunk(env).unnamedValue() + } + } + else -> { env: EvaluatorState -> + itemThunks.asSequence().mapIndexed { i, itemThunk -> itemThunk(env).namedValue(i.exprValue()) } + } + } + + return thunkFactory.thunkEnv(metas) { env -> + // todo: use valueFactory.newSequence() instead. + SequenceExprValue( + valueFactory.ion, + seqType, + makeItemThunkSequence(env) + ) + } + } + + @Suppress("UNUSED_PARAMETER") + private fun compileCallAgg(expr: PartiqlPhysical.Expr.CallAgg, metas: MetaContainer): PhysicalPlanThunk = TODO("call_agg") + + private fun compilePath(expr: PartiqlPhysical.Expr.Path, metas: MetaContainer): PhysicalPlanThunk { + val rootThunk = compileAstExpr(expr.root) + val remainingComponents = LinkedList() + + expr.steps.forEach { remainingComponents.addLast(it) } + + val componentThunk = compilePathComponents(remainingComponents, metas) + + return thunkFactory.thunkEnv(metas) { env -> + val rootValue = rootThunk(env) + componentThunk(env, rootValue) + } + } + + private fun compilePathComponents( + remainingComponents: LinkedList, + pathMetas: MetaContainer + ): PhsycialPlanThunkValue { + + val componentThunks = ArrayList>() + + while (!remainingComponents.isEmpty()) { + val pathComponent = remainingComponents.removeFirst() + val componentMetas = pathComponent.metas + componentThunks.add( + when (pathComponent) { + is PartiqlPhysical.PathStep.PathExpr -> { + val indexExpr = pathComponent.index + val caseSensitivity = pathComponent.case + when { + // If indexExpr is a literal string, there is no need to evaluate it--just compile a + // thunk that directly returns a bound value + indexExpr is PartiqlPhysical.Expr.Lit && indexExpr.value.toIonValue(valueFactory.ion) is IonString -> { + val lookupName = BindingName( + indexExpr.value.toIonValue(valueFactory.ion).stringValue()!!, + caseSensitivity.toBindingCase() + ) + thunkFactory.thunkEnvValue(componentMetas) { _, componentValue -> + componentValue.bindings[lookupName] ?: valueFactory.missingValue + } + } + else -> { + val indexThunk = compileAstExpr(indexExpr) + thunkFactory.thunkEnvValue(componentMetas) { env, componentValue -> + val indexValue = indexThunk(env) + when { + indexValue.type == ExprValueType.INT -> { + componentValue.ordinalBindings[indexValue.numberValue().toInt()] + } + indexValue.type.isText -> { + val lookupName = + BindingName(indexValue.stringValue(), caseSensitivity.toBindingCase()) + componentValue.bindings[lookupName] + } + else -> { + when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> err( + "Cannot convert index to int/string: $indexValue", + ErrorCode.EVALUATOR_INVALID_CONVERSION, + errorContextFrom(componentMetas), + internal = false + ) + TypingMode.PERMISSIVE -> valueFactory.missingValue + } + } + } ?: valueFactory.missingValue + } + } + } + } + is PartiqlPhysical.PathStep.PathUnpivot -> { + when { + !remainingComponents.isEmpty() -> { + val tempThunk = compilePathComponents(remainingComponents, pathMetas) + thunkFactory.thunkEnvValue(componentMetas) { env, componentValue -> + val mapped = componentValue.unpivot() + .flatMap { tempThunk(env, it).rangeOver() } + .asSequence() + valueFactory.newBag(mapped) + } + } + else -> + thunkFactory.thunkEnvValue(componentMetas) { _, componentValue -> + valueFactory.newBag(componentValue.unpivot().asSequence()) + } + } + } + // this is for `path[*].component` + is PartiqlPhysical.PathStep.PathWildcard -> { + when { + !remainingComponents.isEmpty() -> { + val hasMoreWildCards = + remainingComponents.filterIsInstance().any() + val tempThunk = compilePathComponents(remainingComponents, pathMetas) + + when { + !hasMoreWildCards -> thunkFactory.thunkEnvValue(componentMetas) { env, componentValue -> + val mapped = componentValue + .rangeOver() + .map { tempThunk(env, it) } + .asSequence() + + valueFactory.newBag(mapped) + } + else -> thunkFactory.thunkEnvValue(componentMetas) { env, componentValue -> + val mapped = componentValue + .rangeOver() + .flatMap { + val tempValue = tempThunk(env, it) + tempValue + } + .asSequence() + + valueFactory.newBag(mapped) + } + } + } + else -> { + thunkFactory.thunkEnvValue(componentMetas) { _, componentValue -> + val mapped = componentValue.rangeOver().asSequence() + valueFactory.newBag(mapped) + } + } + } + } + } + ) + } + return when (componentThunks.size) { + 1 -> componentThunks.first() + else -> thunkFactory.thunkEnvValue(pathMetas) { env, rootValue -> + componentThunks.fold(rootValue) { componentValue, componentThunk -> + componentThunk(env, componentValue) + } + } + } + } + + /** + * Given an AST node that represents a `LIKE` predicate, return an ExprThunk that evaluates a `LIKE` predicate. + * + * Three cases + * + * 1. All arguments are literals, then compile and run the pattern + * 1. Search pattern and escape pattern are literals, compile the pattern. Running the pattern deferred to evaluation time. + * 1. Pattern or escape (or both) are *not* literals, compile and running of pattern deferred to evaluation time. + * + * ``` + * LIKE [ESCAPE ] + * ``` + * + * @return a thunk that when provided with an environment evaluates the `LIKE` predicate + */ + private fun compileLike(expr: PartiqlPhysical.Expr.Like, metas: MetaContainer): PhysicalPlanThunk { + val valueExpr = expr.value + val patternExpr = expr.pattern + val escapeExpr = expr.escape + + val patternLocationMeta = patternExpr.metas.toPartiQlMetaContainer().sourceLocation + val escapeLocationMeta = escapeExpr?.metas?.toPartiQlMetaContainer()?.sourceLocation + + // This is so that null short-circuits can be supported. + fun getRegexPattern(pattern: ExprValue, escape: ExprValue?): (() -> Pattern)? { + val patternArgs = listOfNotNull(pattern, escape) + when { + patternArgs.any { it.type.isUnknown } -> return null + patternArgs.any { !it.type.isText } -> return { + err( + "LIKE expression must be given non-null strings as input", + ErrorCode.EVALUATOR_LIKE_INVALID_INPUTS, + errorContextFrom(metas).also { + it[Property.LIKE_PATTERN] = pattern.ionValue.toString() + if (escape != null) it[Property.LIKE_ESCAPE] = escape.ionValue.toString() + }, + internal = false + ) + } + else -> { + val (patternString: String, escapeChar: Int?) = + checkPattern(pattern.ionValue, patternLocationMeta, escape?.ionValue, escapeLocationMeta) + val likeRegexPattern = when { + patternString.isEmpty() -> Pattern.compile("") + else -> parsePattern(patternString, escapeChar) + } + return { likeRegexPattern } + } + } + } + + fun matchRegexPattern(value: ExprValue, likePattern: (() -> Pattern)?): ExprValue { + return when { + likePattern == null || value.type.isUnknown -> valueFactory.nullValue + !value.type.isText -> err( + "LIKE expression must be given non-null strings as input", + ErrorCode.EVALUATOR_LIKE_INVALID_INPUTS, + errorContextFrom(metas).also { + it[Property.LIKE_VALUE] = value.ionValue.toString() + }, + internal = false + ) + else -> valueFactory.newBoolean(likePattern().matcher(value.stringValue()).matches()) + } + } + + val valueThunk = compileAstExpr(valueExpr) + + // If the pattern and escape expressions are literals then we can compile the pattern now and + // re-use it with every execution. Otherwise, we must re-compile the pattern every time. + return when { + patternExpr is PartiqlPhysical.Expr.Lit && (escapeExpr == null || escapeExpr is PartiqlPhysical.Expr.Lit) -> { + val patternParts = getRegexPattern( + valueFactory.newFromIonValue(patternExpr.value.toIonValue(valueFactory.ion)), + (escapeExpr as? PartiqlPhysical.Expr.Lit)?.value?.toIonValue(valueFactory.ion) + ?.let { valueFactory.newFromIonValue(it) } + ) + + // If valueExpr is also a literal then we can evaluate this at compile time and return a constant. + if (valueExpr is PartiqlPhysical.Expr.Lit) { + val resultValue = matchRegexPattern( + valueFactory.newFromIonValue(valueExpr.value.toIonValue(valueFactory.ion)), + patternParts + ) + return thunkFactory.thunkEnv(metas) { resultValue } + } else { + thunkFactory.thunkEnvOperands(metas, valueThunk) { _, value -> + matchRegexPattern(value, patternParts) + } + } + } + else -> { + val patternThunk = compileAstExpr(patternExpr) + when (escapeExpr) { + null -> { + // thunk that re-compiles the DFA every evaluation without a custom escape sequence + thunkFactory.thunkEnvOperands(metas, valueThunk, patternThunk) { _, value, pattern -> + val pps = getRegexPattern(pattern, null) + matchRegexPattern(value, pps) + } + } + else -> { + // thunk that re-compiles the pattern every evaluation but *with* a custom escape sequence + val escapeThunk = compileAstExpr(escapeExpr) + thunkFactory.thunkEnvOperands( + metas, + valueThunk, + patternThunk, + escapeThunk + ) { _, value, pattern, escape -> + val pps = getRegexPattern(pattern, escape) + matchRegexPattern(value, pps) + } + } + } + } + } + } + + /** + * Given the pattern and optional escape character in a `LIKE` predicate as [IonValue]s + * check their validity based on the SQL92 spec and return a triple that contains in order + * + * - the search pattern as a string + * - the escape character, possibly `null` + * - the length of the search pattern. The length of the search pattern is either + * - the length of the string representing the search pattern when no escape character is used + * - the length of the string representing the search pattern without counting uses of the escape character + * when an escape character is used + * + * A search pattern is valid when + * 1. pattern is not null + * 1. pattern contains characters where `_` means any 1 character and `%` means any string of length 0 or more + * 1. if the escape character is specified then pattern can be deterministically partitioned into character groups where + * 1. A length 1 character group consists of any character other than the ESCAPE character + * 1. A length 2 character group consists of the ESCAPE character followed by either `_` or `%` or the ESCAPE character itself + * + * @param pattern search pattern + * @param escape optional escape character provided in the `LIKE` predicate + * + * @return a triple that contains in order the search pattern as a [String], optionally the code point for the escape character if one was provided + * and the size of the search pattern excluding uses of the escape character + */ + private fun checkPattern( + pattern: IonValue, + patternLocationMeta: SourceLocationMeta?, + escape: IonValue?, + escapeLocationMeta: SourceLocationMeta? + ): Pair { + + val patternString = pattern.stringValue() + ?: err( + "Must provide a non-null value for PATTERN in a LIKE predicate: $pattern", + ErrorCode.EVALUATOR_LIKE_PATTERN_INVALID_ESCAPE_SEQUENCE, + errorContextFrom(patternLocationMeta), + internal = false + ) + + escape?.let { + val escapeCharString = checkEscapeChar(escape, escapeLocationMeta) + val escapeCharCodePoint = escapeCharString.codePointAt(0) // escape is a string of length 1 + val validEscapedChars = setOf('_'.toInt(), '%'.toInt(), escapeCharCodePoint) + val iter = patternString.codePointSequence().iterator() + + while (iter.hasNext()) { + val current = iter.next() + if (current == escapeCharCodePoint && (!iter.hasNext() || !validEscapedChars.contains(iter.next()))) { + err( + "Invalid escape sequence : $patternString", + ErrorCode.EVALUATOR_LIKE_PATTERN_INVALID_ESCAPE_SEQUENCE, + errorContextFrom(patternLocationMeta).apply { + set(Property.LIKE_PATTERN, patternString) + set(Property.LIKE_ESCAPE, escapeCharString) + }, + internal = false + ) + } + } + return Pair(patternString, escapeCharCodePoint) + } + return Pair(patternString, null) + } + + /** + * Given an [IonValue] to be used as the escape character in a `LIKE` predicate check that it is + * a valid character based on the SQL Spec. + * + * + * A value is a valid escape when + * 1. it is 1 character long, and, + * 1. Cannot be null (SQL92 spec marks this cases as *unknown*) + * + * @param escape value provided as an escape character for a `LIKE` predicate + * + * @return the escape character as a [String] or throws an exception when the input is invalid + */ + private fun checkEscapeChar(escape: IonValue, locationMeta: SourceLocationMeta?): String { + val escapeChar = escape.stringValue() ?: err( + "Must provide a value when using ESCAPE in a LIKE predicate: $escape", + ErrorCode.EVALUATOR_LIKE_PATTERN_INVALID_ESCAPE_SEQUENCE, + errorContextFrom(locationMeta), + internal = false + ) + when (escapeChar) { + "" -> { + err( + "Cannot use empty character as ESCAPE character in a LIKE predicate: $escape", + ErrorCode.EVALUATOR_LIKE_PATTERN_INVALID_ESCAPE_SEQUENCE, + errorContextFrom(locationMeta), + internal = false + ) + } + else -> { + if (escapeChar.trim().length != 1) { + err( + "Escape character must have size 1 : $escapeChar", + ErrorCode.EVALUATOR_LIKE_PATTERN_INVALID_ESCAPE_SEQUENCE, + errorContextFrom(locationMeta), + internal = false + ) + } + } + } + return escapeChar + } + + private fun compileExec(node: PartiqlPhysical.Statement.Exec): PhysicalPlanThunk { + val metas = node.metas + val procedureName = node.procedureName.text + val procedure = procedures[procedureName] ?: err( + "No such stored procedure: $procedureName", + ErrorCode.EVALUATOR_NO_SUCH_PROCEDURE, + errorContextFrom(metas).also { + it[Property.PROCEDURE_NAME] = procedureName + }, + internal = false + ) + + val args = node.args + // Check arity + if (args.size !in procedure.signature.arity) { + val errorContext = errorContextFrom(metas).also { + it[Property.EXPECTED_ARITY_MIN] = procedure.signature.arity.first + it[Property.EXPECTED_ARITY_MAX] = procedure.signature.arity.last + } + + val message = when { + procedure.signature.arity.first == 1 && procedure.signature.arity.last == 1 -> + "${procedure.signature.name} takes a single argument, received: ${args.size}" + procedure.signature.arity.first == procedure.signature.arity.last -> + "${procedure.signature.name} takes exactly ${procedure.signature.arity.first} arguments, received: ${args.size}" + else -> + "${procedure.signature.name} takes between ${procedure.signature.arity.first} and " + + "${procedure.signature.arity.last} arguments, received: ${args.size}" + } + + throw EvaluationException( + message, + ErrorCode.EVALUATOR_INCORRECT_NUMBER_OF_ARGUMENTS_TO_PROCEDURE_CALL, + errorContext, + internal = false + ) + } + + // Compile the procedure's arguments + val argThunks = compileAstExprs(args) + + return thunkFactory.thunkEnv(metas) { env -> + val procedureArgValues = argThunks.map { it(env) } + procedure.call(env.session, procedureArgValues) + } + } + + private fun compileDate(expr: PartiqlPhysical.Expr.Date, metas: MetaContainer): PhysicalPlanThunk = + thunkFactory.thunkEnv(metas) { + valueFactory.newDate( + expr.year.value.toInt(), + expr.month.value.toInt(), + expr.day.value.toInt() + ) + } + + private fun compileLitTime(expr: PartiqlPhysical.Expr.LitTime, metas: MetaContainer): PhysicalPlanThunk = + thunkFactory.thunkEnv(metas) { + // Add the default time zone if the type "TIME WITH TIME ZONE" does not have an explicitly specified time zone. + valueFactory.newTime( + Time.of( + expr.value.hour.value.toInt(), + expr.value.minute.value.toInt(), + expr.value.second.value.toInt(), + expr.value.nano.value.toInt(), + expr.value.precision.value.toInt(), + if (expr.value.withTimeZone.value && expr.value.tzMinutes == null) evaluatorOptions.defaultTimezoneOffset.totalMinutes else expr.value.tzMinutes?.value?.toInt() + ) + ) + } + + /** A special wrapper for `UNPIVOT` values as a BAG. */ + private class UnpivotedExprValue(private val values: Iterable) : BaseExprValue() { + override val type = ExprValueType.BAG + override fun iterator() = values.iterator() + + // XXX this value is only ever produced in a FROM iteration, thus none of these should ever be called + override val ionValue + get() = throw UnsupportedOperationException("Synthetic value cannot provide ion value") + } + + /** Unpivots a `struct`, and synthesizes a synthetic singleton `struct` for other [ExprValue]. */ + internal fun ExprValue.unpivot(): ExprValue = when { + // special case for our special UNPIVOT value to avoid double wrapping + this is UnpivotedExprValue -> this + // Wrap into a pseudo-BAG + type == ExprValueType.STRUCT || type == ExprValueType.MISSING -> UnpivotedExprValue(this) + // for non-struct, this wraps any value into a BAG with a synthetic name + else -> UnpivotedExprValue( + listOf( + this.namedValue(valueFactory.newString(syntheticColumnName(0))) + ) + ) + } + + private fun createStructExprValue(seq: Sequence, ordering: StructOrdering) = + valueFactory.newStruct( + when (evaluatorOptions.projectionIteration) { + ProjectionIterationBehavior.FILTER_MISSING -> seq.filter { it.type != ExprValueType.MISSING } + ProjectionIterationBehavior.UNFILTERED -> seq + }, + ordering + ) +} + +internal val MetaContainer.sourceLocationMeta get() = this[SourceLocationMeta.TAG] as? SourceLocationMeta + +internal fun StaticType.getTypes() = when (val flattened = this.flatten()) { + is AnyOfType -> flattened.types + else -> listOf(this) +} + +/** + * Represents an element in a select list that is to be projected into the final result. + * i.e. an expression, or a (project_all) node. + */ +private sealed class CompiledStructPart { + + /** + * Represents a single compiled expression to be projected into the final result. + * Given `SELECT a + b as value FROM foo`: + * - `name` is "value" + * - `thunk` is compiled expression, i.e. `a + b` + */ + class Field(val nameThunk: PhysicalPlanThunk, val valueThunk: PhysicalPlanThunk) : CompiledStructPart() + + /** + * Represents a wildcard ((path_project_all) node) expression to be projected into the final result. + * This covers two cases. For `SELECT foo.* FROM foo`, `exprThunks` contains a single compiled expression + * `foo`. + * + * For `SELECT * FROM foo, bar, bat`, `exprThunks` would contain a compiled expression for each of `foo`, `bar` and + * `bat`. + */ + class StructMerge(val thunks: List) : CompiledStructPart() +} diff --git a/lang/src/org/partiql/lang/eval/physical/RelationThunk.kt b/lang/src/org/partiql/lang/eval/physical/RelationThunk.kt new file mode 100644 index 0000000000..eee6e976f9 --- /dev/null +++ b/lang/src/org/partiql/lang/eval/physical/RelationThunk.kt @@ -0,0 +1,44 @@ +package org.partiql.lang.eval.physical + +import com.amazon.ionelement.api.MetaContainer +import org.partiql.lang.ast.SourceLocationMeta +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.eval.EvaluationException +import org.partiql.lang.eval.errorContextFrom +import org.partiql.lang.eval.fillErrorContext +import org.partiql.lang.eval.relation.RelationIterator + +/** A thunk that returns a [RelationIterator], which is the result of evaluating a relational operator. */ +internal typealias RelationThunkEnv = (EvaluatorState) -> RelationIterator + +/** + * Invokes [t] with error handling like is supplied by [org.partiql.lang.eval.ThunkFactory]. + * + * This function is not currently in `ThunkFactory` to avoid complicating `ThunkFactory` further. If a need arises, + * it could be moved to `ThunkFactory`. + */ +internal inline fun relationThunk(metas: MetaContainer, crossinline t: RelationThunkEnv): RelationThunkEnv { + val sourceLocationMeta = metas[SourceLocationMeta.TAG] as? SourceLocationMeta + return { env: EvaluatorState -> + try { + t(env) + } catch (e: EvaluationException) { + // Only add source location data to the error context if it doesn't already exist + // in [errorContext]. + if (!e.errorContext.hasProperty(Property.LINE_NUMBER)) { + sourceLocationMeta?.let { fillErrorContext(e.errorContext, sourceLocationMeta) } + } + throw e + } catch (e: Exception) { + val message = e.message ?: "" + throw EvaluationException( + "Generic exception, $message", + errorCode = ErrorCode.EVALUATOR_GENERIC_EXCEPTION, + errorContext = errorContextFrom(sourceLocationMeta), + cause = e, + internal = true + ) + } + } +} diff --git a/lang/src/org/partiql/lang/eval/physical/UndefinedVariableUtil.kt b/lang/src/org/partiql/lang/eval/physical/UndefinedVariableUtil.kt new file mode 100644 index 0000000000..ee6bd16d04 --- /dev/null +++ b/lang/src/org/partiql/lang/eval/physical/UndefinedVariableUtil.kt @@ -0,0 +1,31 @@ +package org.partiql.lang.eval.physical + +import com.amazon.ionelement.api.MetaContainer +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.errors.UNBOUND_QUOTED_IDENTIFIER_HINT +import org.partiql.lang.eval.BindingCase +import org.partiql.lang.eval.BindingName +import org.partiql.lang.eval.EvaluationException +import org.partiql.lang.eval.errorContextFrom +import org.partiql.lang.util.propertyValueMapOf + +internal fun throwUndefinedVariableException( + bindingName: BindingName, + metas: MetaContainer? +): Nothing { + val (errorCode, hint) = when (bindingName.bindingCase) { + BindingCase.SENSITIVE -> + ErrorCode.EVALUATOR_QUOTED_BINDING_DOES_NOT_EXIST to " $UNBOUND_QUOTED_IDENTIFIER_HINT" + BindingCase.INSENSITIVE -> + ErrorCode.EVALUATOR_BINDING_DOES_NOT_EXIST to "" + } + throw EvaluationException( + message = "No such binding: ${bindingName.name}.$hint", + errorCode = errorCode, + errorContext = (metas?.let { errorContextFrom(metas) } ?: propertyValueMapOf()).also { + it[Property.BINDING_NAME] = bindingName.name + }, + internal = false + ) +} diff --git a/lang/src/org/partiql/lang/eval/relation/Relation.kt b/lang/src/org/partiql/lang/eval/relation/Relation.kt new file mode 100644 index 0000000000..9b8915861b --- /dev/null +++ b/lang/src/org/partiql/lang/eval/relation/Relation.kt @@ -0,0 +1,90 @@ +package org.partiql.lang.eval.relation + +import kotlin.coroutines.Continuation +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext +import kotlin.coroutines.createCoroutine +import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED +import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn +import kotlin.coroutines.resume + +/** + * Builds a [RelationIterator] that yields after every step in evaluating a relational operator. + * + * This is inspired heavily by Kotlin's [sequence] but for [RelationIterator] instead of [Sequence]. + */ +internal fun relation( + seqType: RelationType, + block: suspend RelationScope.() -> Unit +): RelationIterator { + val iterator = RelationBuilderIterator(seqType, block) + iterator.nextStep = block.createCoroutine(receiver = iterator, completion = iterator) + return iterator +} + +@DslMarker +@Target(AnnotationTarget.CLASS, AnnotationTarget.TYPE) +annotation class RelationDsl + +/** Defines functions within a block supplied to [relation]. */ +@RelationDsl +internal interface RelationScope { + /** Suspends the coroutine. Should be called after processing the current row of the relation. */ + suspend fun yield() + + /** Yields once for every row remaining in [relItr]. */ + suspend fun yieldAll(relItr: RelationIterator) +} + +private class RelationBuilderIterator( + override val relType: RelationType, + block: suspend RelationScope.() -> Unit +) : RelationScope, RelationIterator, Continuation { + var yielded = false + + var nextStep: Continuation? = block.createCoroutine(receiver = this, completion = this) + + override suspend fun yield() { + yielded = true + suspendCoroutineUninterceptedOrReturn { c -> + nextStep = c + COROUTINE_SUSPENDED + } + } + + override suspend fun yieldAll(relItr: RelationIterator) { + while (relItr.nextRow()) { + yield() + } + } + + override fun nextRow(): Boolean { + // if nextStep is null it means we've reached the end of the relation, but nextRow() was called again + // for some reason. This probably indicates a bug since we should not in general be attempting to + // read a `RelationIterator` after it has exhausted. + if (nextStep == null) { + error( + "Relation was previously exhausted. " + + "Please don't call nextRow() again after it returns false the first time." + ) + } + val step = nextStep!! + nextStep = null + step.resume(Unit) + + return if (yielded) { + yielded = false + true + } else { + false + } + } + + // Completion continuation implementation + override fun resumeWith(result: Result) { + result.getOrThrow() // just rethrow exception if it is there + } + + override val context: CoroutineContext + get() = EmptyCoroutineContext +} diff --git a/lang/src/org/partiql/lang/eval/relation/RelationIterator.kt b/lang/src/org/partiql/lang/eval/relation/RelationIterator.kt new file mode 100644 index 0000000000..5b6bfaeb8a --- /dev/null +++ b/lang/src/org/partiql/lang/eval/relation/RelationIterator.kt @@ -0,0 +1,34 @@ +package org.partiql.lang.eval.relation + +enum class RelationType { BAG, LIST } + +/** + * Represents an iterator that is returned by a relational operator during evaluation. + * + * This is a "faux" iterator in a sense, because it doesn't provide direct access to a current element. + * + * When initially created, the iterator is positioned "before" the first element. [nextRow] should be called to advance + * the iterator to the first row. + * + * We do not use [Iterator] for this purpose because it is not a natural fit. There are two reasons: + * + * 1. [Iterator.next] returns the current element, but this isn't actually an iterator over a collection. Instead, + * execution of [nextRow] may have a side effect of populating value(s) in the current + * [org.partiql.lang.eval.Environment.registers] array. Bridge operators such as + * [org.partiql.lang.domains.PartiqlPhysical.Expr.BindingsToValues] are responsible for extracting current values from + * [org.partiql.lang.eval.Environment.registers] and converting them to the appropriate container [ExprValue]s. + * 2. [Iterator.hasNext] requires knowing if additional rows remain after the current row, but in the case of a `filter` + * relation, this requires advancing through possibly all remaining rows to see if any remaining row matches the + * predicate. + */ +internal interface RelationIterator { + val relType: RelationType + + /** + * Advances the iterator to the next row. + * + * Returns true to indicate that the next row was found and that [org.partiql.lang.eval.Environment.registers] have + * been updated for the current row. False if there are no more rows. + */ + fun nextRow(): Boolean +} diff --git a/lang/src/org/partiql/lang/eval/visitors/PartiqlPhysicalSanityValidator.kt b/lang/src/org/partiql/lang/eval/visitors/PartiqlPhysicalSanityValidator.kt new file mode 100644 index 0000000000..2859260bd6 --- /dev/null +++ b/lang/src/org/partiql/lang/eval/visitors/PartiqlPhysicalSanityValidator.kt @@ -0,0 +1,121 @@ +package org.partiql.lang.eval.visitors + +import com.amazon.ionelement.api.IntElement +import com.amazon.ionelement.api.IntElementSize +import com.amazon.ionelement.api.MetaContainer +import com.amazon.ionelement.api.TextElement +import org.partiql.lang.ast.IsCountStarMeta +import org.partiql.lang.ast.passes.SemanticException +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.domains.addSourceLocation +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.errors.PropertyValueMap +import org.partiql.lang.eval.EvaluationException +import org.partiql.lang.eval.TypedOpBehavior +import org.partiql.lang.eval.err +import org.partiql.lang.eval.errorContextFrom +import org.partiql.lang.planner.EvaluatorOptions +import org.partiql.lang.util.propertyValueMapOf +import org.partiql.pig.runtime.LongPrimitive + +/** + * Provides rules for basic AST sanity checks that should be performed before any attempt at further phsycial + * plan processing. This is provided as a distinct [PartiqlPhysical.Visitor] so that the planner and evaluator may + * assume that the physical plan has passed the checks performed here. + * + * Any exception thrown by this class should always be considered an indication of a bug. + */ +class PartiqlPhysicalSanityValidator(private val evaluatorOptions: EvaluatorOptions) : PartiqlPhysical.Visitor() { + + /** + * Quick validation step to make sure the indexes of any variables make sense. + * It is unlikely that this check will ever fail, but if it does, it likely means there's a bug in + * [org.partiql.lang.planner.transforms.VariableIdAllocator] or that the plan was malformed by other means. + */ + override fun visitPlan(node: PartiqlPhysical.Plan) { + node.locals.forEachIndexed { idx, it -> + if (it.registerIndex.value != idx.toLong()) { + throw EvaluationException( + message = "Variable index must match ordinal position of variable", + errorCode = ErrorCode.INTERNAL_ERROR, + errorContext = propertyValueMapOf(), + internal = true + ) + } + } + super.visitPlan(node) + } + + override fun visitExprLit(node: PartiqlPhysical.Expr.Lit) { + val ionValue = node.value + val metas = node.metas + if (node.value is IntElement && ionValue.integerSize == IntElementSize.BIG_INTEGER) { + throw EvaluationException( + message = "Int overflow or underflow at compile time", + errorCode = ErrorCode.SEMANTIC_LITERAL_INT_OVERFLOW, + errorContext = errorContextFrom(metas), + internal = false + ) + } + } + + private fun validateDecimalOrNumericType(scale: LongPrimitive?, precision: LongPrimitive?, metas: MetaContainer) { + if (scale != null && precision != null && evaluatorOptions.typedOpBehavior == TypedOpBehavior.HONOR_PARAMETERS) { + if (scale.value !in 0..precision.value) { + err( + "Scale ${scale.value} should be between 0 and precision ${precision.value}", + errorCode = ErrorCode.SEMANTIC_INVALID_DECIMAL_ARGUMENTS, + errorContext = errorContextFrom(metas), + internal = false + ) + } + } + } + + override fun visitTypeDecimalType(node: PartiqlPhysical.Type.DecimalType) { + validateDecimalOrNumericType(node.scale, node.precision, node.metas) + } + + override fun visitTypeNumericType(node: PartiqlPhysical.Type.NumericType) { + validateDecimalOrNumericType(node.scale, node.precision, node.metas) + } + + override fun visitExprCallAgg(node: PartiqlPhysical.Expr.CallAgg) { + val setQuantifier = node.setq + val metas = node.metas + if (setQuantifier is PartiqlPhysical.SetQuantifier.Distinct && metas.containsKey(IsCountStarMeta.TAG)) { + err( + "COUNT(DISTINCT *) is not supported", + ErrorCode.EVALUATOR_COUNT_DISTINCT_STAR, + errorContextFrom(metas), + internal = false + ) + } + } + + override fun visitExprStruct(node: PartiqlPhysical.Expr.Struct) { + node.parts.forEach { part -> + when (part) { + is PartiqlPhysical.StructPart.StructField -> { + if (part.fieldName is PartiqlPhysical.Expr.Missing || + (part.fieldName is PartiqlPhysical.Expr.Lit && part.fieldName.value !is TextElement) + ) { + val type = when (part.fieldName) { + is PartiqlPhysical.Expr.Lit -> part.fieldName.value.type.toString() + else -> "MISSING" + } + throw SemanticException( + "Found struct part to be of type $type", + ErrorCode.SEMANTIC_NON_TEXT_STRUCT_FIELD_KEY, + PropertyValueMap().addSourceLocation(part.fieldName.metas).also { pvm -> + pvm[Property.ACTUAL_TYPE] = type + } + ) + } + } + is PartiqlPhysical.StructPart.StructFields -> { /* intentionally empty */ } + } + } + } +} diff --git a/lang/src/org/partiql/lang/planner/MetadataResolver.kt b/lang/src/org/partiql/lang/planner/MetadataResolver.kt index 2f677b279a..af13cb00e2 100644 --- a/lang/src/org/partiql/lang/planner/MetadataResolver.kt +++ b/lang/src/org/partiql/lang/planner/MetadataResolver.kt @@ -52,7 +52,7 @@ interface MetadataResolver { * without providing an error. (This is consistent with Postres's behavior in this scenario.) * * Note that while [ResolutionResult.LocalVariable] exists, it is intentionally marked `internal` and cannot - * be used by outside this project. + * be used outside this project. */ fun resolveVariable(bindingName: BindingName): ResolutionResult } diff --git a/lang/src/org/partiql/lang/planner/PlannerPipeline.kt b/lang/src/org/partiql/lang/planner/PlannerPipeline.kt index ba9ed67e24..a632af939e 100644 --- a/lang/src/org/partiql/lang/planner/PlannerPipeline.kt +++ b/lang/src/org/partiql/lang/planner/PlannerPipeline.kt @@ -26,7 +26,10 @@ import org.partiql.lang.eval.ExprFunction import org.partiql.lang.eval.ExprValueFactory import org.partiql.lang.eval.Expression import org.partiql.lang.eval.ThunkReturnTypeAssertions +import org.partiql.lang.eval.builtins.DynamicLookupExprFunction import org.partiql.lang.eval.builtins.createBuiltinFunctions +import org.partiql.lang.eval.builtins.storedprocedure.StoredProcedure +import org.partiql.lang.eval.physical.PhysicalExprToThunkConverterImpl import org.partiql.lang.planner.transforms.PlanningProblemDetails import org.partiql.lang.planner.transforms.normalize import org.partiql.lang.planner.transforms.toDefaultPhysicalPlan @@ -35,6 +38,7 @@ import org.partiql.lang.planner.transforms.toResolvedPlan import org.partiql.lang.syntax.Parser import org.partiql.lang.syntax.SqlParser import org.partiql.lang.syntax.SyntaxException +import org.partiql.lang.types.CustomType /** * [PlannerPipeline] is the main interface for planning and compiling PartiQL queries into instances of [Expression] @@ -159,6 +163,9 @@ interface PlannerPipeline { class Builder(val valueFactory: ExprValueFactory) { private var parser: Parser? = null private var evaluatorOptions: EvaluatorOptions? = null + private val customFunctions: MutableMap = HashMap() + private var customDataTypes: List = listOf() + private val customProcedures: MutableMap = HashMap() private var metadataResolver: MetadataResolver = emptyMetadataResolver() private var allowUndefinedVariables: Boolean = false private var enableLegacyExceptionHandling: Boolean = false @@ -189,6 +196,44 @@ interface PlannerPipeline { fun evaluatorOptions(block: EvaluatorOptions.Builder.() -> Unit): Builder = evaluatorOptions(EvaluatorOptions.build(block)) + /** + * Add a custom function which will be callable by the compiled queries. + * + * Functions added here will replace any built-in function with the same name. + * + * This function is marked as internal to prevent it from being used outside the tests in this + * project--it will be replaced during implementation of the open type system. + * https://github.com/partiql/partiql-lang-kotlin/milestone/4 + */ + internal fun addFunction(function: ExprFunction): Builder = this.apply { + customFunctions[function.signature.name] = function + } + + /** + * Add custom types to CAST/IS operators to. + * + * Built-in types will take precedence over custom types in case of a name collision. + * + * This function is marked as internal to prevent it from being used outside the tests in this + * project--it will be replaced during implementation of the open type system. + * https://github.com/partiql/partiql-lang-kotlin/milestone/4 + */ + internal fun customDataTypes(customTypes: List) = this.apply { + customDataTypes = customTypes + } + + /** + * Add a custom stored procedure which will be callable by the compiled queries. + * + * Stored procedures added here will replace any built-in procedure with the same name. + * This function is marked as internal to prevent it from being used outside the tests in this + * project--it will be replaced during implementation of the open type system. + * https://github.com/partiql/partiql-lang-kotlin/milestone/4 + */ + internal fun addProcedure(procedure: StoredProcedure): Builder = this.apply { + customProcedures[procedure.signature.name] = procedure + } + /** * Adds the [MetadataResolver] for global variables. * @@ -231,18 +276,21 @@ interface PlannerPipeline { ) } - val builtinFunctions = createBuiltinFunctions(valueFactory) - // TODO: uncomment when DynamicLookupExprFunction exists -// val builtinFunctions = createBuiltinFunctions(valueFactory) + DynamicLookupExprFunction() + val builtinFunctions = createBuiltinFunctions(valueFactory) + DynamicLookupExprFunction() val builtinFunctionsMap = builtinFunctions.associateBy { it.signature.name } + // customFunctions must be on the right side of + here to ensure that they overwrite any + // built-in functions with the same name. + val allFunctionsMap = builtinFunctionsMap + customFunctions return PlannerPipelineImpl( valueFactory = valueFactory, - parser = parser ?: SqlParser(valueFactory.ion), + parser = parser ?: SqlParser(valueFactory.ion, this.customDataTypes), evaluatorOptions = compileOptionsToUse, - functions = builtinFunctionsMap, + functions = allFunctionsMap, + customDataTypes = customDataTypes, + procedures = customProcedures, metadataResolver = metadataResolver, allowUndefinedVariables = allowUndefinedVariables, enableLegacyExceptionHandling = enableLegacyExceptionHandling @@ -256,6 +304,8 @@ internal class PlannerPipelineImpl( private val parser: Parser, val evaluatorOptions: EvaluatorOptions, val functions: Map, + val customDataTypes: List, + val procedures: Map, val metadataResolver: MetadataResolver, val allowUndefinedVariables: Boolean, val enableLegacyExceptionHandling: Boolean @@ -272,6 +322,12 @@ internal class PlannerPipelineImpl( } } + val customTypedOpParameters = customDataTypes.map { customType -> + (customType.aliases + customType.name).map { alias -> + Pair(alias.toLowerCase(), customType.typedOpParameter) + } + }.flatten().toMap() + override fun plan(query: String): PassResult { val ast = try { parser.parseAstStatement(query) @@ -321,35 +377,34 @@ internal class PlannerPipelineImpl( } override fun compile(physcialPlan: PartiqlPhysical.Plan): PassResult { - TODO("uncomment the code below in the PR introducing the plan evaluator") -// val compiler = PhysicalExprToThunkConverterImpl( -// valueFactory = valueFactory, -// functions = functions, -// customTypedOpParameters = customTypedOpParameters, -// procedures = procedures, -// evaluatorOptions = evaluatorOptions -// ) -// -// val expression = when { -// enableLegacyExceptionHandling -> compiler.compile(physcialPlan) -// else -> { -// // Legacy exception handling is disabled, convert any [SqlException] into a Problem and return -// // PassResult.Error. -// try { -// compiler.compile(physcialPlan) -// } catch (e: SqlException) { -// val problem = Problem( -// SourceLocationMeta( -// e.errorContext[Property.LINE_NUMBER]?.longValue() ?: -1, -// e.errorContext[Property.COLUMN_NUMBER]?.longValue() ?: -1 -// ), -// PlanningProblemDetails.CompileError(e.generateMessageNoLocation()) -// ) -// return PassResult.Error(listOf(problem)) -// } -// } -// } -// -// return PassResult.Success(expression, listOf()) + val compiler = PhysicalExprToThunkConverterImpl( + valueFactory = valueFactory, + functions = functions, + customTypedOpParameters = customTypedOpParameters, + procedures = procedures, + evaluatorOptions = evaluatorOptions + ) + + val expression = when { + enableLegacyExceptionHandling -> compiler.compile(physcialPlan) + else -> { + // Legacy exception handling is disabled, convert any [SqlException] into a Problem and return + // PassResult.Error. + try { + compiler.compile(physcialPlan) + } catch (e: SqlException) { + val problem = Problem( + SourceLocationMeta( + e.errorContext[Property.LINE_NUMBER]?.longValue() ?: -1, + e.errorContext[Property.COLUMN_NUMBER]?.longValue() ?: -1 + ), + PlanningProblemDetails.CompileError(e.generateMessageNoLocation()) + ) + return PassResult.Error(listOf(problem)) + } + } + } + + return PassResult.Success(expression, listOf()) } } diff --git a/lang/src/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransform.kt b/lang/src/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransform.kt index 3f4f2ccfa1..4c0f569e14 100644 --- a/lang/src/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransform.kt +++ b/lang/src/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransform.kt @@ -85,25 +85,6 @@ internal fun PartiqlLogical.Plan.toResolvedPlan( return resolvedSt } -private fun PartiqlLogical.Expr.Id.asGlobalId(uniqueId: String): PartiqlLogicalResolved.Expr.GlobalId = - PartiqlLogicalResolved.build { - globalId_( - name = name, - uniqueId = uniqueId.asPrimitive(), - metas = this@asGlobalId.metas - ) - } - -private fun PartiqlLogical.Expr.Id.asLocalId(index: Int): PartiqlLogicalResolved.Expr = - PartiqlLogicalResolved.build { - localId_(index.asPrimitive(), this@asLocalId.metas) - } - -private fun PartiqlLogical.Expr.Id.asErrorId(): PartiqlLogicalResolved.Expr = - PartiqlLogicalResolved.build { - localId_((-1).asPrimitive(), this@asErrorId.metas) - } - /** * A local scope is a list of variable declarations that are produced by a relational operator and an optional * reference to a parent scope. This is handled separately from global variables. @@ -153,6 +134,25 @@ private data class LogicalToLogicalResolvedVisitorTransform( } } + private fun PartiqlLogical.Expr.Id.asGlobalId(uniqueId: String): PartiqlLogicalResolved.Expr.GlobalId = + PartiqlLogicalResolved.build { + globalId_( + uniqueId = uniqueId.asPrimitive(), + case = this@LogicalToLogicalResolvedVisitorTransform.transformCaseSensitivity(this@asGlobalId.case), + metas = this@asGlobalId.metas + ) + } + + private fun PartiqlLogical.Expr.Id.asLocalId(index: Int): PartiqlLogicalResolved.Expr = + PartiqlLogicalResolved.build { + localId_(index.asPrimitive(), this@asLocalId.metas) + } + + private fun PartiqlLogical.Expr.Id.asErrorId(): PartiqlLogicalResolved.Expr = + PartiqlLogicalResolved.build { + localId_((-1).asPrimitive(), this@asErrorId.metas) + } + override fun transformPlan(node: PartiqlLogical.Plan): PartiqlLogicalResolved.Plan = PartiqlLogicalResolved.build { plan_( diff --git a/lang/src/org/partiql/lang/types/PartiqlAstTypeExtensions.kt b/lang/src/org/partiql/lang/types/PartiqlAstTypeExtensions.kt index d9ef4f9096..d726074cae 100644 --- a/lang/src/org/partiql/lang/types/PartiqlAstTypeExtensions.kt +++ b/lang/src/org/partiql/lang/types/PartiqlAstTypeExtensions.kt @@ -1,85 +1,20 @@ package org.partiql.lang.types import org.partiql.lang.domains.PartiqlAst +import org.partiql.lang.domains.PartiqlPhysical -/** - * Helper to convert [PartiqlAst.Type] in AST to a [TypedOpParameter]. - */ -fun PartiqlAst.Type.toTypedOpParameter(customTypedOpParameters: Map): TypedOpParameter = when (this) { - is PartiqlAst.Type.MissingType -> TypedOpParameter(StaticType.MISSING) - is PartiqlAst.Type.NullType -> TypedOpParameter(StaticType.NULL) - is PartiqlAst.Type.BooleanType -> TypedOpParameter(StaticType.BOOL) - is PartiqlAst.Type.SmallintType -> TypedOpParameter(IntType(IntType.IntRangeConstraint.SHORT)) - is PartiqlAst.Type.Integer4Type -> TypedOpParameter(IntType(IntType.IntRangeConstraint.INT4)) - is PartiqlAst.Type.Integer8Type -> TypedOpParameter(IntType(IntType.IntRangeConstraint.LONG)) - is PartiqlAst.Type.IntegerType -> TypedOpParameter(IntType(IntType.IntRangeConstraint.LONG)) - is PartiqlAst.Type.FloatType, is PartiqlAst.Type.RealType, is PartiqlAst.Type.DoublePrecisionType -> TypedOpParameter(StaticType.FLOAT) - is PartiqlAst.Type.DecimalType -> when { - this.precision == null && this.scale == null -> TypedOpParameter(StaticType.DECIMAL) - this.precision != null && this.scale == null -> TypedOpParameter(DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(this.precision.value.toInt()))) - else -> TypedOpParameter( - DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(this.precision!!.value.toInt(), this.scale!!.value.toInt())) - ) - } - is PartiqlAst.Type.NumericType -> when { - this.precision == null && this.scale == null -> TypedOpParameter(StaticType.DECIMAL) - this.precision != null && this.scale == null -> TypedOpParameter(DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(this.precision.value.toInt()))) - else -> TypedOpParameter( - DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(this.precision!!.value.toInt(), this.scale!!.value.toInt())) - ) - } - is PartiqlAst.Type.TimestampType -> TypedOpParameter(StaticType.TIMESTAMP) - is PartiqlAst.Type.CharacterType -> when { - this.length == null -> TypedOpParameter(StringType(StringType.StringLengthConstraint.Constrained(NumberConstraint.Equals(1)))) - else -> TypedOpParameter( - StringType( - StringType.StringLengthConstraint.Constrained( - NumberConstraint.Equals(this.length.value.toInt()) - ) - ) - ) - } - is PartiqlAst.Type.CharacterVaryingType -> when (this.length) { - null -> TypedOpParameter(StringType(StringType.StringLengthConstraint.Unconstrained)) - else -> TypedOpParameter(StringType(StringType.StringLengthConstraint.Constrained(NumberConstraint.UpTo(this.length.value.toInt())))) - } - is PartiqlAst.Type.StringType -> TypedOpParameter(StaticType.STRING) - is PartiqlAst.Type.SymbolType -> TypedOpParameter(StaticType.SYMBOL) - is PartiqlAst.Type.ClobType -> TypedOpParameter(StaticType.CLOB) - is PartiqlAst.Type.BlobType -> TypedOpParameter(StaticType.BLOB) - is PartiqlAst.Type.StructType -> TypedOpParameter(StaticType.STRUCT) - is PartiqlAst.Type.TupleType -> TypedOpParameter(StaticType.STRUCT) - is PartiqlAst.Type.ListType -> TypedOpParameter(StaticType.LIST) - is PartiqlAst.Type.SexpType -> TypedOpParameter(StaticType.SEXP) - is PartiqlAst.Type.BagType -> TypedOpParameter(StaticType.BAG) - is PartiqlAst.Type.AnyType -> TypedOpParameter(StaticType.ANY) - is PartiqlAst.Type.CustomType -> - customTypedOpParameters.mapKeys { - (k, _) -> - k.toLowerCase() - }[this.name.text.toLowerCase()] ?: error("Could not find parameter for $this") - is PartiqlAst.Type.DateType -> TypedOpParameter(StaticType.DATE) - is PartiqlAst.Type.TimeType -> TypedOpParameter( - TimeType(this.precision?.value?.toInt(), withTimeZone = false) - ) - is PartiqlAst.Type.TimeWithTimeZoneType -> TypedOpParameter( - TimeType(this.precision?.value?.toInt(), withTimeZone = true) - ) - is PartiqlAst.Type.EsAny, - is PartiqlAst.Type.EsBoolean, - is PartiqlAst.Type.EsFloat, - is PartiqlAst.Type.EsInteger, - is PartiqlAst.Type.EsText, - is PartiqlAst.Type.RsBigint, - is PartiqlAst.Type.RsBoolean, - is PartiqlAst.Type.RsDoublePrecision, - is PartiqlAst.Type.RsInteger, - is PartiqlAst.Type.RsReal, - is PartiqlAst.Type.RsVarcharMax, - is PartiqlAst.Type.SparkBoolean, - is PartiqlAst.Type.SparkDouble, - is PartiqlAst.Type.SparkFloat, - is PartiqlAst.Type.SparkInteger, - is PartiqlAst.Type.SparkLong, - is PartiqlAst.Type.SparkShort -> error("$this node should not be present in PartiQLAST. Consider transforming the AST using CustomTypeVisitorTransform.") +/** Helper to convert [PartiqlAst.Type] in AST to a [TypedOpParameter]. */ +fun PartiqlAst.Type.toTypedOpParameter(customTypedOpParameters: Map): TypedOpParameter { + // hack: to avoid duplicating the function `PartiqlAst.Type.toTypedOpParameter`, we have to convert this + // PartiqlAst.Type to PartiqlPhysical.Type. The easiest way to do that without using a visitor transform + // (which is overkill and comes with some downsides for something this simple), is to transform to and from + // s-expressions again. This will work without difficulty as long as PartiqlAst.Type remains unchanged in all + // permuted domains between PartiqlAst and PartiqlPhysical. + + // This is really just a temporary measure, however, which must exist for as long as the type inferencer works only + // on PartiqlAst. When it has been migrated to use PartiqlPhysical instead, there should no longer be a reason + // to keep this function around. + val sexp = this.toIonElement() + val physicalType = PartiqlPhysical.transform(sexp) as PartiqlPhysical.Type + return physicalType.toTypedOpParameter(customTypedOpParameters) } diff --git a/lang/src/org/partiql/lang/types/PartiqlPhysicalTypeExtensions.kt b/lang/src/org/partiql/lang/types/PartiqlPhysicalTypeExtensions.kt new file mode 100644 index 0000000000..71f20f06ce --- /dev/null +++ b/lang/src/org/partiql/lang/types/PartiqlPhysicalTypeExtensions.kt @@ -0,0 +1,85 @@ +package org.partiql.lang.types + +import org.partiql.lang.domains.PartiqlPhysical + +/** + * Helper to convert [PartiqlPhysical.Type] in AST to a [TypedOpParameter]. + */ +fun PartiqlPhysical.Type.toTypedOpParameter(customTypedOpParameters: Map): TypedOpParameter = when (this) { + is PartiqlPhysical.Type.MissingType -> TypedOpParameter(StaticType.MISSING) + is PartiqlPhysical.Type.NullType -> TypedOpParameter(StaticType.NULL) + is PartiqlPhysical.Type.BooleanType -> TypedOpParameter(StaticType.BOOL) + is PartiqlPhysical.Type.SmallintType -> TypedOpParameter(IntType(IntType.IntRangeConstraint.SHORT)) + is PartiqlPhysical.Type.Integer4Type -> TypedOpParameter(IntType(IntType.IntRangeConstraint.INT4)) + is PartiqlPhysical.Type.Integer8Type -> TypedOpParameter(IntType(IntType.IntRangeConstraint.LONG)) + is PartiqlPhysical.Type.IntegerType -> TypedOpParameter(IntType(IntType.IntRangeConstraint.LONG)) + is PartiqlPhysical.Type.FloatType, is PartiqlPhysical.Type.RealType, is PartiqlPhysical.Type.DoublePrecisionType -> TypedOpParameter(StaticType.FLOAT) + is PartiqlPhysical.Type.DecimalType -> when { + this.precision == null && this.scale == null -> TypedOpParameter(StaticType.DECIMAL) + this.precision != null && this.scale == null -> TypedOpParameter(DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(this.precision.value.toInt()))) + else -> TypedOpParameter( + DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(this.precision!!.value.toInt(), this.scale!!.value.toInt())) + ) + } + is PartiqlPhysical.Type.NumericType -> when { + this.precision == null && this.scale == null -> TypedOpParameter(StaticType.DECIMAL) + this.precision != null && this.scale == null -> TypedOpParameter(DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(this.precision.value.toInt()))) + else -> TypedOpParameter( + DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(this.precision!!.value.toInt(), this.scale!!.value.toInt())) + ) + } + is PartiqlPhysical.Type.TimestampType -> TypedOpParameter(StaticType.TIMESTAMP) + is PartiqlPhysical.Type.CharacterType -> when { + this.length == null -> TypedOpParameter(StringType(StringType.StringLengthConstraint.Constrained(NumberConstraint.Equals(1)))) + else -> TypedOpParameter( + StringType( + StringType.StringLengthConstraint.Constrained( + NumberConstraint.Equals(this.length.value.toInt()) + ) + ) + ) + } + is PartiqlPhysical.Type.CharacterVaryingType -> when (this.length) { + null -> TypedOpParameter(StringType(StringType.StringLengthConstraint.Unconstrained)) + else -> TypedOpParameter(StringType(StringType.StringLengthConstraint.Constrained(NumberConstraint.UpTo(this.length.value.toInt())))) + } + is PartiqlPhysical.Type.StringType -> TypedOpParameter(StaticType.STRING) + is PartiqlPhysical.Type.SymbolType -> TypedOpParameter(StaticType.SYMBOL) + is PartiqlPhysical.Type.ClobType -> TypedOpParameter(StaticType.CLOB) + is PartiqlPhysical.Type.BlobType -> TypedOpParameter(StaticType.BLOB) + is PartiqlPhysical.Type.StructType -> TypedOpParameter(StaticType.STRUCT) + is PartiqlPhysical.Type.TupleType -> TypedOpParameter(StaticType.STRUCT) + is PartiqlPhysical.Type.ListType -> TypedOpParameter(StaticType.LIST) + is PartiqlPhysical.Type.SexpType -> TypedOpParameter(StaticType.SEXP) + is PartiqlPhysical.Type.BagType -> TypedOpParameter(StaticType.BAG) + is PartiqlPhysical.Type.AnyType -> TypedOpParameter(StaticType.ANY) + is PartiqlPhysical.Type.CustomType -> + customTypedOpParameters.mapKeys { + (k, _) -> + k.toLowerCase() + }[this.name.text.toLowerCase()] ?: error("Could not find parameter for $this") + is PartiqlPhysical.Type.DateType -> TypedOpParameter(StaticType.DATE) + is PartiqlPhysical.Type.TimeType -> TypedOpParameter( + TimeType(this.precision?.value?.toInt(), withTimeZone = false) + ) + is PartiqlPhysical.Type.TimeWithTimeZoneType -> TypedOpParameter( + TimeType(this.precision?.value?.toInt(), withTimeZone = true) + ) + is PartiqlPhysical.Type.EsAny, + is PartiqlPhysical.Type.EsBoolean, + is PartiqlPhysical.Type.EsFloat, + is PartiqlPhysical.Type.EsInteger, + is PartiqlPhysical.Type.EsText, + is PartiqlPhysical.Type.RsBigint, + is PartiqlPhysical.Type.RsBoolean, + is PartiqlPhysical.Type.RsDoublePrecision, + is PartiqlPhysical.Type.RsInteger, + is PartiqlPhysical.Type.RsReal, + is PartiqlPhysical.Type.RsVarcharMax, + is PartiqlPhysical.Type.SparkBoolean, + is PartiqlPhysical.Type.SparkDouble, + is PartiqlPhysical.Type.SparkFloat, + is PartiqlPhysical.Type.SparkInteger, + is PartiqlPhysical.Type.SparkLong, + is PartiqlPhysical.Type.SparkShort -> error("$this node should not be present in PartiqlPhysical. Consider transforming the AST using CustomTypeVisitorTransform.") +} diff --git a/lang/test/org/partiql/lang/eval/EvaluatorTestBase.kt b/lang/test/org/partiql/lang/eval/EvaluatorTestBase.kt index fd5b875a8d..c18db638ec 100644 --- a/lang/test/org/partiql/lang/eval/EvaluatorTestBase.kt +++ b/lang/test/org/partiql/lang/eval/EvaluatorTestBase.kt @@ -37,6 +37,7 @@ import org.partiql.lang.eval.evaluatortestframework.LegacySerializerTestAdapter import org.partiql.lang.eval.evaluatortestframework.MultipleTestAdapter import org.partiql.lang.eval.evaluatortestframework.PartiqlAstExprNodeRoundTripAdapter import org.partiql.lang.eval.evaluatortestframework.PipelineEvaluatorTestAdapter +import org.partiql.lang.eval.evaluatortestframework.PlannerPipelineFactory import org.partiql.lang.util.asSequence import org.partiql.lang.util.newFromIonText @@ -47,7 +48,7 @@ abstract class EvaluatorTestBase : TestBase() { private val testHarness: EvaluatorTestAdapter = MultipleTestAdapter( listOf( PipelineEvaluatorTestAdapter(CompilerPipelineFactory()), - // TODO: PipelineEvaluatorTestAdapter(PlannerPipelineFactory()), + PipelineEvaluatorTestAdapter(PlannerPipelineFactory()), PartiqlAstExprNodeRoundTripAdapter(), LegacySerializerTestAdapter(), AstRewriterBaseTestAdapter() diff --git a/lang/test/org/partiql/lang/eval/EvaluatorTests.kt b/lang/test/org/partiql/lang/eval/EvaluatorTests.kt index 167490a2ed..4707b5a9c0 100644 --- a/lang/test/org/partiql/lang/eval/EvaluatorTests.kt +++ b/lang/test/org/partiql/lang/eval/EvaluatorTests.kt @@ -14,7 +14,6 @@ package org.partiql.lang.eval -import org.junit.jupiter.api.Disabled import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.MethodSource import org.partiql.lang.ION @@ -98,10 +97,17 @@ class EvaluatorTests { "selectDistinctWithAggregate", // TODO: Support aggregates in physical plans "selectDistinctAggregationWithGroupBy", // TODO: Support GROUP BY in physical plans "selectDistinctWithGroupBy", // TODO: Support GROUP BY in physical plans + "unpivotStructWithMissingField", // TODO: Support UNPIVOT in physical plans "unpivotMissing", // TODO: Support UNPIVOT in physical plans "unpivotEmptyStruct", // TODO: Support UNPIVOT in physical plans "unpivotMissingWithAsAndAt", // TODO: Support UNPIVOT in physical plans "unpivotMissingCrossJoinWithAsAndAt", // TODO: Support UNPIVOT in physical plans + + // UndefinedVariableBehavior.MISSING not supported by plan evaluator + "undefinedUnqualifiedVariableWithUndefinedVariableBehaviorMissing", + "undefinedUnqualifiedVariableIsNullExprWithUndefinedVariableBehaviorMissing", + "undefinedUnqualifiedVariableIsMissingExprWithUndefinedVariableBehaviorMissing", + "undefinedUnqualifiedVariableInSelectWithUndefinedVariableBehaviorMissing", ) @JvmStatic @@ -119,6 +125,5 @@ class EvaluatorTests { @ParameterizedTest @MethodSource("planEvaluatorTests") - @Disabled("The planner will be merged in a future pull request.") - fun planEvalutorTests(tc: IonResultTestCase) = tc.runTestCase(valueFactory, mockDb, EvaluatorTestTarget.PLANNER_PIPELINE) + fun planEvaluatorTests(tc: IonResultTestCase) = tc.runTestCase(valueFactory, mockDb, EvaluatorTestTarget.PLANNER_PIPELINE) } diff --git a/lang/test/org/partiql/lang/eval/TypingModeTests.kt b/lang/test/org/partiql/lang/eval/TypingModeTests.kt index a771b54e5f..fc4022e5ce 100644 --- a/lang/test/org/partiql/lang/eval/TypingModeTests.kt +++ b/lang/test/org/partiql/lang/eval/TypingModeTests.kt @@ -58,7 +58,7 @@ class TypingModeTests : EvaluatorTestBase() { expectedErrorCode = tc.expectedLegacyError.errorCode, expectedPermissiveModeResult = tc.expectedPermissiveModeResult, addtionalExceptionAssertBlock = { ex: SqlException -> - // Have to use the addtionalExceptionAssertBlock instead of error context for this + // Have to use the additionalExceptionAssertBlock instead of error context for this // because there are a few cases with error context values other than line & column that we don't // account for in [TestCase]. assertEquals( diff --git a/lang/test/org/partiql/lang/eval/builtins/BuiltInFunctionTestExtensions.kt b/lang/test/org/partiql/lang/eval/builtins/BuiltInFunctionTestExtensions.kt index 8adc72f79b..952b58c9e0 100644 --- a/lang/test/org/partiql/lang/eval/builtins/BuiltInFunctionTestExtensions.kt +++ b/lang/test/org/partiql/lang/eval/builtins/BuiltInFunctionTestExtensions.kt @@ -6,6 +6,7 @@ import org.partiql.lang.eval.Bindings import org.partiql.lang.eval.EvaluationSession import org.partiql.lang.eval.ExprValue import org.partiql.lang.eval.ExprValueFactory +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.util.newFromIonText /** @@ -19,8 +20,13 @@ internal fun checkInvalidArgType(funcName: String, syntaxSuffix: String = "(", a * Internal function used by ExprFunctionTest to test invalid arity. */ internal val invalidArityChecker = InvalidArityChecker() -internal fun checkInvalidArity(funcName: String, minArity: Int, maxArity: Int) = - invalidArityChecker.checkInvalidArity(funcName, minArity, maxArity) +internal fun checkInvalidArity( + funcName: String, + minArity: Int, + maxArity: Int, + targetPipeline: EvaluatorTestTarget = EvaluatorTestTarget.ALL_PIPELINES +) = + invalidArityChecker.checkInvalidArity(funcName, minArity, maxArity, targetPipeline) private val valueFactory = ExprValueFactory.standard(ION) diff --git a/lang/test/org/partiql/lang/eval/builtins/InvalidArityChecker.kt b/lang/test/org/partiql/lang/eval/builtins/InvalidArityChecker.kt index 9fe3b14bf9..848f884f5a 100644 --- a/lang/test/org/partiql/lang/eval/builtins/InvalidArityChecker.kt +++ b/lang/test/org/partiql/lang/eval/builtins/InvalidArityChecker.kt @@ -3,6 +3,7 @@ package org.partiql.lang.eval.builtins import org.partiql.lang.errors.ErrorCode import org.partiql.lang.errors.Property import org.partiql.lang.eval.EvaluatorTestBase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.util.propertyValueMapOf /** @@ -31,7 +32,7 @@ class InvalidArityChecker : EvaluatorTestBase() { * @param maxArity is the maximum arity of an ExprFunction. * @param minArity is the minimum arity of an ExprFunction. */ - fun checkInvalidArity(funcName: String, minArity: Int, maxArity: Int) { + fun checkInvalidArity(funcName: String, minArity: Int, maxArity: Int, targetPipeline: EvaluatorTestTarget) { if (minArity < 0) throw IllegalStateException("Minimum arity has to be larger than 0.") if (maxArity < minArity) throw IllegalStateException("Maximum arity has to be larger than or equal to minimum arity.") @@ -44,7 +45,7 @@ class InvalidArityChecker : EvaluatorTestBase() { else -> sb.append(",null") } if (curArity < minArity || curArity > maxArity) { // If less or more argument provided, we catch invalid arity error - assertThrowsInvalidArity("$sb)", funcName, curArity, minArity, maxArity) + assertThrowsInvalidArity("$sb)", funcName, curArity, minArity, maxArity, targetPipeline) } } } @@ -54,7 +55,8 @@ class InvalidArityChecker : EvaluatorTestBase() { funcName: String, actualArity: Int, minArity: Int, - maxArity: Int + maxArity: Int, + targetPipeline: EvaluatorTestTarget ) = runEvaluatorErrorTestCase( query = query, expectedErrorCode = ErrorCode.EVALUATOR_INCORRECT_NUMBER_OF_ARGUMENTS_TO_FUNC_CALL, @@ -65,5 +67,6 @@ class InvalidArityChecker : EvaluatorTestBase() { Property.EXPECTED_ARITY_MAX to maxArity, Property.ACTUAL_ARITY to actualArity ), + target = targetPipeline ) } diff --git a/lang/test/org/partiql/lang/eval/builtins/functions/DynamicLookupExprFunctionTest.kt b/lang/test/org/partiql/lang/eval/builtins/functions/DynamicLookupExprFunctionTest.kt new file mode 100644 index 0000000000..f55ba07792 --- /dev/null +++ b/lang/test/org/partiql/lang/eval/builtins/functions/DynamicLookupExprFunctionTest.kt @@ -0,0 +1,153 @@ +package org.partiql.lang.eval.builtins.functions + +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.eval.EvaluatorTestBase +import org.partiql.lang.eval.builtins.DYNAMIC_LOOKUP_FUNCTION_NAME +import org.partiql.lang.eval.builtins.ExprFunctionTestCase +import org.partiql.lang.eval.builtins.checkInvalidArity +import org.partiql.lang.eval.evaluatortestframework.EvaluatorErrorTestCase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget +import org.partiql.lang.eval.evaluatortestframework.ExpectedResultFormat +import org.partiql.lang.util.ArgumentsProviderBase +import org.partiql.lang.util.propertyValueMapOf +import org.partiql.lang.util.to + +class DynamicLookupExprFunctionTest : EvaluatorTestBase() { + val session = mapOf( + "f" to "{ foo: 42 }", + "b" to "{ bar: 43 }", + "foo" to "44", + ).toSession() + + // Pass test cases + @ParameterizedTest + @ArgumentsSource(ToStringPassCases::class) + fun runPassTests(testCase: ExprFunctionTestCase) = + runEvaluatorTestCase( + query = testCase.source, + expectedResult = testCase.expectedLegacyModeResult, + target = EvaluatorTestTarget.PLANNER_PIPELINE, + expectedResultFormat = ExpectedResultFormat.ION, + session = session + ) + + // We rely on the built-in [DEFAULT_COMPARATOR] for the actual definition of equality, which is not being tested + // here. + class ToStringPassCases : ArgumentsProviderBase() { + override fun getParameters(): List = listOf( + // function signature: $__dynamic_lookup__(, , , *) + // arg #1: the name of the field or variable to locate. + // arg #2: case-insensitive or sensitive + // arg #3: look in globals first or locals first. + // arg #4 and later (variadic): any remaining arguments are the variables to search within, which in general + // are structs. note that in general, these will be local variables, however we don't use local variables + // here to simplify these test cases. + + // locals_then_globals + + // `foo` should be found in the variable f, which is a struct + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`foo`, `case_insensitive`, `locals_then_globals`, f, b)", "42"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`fOo`, `case_insensitive`, `locals_then_globals`, f, b)", "42"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`FoO`, `case_insensitive`, `locals_then_globals`, f, b)", "42"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`foo`, `case_sensitive`, `locals_then_globals`, f, b)", "42"), + // `bar` should be found in the variable b, which is also a struct + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`bar`, `case_insensitive`, `locals_then_globals`, f, b)", "43"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`BaR`, `case_insensitive`, `locals_then_globals`, f, b)", "43"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`bAr`, `case_insensitive`, `locals_then_globals`, f, b)", "43"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`bar`, `case_sensitive`, `locals_then_globals`, f, b)", "43"), + + // globals_then_locals + + // The global variable `foo` should be found first, ignoring the `f.foo`, unlike the similar cases above` + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`foo`, `case_insensitive`, `globals_then_locals`, f, b)", "44"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`fOo`, `case_insensitive`, `globals_then_locals`, f, b)", "44"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`FoO`, `case_insensitive`, `globals_then_locals`, f, b)", "44"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`foo`, `case_sensitive`, `globals_then_locals`, f, b)", "44"), + // `bar` should still be found in the variable b, which is also a struct, since there is no global named `bar`. + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`bar`, `case_insensitive`, `globals_then_locals`, f, b)", "43"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`BaR`, `case_insensitive`, `globals_then_locals`, f, b)", "43"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`bAr`, `case_insensitive`, `globals_then_locals`, f, b)", "43"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`bar`, `case_sensitive`, `globals_then_locals`, f, b)", "43") + ) + } + + @ParameterizedTest + @ArgumentsSource(MismatchCaseSensitiveCases::class) + fun mismatchedCaseSensitiveTests(testCase: EvaluatorErrorTestCase) = + runEvaluatorErrorTestCase( + testCase.copy( + expectedPermissiveModeResult = "MISSING", + targetPipeline = EvaluatorTestTarget.PLANNER_PIPELINE + ), + session = session + ) + + class MismatchCaseSensitiveCases : ArgumentsProviderBase() { + override fun getParameters(): List = listOf( + // Can't find these variables due to case mismatch when perform case sensitive lookup + EvaluatorErrorTestCase( + query = "\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`fOo`, `case_sensitive`, `locals_then_globals`, f, b)", + expectedErrorCode = ErrorCode.EVALUATOR_QUOTED_BINDING_DOES_NOT_EXIST, + expectedErrorContext = propertyValueMapOf(1, 1, Property.BINDING_NAME to "fOo") + ), + EvaluatorErrorTestCase( + query = "\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`FoO`, `case_sensitive`, `locals_then_globals`, f, b)", + expectedErrorCode = ErrorCode.EVALUATOR_QUOTED_BINDING_DOES_NOT_EXIST, + expectedErrorContext = propertyValueMapOf(1, 1, Property.BINDING_NAME to "FoO") + ), + EvaluatorErrorTestCase( + query = "\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`BaR`, `case_sensitive`, `locals_then_globals`, f, b)", + expectedErrorCode = ErrorCode.EVALUATOR_QUOTED_BINDING_DOES_NOT_EXIST, + expectedErrorContext = propertyValueMapOf(1, 1, Property.BINDING_NAME to "BaR") + ), + EvaluatorErrorTestCase( + query = "\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`bAr`, `case_sensitive`, `locals_then_globals`, f, b)", + expectedErrorCode = ErrorCode.EVALUATOR_QUOTED_BINDING_DOES_NOT_EXIST, + expectedErrorContext = propertyValueMapOf(1, 1, Property.BINDING_NAME to "bAr") + ) + ) + } + + data class InvalidArgTestCase( + val source: String, + val argumentPosition: Int, + val actualArgumentType: String, + ) + + @ParameterizedTest + @ArgumentsSource(InvalidArgCases::class) + fun invalidArgTypeTestCases(testCase: InvalidArgTestCase) = + runEvaluatorErrorTestCase( + query = testCase.source, + expectedErrorCode = ErrorCode.EVALUATOR_INCORRECT_TYPE_OF_ARGUMENTS_TO_FUNC_CALL, + expectedErrorContext = propertyValueMapOf( + 1, 1, + Property.FUNCTION_NAME to DYNAMIC_LOOKUP_FUNCTION_NAME, + Property.EXPECTED_ARGUMENT_TYPES to "SYMBOL", + Property.ACTUAL_ARGUMENT_TYPES to testCase.actualArgumentType, + Property.ARGUMENT_POSITION to testCase.argumentPosition + ), + expectedPermissiveModeResult = "MISSING", + target = EvaluatorTestTarget.PLANNER_PIPELINE + ) + + class InvalidArgCases : ArgumentsProviderBase() { + override fun getParameters(): List = listOf( + InvalidArgTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(1, `case_insensitive`, `locals_then_globals`)", 1, "INT"), + InvalidArgTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`foo`, 1, `locals_then_globals`)", 2, "INT"), + InvalidArgTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`foo`, `case_insensitive`, 1)", 3, "INT") + ) + } + + @Test + fun invalidArityTest() = checkInvalidArity( + funcName = "\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"", + maxArity = Int.MAX_VALUE, + minArity = 3, + targetPipeline = EvaluatorTestTarget.PLANNER_PIPELINE + ) +} diff --git a/lang/test/org/partiql/lang/eval/evaluatortestframework/PlannerPipelineFactory.kt b/lang/test/org/partiql/lang/eval/evaluatortestframework/PlannerPipelineFactory.kt new file mode 100644 index 0000000000..cc93b8bb91 --- /dev/null +++ b/lang/test/org/partiql/lang/eval/evaluatortestframework/PlannerPipelineFactory.kt @@ -0,0 +1,122 @@ +package org.partiql.lang.eval.evaluatortestframework + +import org.junit.jupiter.api.fail +import org.partiql.lang.ION +import org.partiql.lang.eval.BindingName +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.TypingMode +import org.partiql.lang.eval.UndefinedVariableBehavior +import org.partiql.lang.planner.EvaluatorOptions +import org.partiql.lang.planner.MetadataResolver +import org.partiql.lang.planner.PassResult +import org.partiql.lang.planner.PlannerPipeline +import org.partiql.lang.planner.ResolutionResult +import kotlin.test.assertNotEquals +import kotlin.test.assertNull + +/** + * Uses the test infrastructure (which is geared toward the legacy [org.partiql.lang.CompilerPipeline]) to create a + * standard [org.partiql.lang.CompilerPipeline], then creates an equivalent [PlannerPipeline] which is wrapped in + * an instance of [AbstractPipeline] and returned to the caller. + * + * Why? Because the entire test infrastructure (and the many thousands of tests) are heavily dependent on + * [org.partiql.lang.CompilerPipeline]. When that class is deprecated or removed we'll want to change this to + * depend on the [PlannerPipeline] instead. + */ +internal class PlannerPipelineFactory : PipelineFactory { + + override val pipelineName: String + get() = "PlannerPipeline (and Physical Plan Evaluator)" + + override val target: EvaluatorTestTarget + get() = EvaluatorTestTarget.PLANNER_PIPELINE + + override fun createPipeline( + evaluatorTestDefinition: EvaluatorTestDefinition, + session: EvaluationSession, + forcePermissiveMode: Boolean + ): AbstractPipeline { + + // Construct a legacy CompilerPipeline + val compilerPipeline = evaluatorTestDefinition.createCompilerPipeline(forcePermissiveMode) + + // Convert it to a PlannerPipeline (to avoid having to refactor many tests cases to use + // PlannerPipeline.Builder and EvaluatorOptions.Builder. + val co = compilerPipeline.compileOptions + + assertNotEquals( + co.undefinedVariable, UndefinedVariableBehavior.MISSING, + "The planner and physical plan evaluator do not support UndefinedVariableBehavior.MISSING. " + + "Please set target = EvaluatorTestTarget.COMPILER_PIPELINE for this test.\n" + + "Test groupName: ${evaluatorTestDefinition.groupName}" + ) + + assertNull( + compilerPipeline.globalTypeBindings, + "The planner and physical plan evaluator do not support globalTypeBindings (yet)" + + "Please set target = EvaluatorTestTarget.COMPILER_PIPELINE for this test." + ) + + val evaluatorOptions = EvaluatorOptions.build { + typingMode(co.typingMode) + thunkOptions(co.thunkOptions) + defaultTimezoneOffset(co.defaultTimezoneOffset) + typedOpBehavior(co.typedOpBehavior) + projectionIteration(co.projectionIteration) + } + + @Suppress("DEPRECATION") + val plannerPipeline = PlannerPipeline.build(ION) { + // this is for support of the existing test suite and may not be desirable for all future tests. + allowUndefinedVariables(true) + + customDataTypes(compilerPipeline.customDataTypes) + + compilerPipeline.functions.values.forEach { this.addFunction(it) } + compilerPipeline.procedures.values.forEach { this.addProcedure(it) } + + evaluatorOptions(evaluatorOptions) + + // For compatibility with the unit test suite, prevent the planner from catching SqlException during query + // compilation and converting them into Problems + enableLegacyExceptionHandling() + + // Create a fake MetadataResolver implementation which defines any global that is also defined in the + // session. + metadataResolver( + object : MetadataResolver { + override fun resolveVariable(bindingName: BindingName): ResolutionResult { + val boundValue = session.globals[bindingName] + return if (boundValue != null) { + // There is no way to tell the actual name of the global variable as it exists + // in session.globals (case may differ). For now we simply have to use binding.name + // as the uniqueId of the variable, however, this is not desirable in production + // scenarios. At minimum, the name of the variable in its original letter-case should be + // used. + ResolutionResult.GlobalVariable(bindingName.name) + } else { + ResolutionResult.Undefined + } + } + } + ) + } + + return object : AbstractPipeline { + override val typingMode: TypingMode + get() = evaluatorOptions.typingMode + + override fun evaluate(query: String): ExprValue { + when (val planningResult = plannerPipeline.planAndCompile(query)) { + is PassResult.Error -> { + fail("Query compilation unexpectedly failed: ${planningResult.errors}") + } + is PassResult.Success -> { + return planningResult.result.eval(session) + } + } + } + } + } +} diff --git a/lang/test/org/partiql/lang/eval/relation/RelationTests.kt b/lang/test/org/partiql/lang/eval/relation/RelationTests.kt new file mode 100644 index 0000000000..24ed0dc1e9 --- /dev/null +++ b/lang/test/org/partiql/lang/eval/relation/RelationTests.kt @@ -0,0 +1,58 @@ +package org.partiql.lang.eval.relation + +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows + +class RelationTests { + + @Test + fun relType() { + val rel = relation(RelationType.BAG) { } + assertEquals(RelationType.BAG, rel.relType) + } + + @Test + fun `0 yields`() { + val rel = relation(RelationType.BAG) { } + assertEquals(RelationType.BAG, rel.relType) + assertFalse(rel.nextRow()) + assertThrows { rel.nextRow() } + } + + @Test + fun `1 yield`() { + val rel = relation(RelationType.BAG) { yield() } + assertTrue(rel.nextRow()) + assertFalse(rel.nextRow()) + assertThrows { rel.nextRow() } + } + + @Test + fun `2 yields`() { + val rel = relation(RelationType.BAG) { + yield() + yield() + } + assertTrue(rel.nextRow()) + assertTrue(rel.nextRow()) + assertFalse(rel.nextRow()) + assertThrows { rel.nextRow() } + } + + @Test + fun `3 yields`() { + val rel = relation(RelationType.BAG) { + yield() + yield() + yield() + } + assertTrue(rel.nextRow()) + assertTrue(rel.nextRow()) + assertTrue(rel.nextRow()) + assertFalse(rel.nextRow()) + assertThrows { rel.nextRow() } + } +} diff --git a/lang/test/org/partiql/lang/planner/PlannerPipelineSmokeTests.kt b/lang/test/org/partiql/lang/planner/PlannerPipelineSmokeTests.kt index eb7a726802..e65ee32f0a 100644 --- a/lang/test/org/partiql/lang/planner/PlannerPipelineSmokeTests.kt +++ b/lang/test/org/partiql/lang/planner/PlannerPipelineSmokeTests.kt @@ -52,7 +52,7 @@ class PlannerPipelineSmokeTests { ), source = scan( i = impl("default"), - expr = globalId("Customer", "fake_uid_for_Customer"), + expr = globalId("fake_uid_for_Customer", caseInsensitive()), asDecl = varDecl(0) ) ) diff --git a/lang/test/org/partiql/lang/planner/transforms/LogicalResolvedToDefaultPhysicalVisitorTransformTests.kt b/lang/test/org/partiql/lang/planner/transforms/LogicalResolvedToDefaultPhysicalVisitorTransformTests.kt index 795d8f8abf..c69fca6597 100644 --- a/lang/test/org/partiql/lang/planner/transforms/LogicalResolvedToDefaultPhysicalVisitorTransformTests.kt +++ b/lang/test/org/partiql/lang/planner/transforms/LogicalResolvedToDefaultPhysicalVisitorTransformTests.kt @@ -22,7 +22,7 @@ class LogicalResolvedToDefaultPhysicalVisitorTransformTests { TestCase( PartiqlLogicalResolved.build { scan( - expr = globalId("foo", "foo"), + expr = globalId("foo", caseInsensitive()), asDecl = varDecl(0), atDecl = varDecl(1), byDecl = varDecl(2) @@ -31,7 +31,7 @@ class LogicalResolvedToDefaultPhysicalVisitorTransformTests { PartiqlPhysical.build { scan( i = DEFAULT_IMPL, - expr = globalId("foo", "foo"), + expr = globalId("foo", caseInsensitive()), asDecl = varDecl(0), atDecl = varDecl(1), byDecl = varDecl(2) @@ -43,7 +43,7 @@ class LogicalResolvedToDefaultPhysicalVisitorTransformTests { filter( predicate = lit(ionBool(true)), source = scan( - expr = globalId("foo", "foo"), + expr = globalId("foo", caseInsensitive()), asDecl = varDecl(0), atDecl = varDecl(1), byDecl = varDecl(2) @@ -56,7 +56,7 @@ class LogicalResolvedToDefaultPhysicalVisitorTransformTests { predicate = lit(ionBool(true)), source = scan( i = DEFAULT_IMPL, - expr = globalId("foo", "foo"), + expr = globalId("foo", caseInsensitive()), asDecl = varDecl(0), atDecl = varDecl(1), byDecl = varDecl(2) diff --git a/lang/test/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransformTests.kt b/lang/test/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransformTests.kt index eccfd4c4b3..d61e399e32 100644 --- a/lang/test/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransformTests.kt +++ b/lang/test/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransformTests.kt @@ -91,7 +91,7 @@ class LogicalToLogicalResolvedVisitorTransformTests { } /** Mock table resolver. That can resolve f, foo, or UPPERCASE_FOO, while respecting case-sensitivity. */ - private val globalBindings = createFakeMetadataResolver( + private val metadataResolver = createFakeMetadataResolver( *listOf( "shadow", "foo", @@ -117,7 +117,7 @@ class LogicalToLogicalResolvedVisitorTransformTests { when (tc.expectation) { is Expectation.Success -> { - val resolved = plan.toResolvedPlan(problemHandler, globalBindings, tc.allowUndefinedVariables) + val resolved = plan.toResolvedPlan(problemHandler, metadataResolver, tc.allowUndefinedVariables) // extract all of the dynamic, global and local ids from the resolved logical plan. val actualResolvedIds = @@ -186,7 +186,7 @@ class LogicalToLogicalResolvedVisitorTransformTests { } is Expectation.Problems -> { assertDoesNotThrow("Should not throw when variables are undefined") { - plan.toResolvedPlan(problemHandler, globalBindings) + plan.toResolvedPlan(problemHandler, metadataResolver) } assertEquals(tc.expectation.problems, problemHandler.problems) } @@ -202,17 +202,17 @@ class LogicalToLogicalResolvedVisitorTransformTests { TestCase( // all uppercase sql = "FOO", - expectation = Expectation.Success(ResolvedId(1, 1) { globalId("FOO", "fake_uid_for_foo") }) + expectation = Expectation.Success(ResolvedId(1, 1) { globalId("fake_uid_for_foo", caseInsensitive()) }) ), TestCase( // all lower case "foo", - Expectation.Success(ResolvedId(1, 1) { globalId("foo", "fake_uid_for_foo") }) + Expectation.Success(ResolvedId(1, 1) { globalId("fake_uid_for_foo", caseInsensitive()) }) ), TestCase( // mixed case "fOo", - Expectation.Success(ResolvedId(1, 1) { globalId("fOo", "fake_uid_for_foo") }) + Expectation.Success(ResolvedId(1, 1) { globalId("fake_uid_for_foo", caseInsensitive()) }) ), TestCase( // undefined @@ -233,10 +233,7 @@ class LogicalToLogicalResolvedVisitorTransformTests { // In this case, we resolve to the first matching binding. This is consistent with Postres 9.6. Expectation.Success( ResolvedId(1, 1) { - globalId( - "case_ambiguous_foo", - "fake_uid_for_case_AMBIGUOUS_foo" - ) + globalId("fake_uid_for_case_AMBIGUOUS_foo", caseInsensitive()) } ) ), @@ -247,10 +244,7 @@ class LogicalToLogicalResolvedVisitorTransformTests { "UPPERCASE_FOO", Expectation.Success( ResolvedId(1, 1) { - globalId( - "UPPERCASE_FOO", - "fake_uid_for_UPPERCASE_FOO" - ) + globalId("fake_uid_for_UPPERCASE_FOO", caseInsensitive()) } ) ), @@ -259,10 +253,7 @@ class LogicalToLogicalResolvedVisitorTransformTests { "uppercase_foo", Expectation.Success( ResolvedId(1, 1) { - globalId( - "uppercase_foo", - "fake_uid_for_UPPERCASE_FOO" - ) + globalId("fake_uid_for_UPPERCASE_FOO", caseInsensitive()) } ) ), @@ -271,10 +262,7 @@ class LogicalToLogicalResolvedVisitorTransformTests { "UpPeRcAsE_fOo", Expectation.Success( ResolvedId(1, 1) { - globalId( - "UpPeRcAsE_fOo", - "fake_uid_for_UPPERCASE_FOO" - ) + globalId("fake_uid_for_UPPERCASE_FOO", caseInsensitive()) } ) ), @@ -313,7 +301,7 @@ class LogicalToLogicalResolvedVisitorTransformTests { TestCase( // all lowercase "\"foo\"", - Expectation.Success(ResolvedId(1, 1) { globalId("foo", "fake_uid_for_foo") }) + Expectation.Success(ResolvedId(1, 1) { globalId("fake_uid_for_foo", caseSensitive()) }) ), TestCase( // mixed @@ -334,8 +322,7 @@ class LogicalToLogicalResolvedVisitorTransformTests { Expectation.Success( ResolvedId(1, 1) { globalId( - "UPPERCASE_FOO", - "fake_uid_for_UPPERCASE_FOO" + "fake_uid_for_UPPERCASE_FOO", caseSensitive() ) } ) @@ -359,7 +346,7 @@ class LogicalToLogicalResolvedVisitorTransformTests { "\"case_AMBIGUOUS_foo\"", Expectation.Success( ResolvedId(1, 1) { - globalId("case_AMBIGUOUS_foo", "fake_uid_for_case_AMBIGUOUS_foo") + globalId("fake_uid_for_case_AMBIGUOUS_foo", caseSensitive()) } ) ), @@ -368,7 +355,7 @@ class LogicalToLogicalResolvedVisitorTransformTests { "\"case_ambiguous_FOO\"", Expectation.Success( ResolvedId(1, 1) { - globalId("case_ambiguous_FOO", "fake_uid_for_case_ambiguous_FOO") + globalId("fake_uid_for_case_ambiguous_FOO", caseSensitive()) } ) ), @@ -549,7 +536,7 @@ class LogicalToLogicalResolvedVisitorTransformTests { "SELECT $varName.* FROM foo AS a AT b BY c", Expectation.Success( ResolvedId(1, 8) { localId(expectedIndex.toLong()) }, - ResolvedId(1, 17) { globalId("foo", "fake_uid_for_foo") } + ResolvedId(1, 17) { globalId("fake_uid_for_foo", caseInsensitive()) } ).withLocals(localVariable("a", 0), localVariable("b", 1), localVariable("c", 2)) ) @@ -566,7 +553,7 @@ class LogicalToLogicalResolvedVisitorTransformTests { "SELECT b.* FROM bar AS b WHERE b.primaryKey = 42", Expectation.Success( ResolvedId(1, 8) { localId(0) }, - ResolvedId(1, 17) { globalId("bar", "fake_uid_for_bar") }, + ResolvedId(1, 17) { globalId("fake_uid_for_bar", caseInsensitive()) }, ResolvedId(1, 32) { localId(0) }, ).withLocals(localVariable("b", 0)) ), @@ -576,7 +563,7 @@ class LogicalToLogicalResolvedVisitorTransformTests { "SELECT shadow.* FROM shadow AS shadow", // `shadow` defined here shadows the global `shadow` Expectation.Success( ResolvedId(1, 8) { localId(0) }, - ResolvedId(1, 22) { globalId("shadow", "fake_uid_for_shadow") } + ResolvedId(1, 22) { globalId("fake_uid_for_shadow", caseInsensitive()) } ).withLocals(localVariable("shadow", 0)) ), diff --git a/lang/test/org/partiql/lang/util/testdsl/IonResultTestCase.kt b/lang/test/org/partiql/lang/util/testdsl/IonResultTestCase.kt index 2b34676757..bc47d17e8c 100644 --- a/lang/test/org/partiql/lang/util/testdsl/IonResultTestCase.kt +++ b/lang/test/org/partiql/lang/util/testdsl/IonResultTestCase.kt @@ -14,6 +14,7 @@ import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestCase import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.eval.evaluatortestframework.ExpectedResultFormat import org.partiql.lang.eval.evaluatortestframework.PipelineEvaluatorTestAdapter +import org.partiql.lang.eval.evaluatortestframework.PlannerPipelineFactory import org.partiql.lang.mockdb.MockDb import org.partiql.lang.syntax.SqlParser @@ -78,11 +79,10 @@ internal fun IonResultTestCase.runTestCase( val adapter = PipelineEvaluatorTestAdapter( when (target) { EvaluatorTestTarget.COMPILER_PIPELINE -> CompilerPipelineFactory() - EvaluatorTestTarget.PLANNER_PIPELINE -> TODO("PlannerPipelineFactory()") - EvaluatorTestTarget.ALL_PIPELINES -> - // We don't support ALL_PIPELINES here because each pipeline needs a separate skip list, which - // is decided by the caller of this function. - error("May only test one pipeline at a time with IonResultTestCase") + EvaluatorTestTarget.PLANNER_PIPELINE -> PlannerPipelineFactory() + // We don't support ALL_PIPELINES here because each pipeline needs a separate skip list, which + // is decided by the caller of this function. + EvaluatorTestTarget.ALL_PIPELINES -> error("May only test one pipeline at a time with IonResultTestCase") } ) diff --git a/pts/test/org/partiql/lang/pts/PartiQlPtsEvaluator.kt b/pts/test/org/partiql/lang/pts/PartiQlPtsEvaluator.kt index 373f970401..6d4bf99912 100644 --- a/pts/test/org/partiql/lang/pts/PartiQlPtsEvaluator.kt +++ b/pts/test/org/partiql/lang/pts/PartiQlPtsEvaluator.kt @@ -58,6 +58,9 @@ class PartiQlPtsEvaluator(equality: PtsEquality) : Evaluator(equality) { is ExpectedError -> TestResultSuccess(test) is ExpectedSuccess -> TestFailure(test, e.generateMessage(), TestFailure.FailureReason.UNEXPECTED_ERROR) } + } catch (e: Exception) { + // Other exception types are always failures. + TestFailure(test, "${e.javaClass.canonicalName} : ${e.message}", TestFailure.FailureReason.UNEXPECTED_ERROR) } private fun verifyTestResult(test: TestExpression, actualResult: IonValue): TestResult =