diff --git a/docs/user/BuiltInFunctions.md b/docs/user/BuiltInFunctions.md index fe265bb0ae..fbe5c27d53 100644 --- a/docs/user/BuiltInFunctions.md +++ b/docs/user/BuiltInFunctions.md @@ -226,8 +226,6 @@ CAST(<<'a', 'b'>> AS bag) -- <<'a', 'b'>> (REPL does not display << >> and comma ### CHAR_LENGTH, CHARACTER_LENGTH - - Counts the number of characters in the specified string, where 'character' is defined as a single unicode code point. *Note:* `CHAR_LENGTH` and `CHARACTER_LENGTH` are synonyms. @@ -455,9 +453,31 @@ EXTRACT(TIMEZONE_MINUTE FROM TIME WITH TIME ZONE '23:12:59-08:30') -- -30 ``` *Note* that `timezone_hour` and `timezone_minute` are **not supported** for `DATE` and `TIME` (without time zone) type. +### `FILTER_DISTINCT` + +Signature +: `FILTER_DISTINCT: Container -> Bag` + +Header +: `FILTER_DISTINCT(c)` + +Purpose +: Returns a bag of distinct values contained within a bag, list, sexp, or struct. If the container is a struct, +the field names are not considered. + +Examples +: + +```sql +FILTER_DISTINCT([0, 0, 1]) -- <<0, 1>> +FILTER_DISTINCT(<<0, 0, 1>>) -- <<0, 1>> +FILTER_DISTINCT(SEXP(0, 0, 1)) -- <<0, 1>> +FILTER_DISTINCT({'a': 0, 'b': 0, 'c': 1}) -- <<0, 1>> +``` + ### LOWER -Given a string convert all upper case characters to lower case characters. +Given a string convert all upper case characters to lower case characters. Signature : `LOWER: String -> String` diff --git a/examples/src/kotlin/org/partiql/examples/ParserErrorExample.kt b/examples/src/kotlin/org/partiql/examples/ParserErrorExample.kt index abe85375e1..a32ff520a5 100644 --- a/examples/src/kotlin/org/partiql/examples/ParserErrorExample.kt +++ b/examples/src/kotlin/org/partiql/examples/ParserErrorExample.kt @@ -28,7 +28,7 @@ class ParserErrorExample(out: PrintStream) : Example(out) { throw Exception("ParserException was not thrown") } catch (e: ParserException) { - val errorContext = e.errorContext!! + val errorContext = e.errorContext val errorInformation = "errorCode: ${e.errorCode}" + "\nLINE_NUMBER: ${errorContext[Property.LINE_NUMBER]}" + diff --git a/lang/resources/org/partiql/type-domains/partiql.ion b/lang/resources/org/partiql/type-domains/partiql.ion index 09d96fa34d..8224ca3948 100644 --- a/lang/resources/org/partiql/type-domains/partiql.ion +++ b/lang/resources/org/partiql/type-domains/partiql.ion @@ -1,3 +1,31 @@ +/* +Domains defined in this file are listed below. They are listed in transformation order and ultimately arrive at a +physical algebra that is ready to be evaluated. + +- partiql_ast: the result of parsing the PartiQL query. Structure resembles the PartiQL syntax. +- partiql_logical: a direct conversion from the partiql_ast to a logical query plan, with no semantic checking. +- partiql_logical_resolved: a variation of partiql_logical wherein all variable declarations have been allocated unique +identifiers and variable references have been resolved to a local or global variable and their unique identifiers have +been identified. Partial push-downs of filters and projections may be applied here. +- partiql_physical: this is the same as partiql_logical_resolved, but with additional relational operators. Also, all +relational operators include an operand to identify the algorithm to be used at evaluation time. After transforming +from the logical algebra to physical, all operators will be set to use default implementations. The physical algebra +may then be further optimized by selecting better implementations of each operator. + +*/ + + +// Domain transformations + +// Makes PIG emit PartiqlAstToPartiqlLogicalVisitorTransform +(transform partiql_ast partiql_logical) + +// Makes PIG emit PartiqlLogicalToPartiqlLogicalResolvedVisitorTransform +(transform partiql_logical partiql_logical_resolved) + +// Makes PIG emit PartiqlLogicalResolvedToPartiqlPhysicalVisitorTransform +(transform partiql_logical_resolved partiql_physical) + /* The PartiQL AST. @@ -426,4 +454,285 @@ ) // end of domain ) // end of define +// Same as partiql_ast, but without the syntactic representation of SFW queries and introduces PartiQL's relational +// algebra. Also removes some nodes not (yet) supported by the query planner and plan evaluator. +(define partiql_logical + (permute_domain partiql_ast + (include + // This is the new top-level node for plans that are intended to be persisted to storage or survive across + // boundaries. These need to include a version number so at least it is possible to know if a persisted + // plan is compatible with the current version of PartiQL. + (record plan + (stmt statement) + (version symbol) // This should really be a string: https://github.com/partiql/partiql-ir-generator/issues/122 + ) + + // Defines a field within a struct constructor or an expression which is expected to be a container + // that is included in the final struct. + (sum struct_part + // For `.*` in SELECT list + // If `` is a struct, the fields of that struct will be part of the merged struct. + // If `` is not a struct, The field `_n` will be included in the struct, where `n` is the + // ordinal of the field in the final merged struct. If `` returns a container that is not a + // struct, field names will be assigned in the format of `_n` where `n` is the ordinal position of the + // field within the merged struct. If `` returns a scalar value, it will be coerced into a + // singleton bag and the previous logic will apply. + (struct_fields part_expr::expr) + + // For ` [AS ]`. If `field_name` returns a non-text value, in legacy mode an exception + // will be thrown. In permissive mode, the field will be excluded from the final struct. + (struct_field field_name::expr value::expr)) + ) + + (with expr + // Remove the select and struct node from the `expr` sum type, which will be replaced below. + (exclude select struct) + + (include + // Invokes `exp` once in the context of every binding tuple returned by `query`, returning a + // collection of values produced by `exp`. The returned collection's type (bag or list) is the same + // as the bindings collection returned by `query`. + (bindings_to_values exp::expr query::bexpr) + + // `struct` is the primary struct constructor and also encapsulates semantics needed for + // `SELECT .*`, and `SELECT AS y`. It can be used as a regular struct constructor, or as + // a struct-union expression. + // + // Example as struct constructor: + // (struct + // (struct_field (lit a) (lit 42)) + // (struct_field (lit b) (lit 43))) + // Returns: { a: 42, b: 43 } + // + // Example as a struct-union. Given a global environment with `foo` bound to `{ a: 42 }` and `bar` + // bound to `{ b: 43}`, then: + // (struct + // (struct_fields (id foo)) + // (struct_fields (id bar))) + // Returns { a: 42, b: 43 } + // Note that `struct_field` and `struct_fields` may be used in combination: + // (struct + // (struct_fields (id foo)) + // (struct_fields (id bar)) + // (struct_field (lit c) (lit 44))) + // Returns { a: 42, b: 43, c: 44 } + // + // TODO: in the future, when the legacy AST compiler has been removed and the AST is no longer + // part of the public API, we should consider moving this definition to the partiql_ast domain. + (struct parts::(* struct_part 1)) + ) + ) + + // These should be excluded as well since they were referenced only by the `select` variant of `expr`, which + // was excluded above. + (exclude + project_item + projection + from_source + ) + + // Change let_binding so that it has a var_decl instead of only a name to represent bindings. + (exclude let_binding) + (include (product let_binding value::expr decl::var_decl)) + + // Now we include new stuff, including PartiQL's relational algebra. + (include + // Every instance of `var_decl` introduces a new binding in the current scope. + // Every part of the AST that can introduce a variable should be represented with one of these nodes. + // Examples of variable declarations include: + // - The `AS`, `AT`, and `BY` sub-clauses in `FROM` + // - The `AS` sub-clauses in within a `LET` clause. + // - The `AS` and `AT` names specified with, `PIVOT`, i.e. `PIVOT x AS y AT z` + // Note that `AS` aliases specified in a select list (i.e. `SELECT x AS y` are *not* variables, they are + // fields.) + // Modeling this with a separate node (as opposed to just a symbol) is beneficial because it is easy to + // identify all variable declarations within a logical plan during tree traversal, and because in later + // permuted domains we can add information to this type such as the variable's assigned index. + // Elements: + // - `name`: the name of the variable as specified by the query author or determined statically. + (product var_decl name::symbol) + + // The operators of PartiQL's relational algebra. + (sum bexpr + // Converts a value collection to a bindings collection. Not used to perform physical reads. (For + // that, see bexpr.project in the partiql_physical domain.) If evaluating `expr` results in a scalar + // value, it is converted into a singleton bag. + (scan expr::expr as_decl::var_decl at_decl::(? var_decl) by_decl::(? var_decl)) + + // Evaluates `predicate` within the scope of every row of `bexpr`, and only returns those + // rows for which `predicate` returns true. + (filter predicate::expr source::bexpr) + + // Basic join operator. Covers cross, inner, left, right and full joins. + // To represent a cross join, set `predicate` to `null`. + (join + join_type::join_type + left::bexpr + right::bexpr + predicate::(? expr)) + + // Skips `row_count` rows, then emits all remaining rows. + (offset row_count::expr source::bexpr) + + // Emits `row_count` rows, discards all remaining rows. + (limit row_count::expr source::bexpr) + + // For every row of `source`, adds each specified `let_binding`. + (let source::bexpr bindings::(* let_binding 1)) + ) + ) + + // Nodes excluded below this line will eventually have a representation in the logical algebra, but not + // initially. + + (with statement + (exclude + dml + ddl + ) + ) + + (exclude + group_by + grouping_strategy + group_key + group_key_list + order_by + sort_spec + ordering_spec + + let + + dml_op + dml_op_list + ddl_op + conflict_action + on_conflict + returning_expr + returning_elem + column_component + returning_mapping + assignment + identifier + ) + ) +) + +// partiql_logical_resolved is a variation of partiql_logical wherein all variable declarations have been allocated +// unique identifiers and variable references have been resolved. The first set of optimizations such as partial +// push-downs of filters and projections may be applied to this domain. +(define partiql_logical_resolved + (permute_domain partiql_logical + // Add `locals` to `plan`. + (exclude plan) + (include + (record plan + (stmt statement) + (version symbol) // This should really be a string: https://github.com/partiql/partiql-ir-generator/issues/122 + (locals (* local_variable 0)) + ) + + // Local variables currently include a name and register index. In the future, something to indicate the + // static type of the variable may also be included here. The index is included explicitly (instead of + // allowing it to be identified by ordinal position) simply to allow it to be easily identified by humans + // when examining plans with many local variables. + (product local_variable name::symbol register_index::int) + ) + + + // For `var_decl`, `name` with `index`. The name of the variable can still be determined by looking at the + // `local_variable` with the same index. + (exclude var_decl scope_qualifier) + (include + (product var_decl index::int) + ) + + (with expr + // At this point there should be no undefined variables in the plan. All variables are rewritten to + // `local_id`, `global_id`. + (exclude id) + (include + // A resolved reference to a variable that was defined within a query. Otherwise known as a local + // variable. "Resolved" means that the variable is guaranteed to exist and we know its register index. + // Elements: + // - `index`: the index of the `var_decl` that this variable refers to, i.e. this always corresponds to + // the `var_decl` with the same index. + (local_id index::int) + + // Global variable reference--typically a table although it can actually be bound to any value. Unlike + // local variables, global variables are not stored in registers. Instead, they are typically + // retrieved from an implementation defined storage system, persistent or otherwise. Evaluating a + // `global_id` will return a value with an open iterator. There is no syntactic representation of this + // node in PartiQL--`global_id` nodes are produced by the planner during the variable resolution pass + // when a variable is resolved to a global variable. + // Elements: + // - `name`: the original name of the variable, kept mostly just for error reporting purposes. + // - `uniqueId`: any Ion value that uniquely identifies the global variable, typically a storage + // defined UUID or the name of the table in its original letter case. + // The value of `uniqueId` is PartiQL integration defined and can be any symbol that uniquely + // identifies the global variable. Examples include database object ids or the alphabetical case + // respecting table name found after case-insensitive lookup. + (global_id uniqueId::symbol case::case_sensitivity) + ) + ) + ) +) + +// Redefines `bexpr` of `partiql_logical_resolved` to include an `(impl ...)` node within every operator. Following +// transformation from partiql_logical_resolved, the implementation of each `bexpr` will be `(impl default)`. +// Optimizations on this domain include but are not limited to: selection of `(impl ...)` other than `default` and +// rewriting of `filter/scan` `mapValues/scan` to perform final push-down of filters and projections, and optimal +// operator implementation selection (i.e. hash or merge join, etc) +(define partiql_physical + (permute_domain partiql_logical_resolved + (include + // Identifies an implementation that has been selected for an instance of a physical operator and + // identifies any static arguments required. This will initially have the `(impl default)` value, with + // different implementations being selected as needed. + // Elements: + // - `name`: the unique name of the implementation. Each relational operator has a different namespace + // containing its default and custom implementations. + // - `static_args`: Any static arguments. These are arbitrary Ion values and are specific to the + // implementation. These values are made available to the implementation at compile-time and evaluation + // time. + (product impl name::symbol static_args::(* ion 0)) + ) + + // Every variant of bexpr changes by adding an `impl` element in the physical algebra, so let's replace it + // entirely. + (exclude bexpr) + (include + (sum bexpr + // A generic physical read operation. At the moment, implementations of this operator may only + // bind each row read to `binding`. In the future, `binding` might be replaced with multiple + // projection templates (these are Ion like path extractors but are capable of extracting subsets of an + // Ion container.) Examples of physical read operations include: + // - full scan + // - index scan + // - index range scan + // - get-row-by-primary key + // - and many, others. + // The specific read operation represented by this node is determined by the `i::impl` element. + (project i::impl binding::var_decl args::(* arguments::expr 0)) + + // Operators below this point are the same as in the logical algebra, but also include an i::impl + // element. + + (scan i::impl expr::expr as_decl::var_decl at_decl::(? var_decl) by_decl::(? var_decl)) + (filter i::impl predicate::expr source::bexpr) + (join + i::impl + join_type::join_type + left::bexpr + right::bexpr + predicate::(? expr)) + + + (offset i::impl row_count::expr source::bexpr) + (limit i::impl row_count::expr source::bexpr) + (let i::impl source::bexpr bindings::(* let_binding 1)) + ) + ) + ) +) diff --git a/lang/src/org/partiql/lang/CompilerPipeline.kt b/lang/src/org/partiql/lang/CompilerPipeline.kt index 31751e48a9..f481b3a2ac 100644 --- a/lang/src/org/partiql/lang/CompilerPipeline.kt +++ b/lang/src/org/partiql/lang/CompilerPipeline.kt @@ -39,7 +39,7 @@ import org.partiql.lang.types.StaticType import org.partiql.lang.util.interruptibleFold /** - * Contains all of the information needed for processing steps. + * Contains all information needed for processing steps. */ data class StepContext( /** The instance of [ExprValueFactory] that is used by the pipeline. */ @@ -102,6 +102,11 @@ interface CompilerPipeline { */ val procedures: @JvmSuppressWildcards Map + /** + * The configured global type bindings. + */ + val globalTypeBindings: Bindings? + /** Compiles the specified PartiQL query using the configured parser. */ fun compile(query: String): Expression @@ -205,7 +210,7 @@ interface CompilerPipeline { fun build(): CompilerPipeline { val compileOptionsToUse = compileOptions ?: CompileOptions.standard() - when (compileOptionsToUse.thunkReturnTypeAssertions) { + when (compileOptionsToUse.thunkOptions.thunkReturnTypeAssertions) { ThunkReturnTypeAssertions.DISABLED -> { /* intentionally blank */ } ThunkReturnTypeAssertions.ENABLED -> { check(this.globalTypeBindings != null) { @@ -244,7 +249,7 @@ internal class CompilerPipelineImpl( override val customDataTypes: List, override val procedures: Map, private val preProcessingSteps: List, - private val globalTypeBindings: Bindings? + override val globalTypeBindings: Bindings? ) : CompilerPipeline { private val compiler = EvaluatingCompiler( diff --git a/lang/src/org/partiql/lang/SqlException.kt b/lang/src/org/partiql/lang/SqlException.kt index fd64cc9233..6e115ecf8b 100644 --- a/lang/src/org/partiql/lang/SqlException.kt +++ b/lang/src/org/partiql/lang/SqlException.kt @@ -18,6 +18,7 @@ import org.partiql.lang.errors.ErrorCode import org.partiql.lang.errors.Property import org.partiql.lang.errors.PropertyValueMap import org.partiql.lang.errors.UNKNOWN +import org.partiql.lang.util.propertyValueMapOf /** * General exception class for the interpreter. @@ -33,20 +34,18 @@ import org.partiql.lang.errors.UNKNOWN * * @param message the message for this exception * @param errorCode the error code for this exception - * @param propertyValueMap context for this error + * @param errorContextArg context for this error, includes details like line & character offsets, among others. + * TODO: https://github.com/partiql/partiql-lang-kotlin/issues/616 * @param cause for this exception - * - * @constructor a custom error [message], the [errorCode], error context as a [propertyValueMap] and optional [cause] creates an - * [SqlException]. This is the constructor for the second configuration explained above. - * */ open class SqlException( override var message: String, val errorCode: ErrorCode, - val errorContext: PropertyValueMap? = null, + errorContextArg: PropertyValueMap? = null, cause: Throwable? = null -) : - RuntimeException(message, cause) { +) : RuntimeException(message, cause) { + + val errorContext: PropertyValueMap = errorContextArg ?: propertyValueMapOf() /** * Indicates if this exception is due to an internal error or not. @@ -81,7 +80,7 @@ open class SqlException( * * * ErrorCategory is one of `Lexer Error`, `Parser Error`, `Runtime Error` * * ErrorLocation is the line and column where the error occurred - * * Errormessatge is the **generated** error message + * * ErrorMessage is the **generated** error message * * * TODO: Prepend to the auto-generated message the file name. @@ -90,6 +89,10 @@ open class SqlException( fun generateMessage(): String = "${errorCategory(errorCode)}: ${errorLocation(errorContext)}: ${errorMessage(errorCode, errorContext)}" + /** Same as [generateMessage] but without the location. */ + fun generateMessageNoLocation(): String = + "${errorCategory(errorCode)}: ${errorMessage(errorCode, errorContext)}" + private fun errorMessage(errorCode: ErrorCode?, propertyValueMap: PropertyValueMap?): String = errorCode?.getErrorMessage(propertyValueMap) ?: UNKNOWN diff --git a/lang/src/org/partiql/lang/ast/passes/SemanticException.kt b/lang/src/org/partiql/lang/ast/passes/SemanticException.kt index 7996aa9a83..c32576d104 100644 --- a/lang/src/org/partiql/lang/ast/passes/SemanticException.kt +++ b/lang/src/org/partiql/lang/ast/passes/SemanticException.kt @@ -37,7 +37,7 @@ class SemanticException( constructor(err: Problem, cause: Throwable? = null) : this( message = "", - errorCode = ErrorCode.SEMANTIC_INFERENCER_ERROR, + errorCode = ErrorCode.SEMANTIC_PROBLEM, errorContext = propertyValueMapOf( Property.LINE_NUMBER to err.sourceLocation.lineNum, Property.COLUMN_NUMBER to err.sourceLocation.charOffset, diff --git a/lang/src/org/partiql/lang/domains/util.kt b/lang/src/org/partiql/lang/domains/util.kt index 713fdc4f54..b885d0706b 100644 --- a/lang/src/org/partiql/lang/domains/util.kt +++ b/lang/src/org/partiql/lang/domains/util.kt @@ -1,5 +1,6 @@ package org.partiql.lang.domains +import com.amazon.ionelement.api.IonElement import com.amazon.ionelement.api.MetaContainer import com.amazon.ionelement.api.emptyMetaContainer import com.amazon.ionelement.api.metaContainerOf @@ -14,6 +15,19 @@ import org.partiql.lang.eval.BindingCase fun PartiqlAst.Builder.id(name: String) = id(name, caseInsensitive(), unqualified()) +// TODO: once https://github.com/partiql/partiql-ir-generator/issues/6 has been completed, we can delete this. +fun PartiqlLogical.Builder.id(name: String) = + id(name, caseInsensitive(), unqualified()) + +// TODO: once https://github.com/partiql/partiql-ir-generator/issues/6 has been completed, we can delete this. +fun PartiqlLogical.Builder.pathExpr(exp: PartiqlLogical.Expr) = + pathExpr(exp, caseInsensitive()) + +// Workaround for a bug in PIG that is fixed in its next release: +// https://github.com/partiql/partiql-ir-generator/issues/41 +fun List.asAnyElement() = + this.map { it.asAnyElement() } + val MetaContainer.staticType: StaticTypeMeta? get() = this[StaticTypeMeta.TAG] as StaticTypeMeta? /** Constructs a container with the specified metas. */ @@ -60,17 +74,17 @@ fun PartiqlAst.CaseSensitivity.toBindingCase(): BindingCase = when (this) { } /** - * Returns the [SourceLocationMeta] as an error context if the [SourceLocationMeta.TAG] exists in the passed - * [metaContainer]. Otherwise, returns an empty map. + * Converts a [PartiqlLogical.CaseSensitivity] to a [BindingCase]. */ -fun errorContextFrom(metaContainer: MetaContainer?): PropertyValueMap { - if (metaContainer == null) { - return PropertyValueMap() - } - val location = metaContainer[SourceLocationMeta.TAG] as? SourceLocationMeta - return if (location != null) { - org.partiql.lang.eval.errorContextFrom(location) - } else { - PropertyValueMap() - } +fun PartiqlLogical.CaseSensitivity.toBindingCase(): BindingCase = when (this) { + is PartiqlLogical.CaseSensitivity.CaseInsensitive -> BindingCase.INSENSITIVE + is PartiqlLogical.CaseSensitivity.CaseSensitive -> BindingCase.SENSITIVE +} + +/** + * Converts a [PartiqlLogical.CaseSensitivity] to a [BindingCase]. + */ +fun PartiqlPhysical.CaseSensitivity.toBindingCase(): BindingCase = when (this) { + is PartiqlPhysical.CaseSensitivity.CaseInsensitive -> BindingCase.INSENSITIVE + is PartiqlPhysical.CaseSensitivity.CaseSensitive -> BindingCase.SENSITIVE } diff --git a/lang/src/org/partiql/lang/errors/ErrorCode.kt b/lang/src/org/partiql/lang/errors/ErrorCode.kt index 256be64e65..61cc8e8e59 100644 --- a/lang/src/org/partiql/lang/errors/ErrorCode.kt +++ b/lang/src/org/partiql/lang/errors/ErrorCode.kt @@ -647,7 +647,7 @@ enum class ErrorCode( "got: ${errorContext?.get(Property.ACTUAL_ARGUMENT_TYPES) ?: UNKNOWN}" }, - SEMANTIC_INFERENCER_ERROR( + SEMANTIC_PROBLEM( ErrorCategory.SEMANTIC, LOCATION + setOf(Property.MESSAGE), "" @@ -980,12 +980,6 @@ enum class ErrorCode( ErrorBehaviorInPermissiveMode.RETURN_MISSING ), - EVALUATOR_SQL_EXCEPTION( - ErrorCategory.EVALUATOR, - LOCATION, - "SQL exception" - ), - EVALUATOR_COUNT_START_NOT_ALLOWED( ErrorCategory.EVALUATOR, LOCATION, diff --git a/lang/src/org/partiql/lang/eval/CompileOptions.kt b/lang/src/org/partiql/lang/eval/CompileOptions.kt index a26511e013..412507743f 100644 --- a/lang/src/org/partiql/lang/eval/CompileOptions.kt +++ b/lang/src/org/partiql/lang/eval/CompileOptions.kt @@ -141,6 +141,7 @@ enum class ThunkReturnTypeAssertions { * @param defaultTimezoneOffset Default timezone offset to be used when TIME WITH TIME ZONE does not explicitly * specify the time zone. Defaults to [ZoneOffset.UTC] */ +@Suppress("DataClassPrivateConstructor") data class CompileOptions private constructor ( val undefinedVariable: UndefinedVariableBehavior, val projectionIteration: ProjectionIterationBehavior = ProjectionIterationBehavior.FILTER_MISSING, @@ -148,7 +149,6 @@ data class CompileOptions private constructor ( val thunkOptions: ThunkOptions = ThunkOptions.standard(), val typingMode: TypingMode = TypingMode.LEGACY, val typedOpBehavior: TypedOpBehavior = TypedOpBehavior.LEGACY, - val thunkReturnTypeAssertions: ThunkReturnTypeAssertions = ThunkReturnTypeAssertions.DISABLED, val defaultTimezoneOffset: ZoneOffset = ZoneOffset.UTC ) { @@ -177,7 +177,7 @@ data class CompileOptions private constructor ( fun build(options: CompileOptions, block: Builder.() -> Unit) = Builder(options).apply(block).build() /** - * Creates a [CompileOptions] instance with the standard values. + * Creates a [CompileOptions] instance with the standard values for use by the legacy AST compiler. */ @JvmStatic fun standard() = Builder().build() @@ -194,7 +194,7 @@ data class CompileOptions private constructor ( fun typingMode(value: TypingMode) = set { copy(typingMode = value) } fun typedOpBehavior(value: TypedOpBehavior) = set { copy(typedOpBehavior = value) } fun thunkOptions(value: ThunkOptions) = set { copy(thunkOptions = value) } - fun evaluationTimeTypeChecks(value: ThunkReturnTypeAssertions) = set { copy(thunkReturnTypeAssertions = value) } + fun thunkOptions(build: ThunkOptions.Builder.() -> Unit) = set { copy(thunkOptions = ThunkOptions.build(build)) } fun defaultTimezoneOffset(value: ZoneOffset) = set { copy(defaultTimezoneOffset = value) } private inline fun set(block: CompileOptions.() -> CompileOptions): Builder { diff --git a/lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt b/lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt index 99186f2c49..26f97d3eba 100644 --- a/lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt +++ b/lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt @@ -86,6 +86,25 @@ import java.util.TreeSet import java.util.regex.Pattern import kotlin.Comparator +/** + * A thunk with no parameters other than the current environment. + * + * See https://en.wikipedia.org/wiki/Thunk + * + * This name was chosen because it is a thunk that accepts an instance of `Environment`. + */ +private typealias ThunkEnv = Thunk + +/** + * A thunk taking a single [T] argument and the current environment. + * + * See https://en.wikipedia.org/wiki/Thunk + * + * This name was chosen because it is a thunk that accepts an instance of `Environment` and an [ExprValue] as + * its arguments. + */ +private typealias ThunkEnvValue = ThunkValue + /** * A basic compiler that converts an instance of [PartiqlAst] to an [Expression]. * @@ -115,7 +134,7 @@ internal class EvaluatingCompiler( private val compileOptions: CompileOptions = CompileOptions.standard() ) { private val errorSignaler = compileOptions.typingMode.createErrorSignaler(valueFactory) - private val thunkFactory = compileOptions.typingMode.createThunkFactory(compileOptions, valueFactory) + private val thunkFactory = compileOptions.typingMode.createThunkFactory(compileOptions.thunkOptions, valueFactory) private val compilationContextStack = Stack() @@ -3048,7 +3067,7 @@ private class SingleProjectionElement(val name: ExprValue, val thunk: ThunkEnv) */ private class MultipleProjectionElement(val thunks: List) : ProjectionElement() -private val MetaContainer.sourceLocationMeta get() = this[SourceLocationMeta.TAG] as? SourceLocationMeta +internal val MetaContainer.sourceLocationMeta get() = this[SourceLocationMeta.TAG] as? SourceLocationMeta private fun StaticType.getTypes() = when (val flattened = this.flatten()) { is AnyOfType -> flattened.types diff --git a/lang/src/org/partiql/lang/eval/Exceptions.kt b/lang/src/org/partiql/lang/eval/Exceptions.kt index e00c8490a6..49a574a5b5 100644 --- a/lang/src/org/partiql/lang/eval/Exceptions.kt +++ b/lang/src/org/partiql/lang/eval/Exceptions.kt @@ -125,6 +125,10 @@ fun fillErrorContext(errorContext: PropertyValueMap, metaContainer: MetaContaine } } +/** + * Returns the [SourceLocationMeta] as an error context if the [SourceLocationMeta.TAG] exists in the passed + * [metaContainer]. Otherwise, returns an empty map. + */ fun errorContextFrom(metaContainer: MetaContainer?): PropertyValueMap { if (metaContainer == null) { return PropertyValueMap() diff --git a/lang/src/org/partiql/lang/eval/Thunk.kt b/lang/src/org/partiql/lang/eval/Thunk.kt index 8ef7566397..60b715c182 100644 --- a/lang/src/org/partiql/lang/eval/Thunk.kt +++ b/lang/src/org/partiql/lang/eval/Thunk.kt @@ -27,19 +27,21 @@ import org.partiql.lang.errors.Property * * See https://en.wikipedia.org/wiki/Thunk * - * This name was chosen because it is a thunk that accepts an instance of `Environment`. + * @param TEnv The type of the environment. Generic so that the legacy AST compiler and the new compiler may use + * different types here. */ -internal typealias ThunkEnv = (Environment) -> ExprValue +internal typealias Thunk = (TEnv) -> ExprValue /** - * A thunk taking a single [T] argument and the current environment. + * A thunk taking a single argument and the current environment. * * See https://en.wikipedia.org/wiki/Thunk * - * This name was chosen because it is a thunk that accepts an instance of `Environment` and an [ExprValue] as - * its arguments. + * @param TEnv The type of the environment. Generic so that the legacy AST compiler and the new compiler may use + * different types here. + * @param TArg The type of the additional argument. */ -internal typealias ThunkEnvValue = (Environment, T) -> ExprValue +internal typealias ThunkValue = (TEnv, TArg) -> ExprValue /** * A type alias for an exception handler which always throws(primarily used for [TypingMode.LEGACY]). @@ -56,12 +58,17 @@ internal typealias ThunkExceptionHandlerForPermissiveMode = (Throwable, SourceLo * * - [handleExceptionForLegacyMode] will be called when in [TypingMode.LEGACY] mode * - [handleExceptionForPermissiveMode] will be called when in [TypingMode.PERMISSIVE] mode + * - [thunkReturnTypeAssertions] is intended for testing only, and ensures that the return value of every expression + * conforms to its `StaticType` meta. This has negative performance implications so should be avoided in production + * environments. This only be used for testing and diagnostic purposes only. * The default exception handler wraps any [Throwable] exception and throws [EvaluationException] */ data class ThunkOptions private constructor( val handleExceptionForLegacyMode: ThunkExceptionHandlerForLegacyMode = DEFAULT_EXCEPTION_HANDLER_FOR_LEGACY_MODE, - val handleExceptionForPermissiveMode: ThunkExceptionHandlerForPermissiveMode = DEFAULT_EXCEPTION_HANDLER_FOR_PERMISSIVE_MODE + val handleExceptionForPermissiveMode: ThunkExceptionHandlerForPermissiveMode = DEFAULT_EXCEPTION_HANDLER_FOR_PERMISSIVE_MODE, + val thunkReturnTypeAssertions: ThunkReturnTypeAssertions = ThunkReturnTypeAssertions.DISABLED, ) { + companion object { /** @@ -89,6 +96,7 @@ data class ThunkOptions private constructor( private var options = ThunkOptions() fun handleExceptionForLegacyMode(value: ThunkExceptionHandlerForLegacyMode) = set { copy(handleExceptionForLegacyMode = value) } fun handleExceptionForPermissiveMode(value: ThunkExceptionHandlerForPermissiveMode) = set { copy(handleExceptionForPermissiveMode = value) } + fun evaluationTimeTypeChecks(value: ThunkReturnTypeAssertions) = set { copy(thunkReturnTypeAssertions = value) } private inline fun set(block: ThunkOptions.() -> ThunkOptions): Builder { options = block(options) return this @@ -116,18 +124,18 @@ internal val DEFAULT_EXCEPTION_HANDLER_FOR_PERMISSIVE_MODE: ThunkExceptionHandle * - when [TypingMode] is [TypingMode.LEGACY], creates [LegacyThunkFactory] * - when [TypingMode] is [TypingMode.PERMISSIVE], creates [PermissiveThunkFactory] */ -internal fun TypingMode.createThunkFactory( - compileOptions: CompileOptions, +internal fun TypingMode.createThunkFactory( + thunkOptions: ThunkOptions, valueFactory: ExprValueFactory -): ThunkFactory = when (this) { - TypingMode.LEGACY -> LegacyThunkFactory(compileOptions, valueFactory) - TypingMode.PERMISSIVE -> PermissiveThunkFactory(compileOptions, valueFactory) +): ThunkFactory = when (this) { + TypingMode.LEGACY -> LegacyThunkFactory(thunkOptions, valueFactory) + TypingMode.PERMISSIVE -> PermissiveThunkFactory(thunkOptions, valueFactory) } /** * Provides methods for constructing new thunks according to the specified [CompileOptions]. */ -internal abstract class ThunkFactory( - val compileOptions: CompileOptions, +internal abstract class ThunkFactory( + val thunkOptions: ThunkOptions, val valueFactory: ExprValueFactory ) { private fun checkEvaluationTimeType(thunkResult: ExprValue, metas: MetaContainer): ExprValue { @@ -157,11 +165,11 @@ internal abstract class ThunkFactory( * confusion in the case [StaticTypeInferenceVisitorTransform] has a bug which prevents it from assigning a * [StaticTypeMeta] or in case it is not run at all. */ - protected fun ThunkEnv.typeCheck(metas: MetaContainer): ThunkEnv = - when (compileOptions.thunkReturnTypeAssertions) { + protected fun Thunk.typeCheck(metas: MetaContainer): Thunk = + when (thunkOptions.thunkReturnTypeAssertions) { ThunkReturnTypeAssertions.DISABLED -> this ThunkReturnTypeAssertions.ENABLED -> { - val wrapper = { env: Environment -> + val wrapper = { env: TEnv -> val thunkResult: ExprValue = this(env) checkEvaluationTimeType(thunkResult, metas) } @@ -169,12 +177,12 @@ internal abstract class ThunkFactory( } } - /** Same as [typeCheck] but works on a [ThunkEnvValue] instead of a [ThunkEnv]. */ - protected fun ThunkEnvValue.typeCheckEnvValue(metas: MetaContainer): ThunkEnvValue = - when (compileOptions.thunkReturnTypeAssertions) { + /** Same as [typeCheck] but works on a [ThunkEnvValue] instead of a [Thunk]. */ + protected fun ThunkValue.typeCheckEnvValue(metas: MetaContainer): ThunkValue = + when (thunkOptions.thunkReturnTypeAssertions) { ThunkReturnTypeAssertions.DISABLED -> this ThunkReturnTypeAssertions.ENABLED -> { - val wrapper = { env: Environment, value: ExprValue -> + val wrapper = { env: TEnv, value: ExprValue -> val thunkResult: ExprValue = this(env, value) checkEvaluationTimeType(thunkResult, metas) } @@ -182,12 +190,12 @@ internal abstract class ThunkFactory( } } - /** Same as [typeCheck] but works on a [ThunkEnvValue>] instead of a [ThunkEnv]. */ - protected fun ThunkEnvValue>.typeCheckEnvValueList(metas: MetaContainer): ThunkEnvValue> = - when (compileOptions.thunkReturnTypeAssertions) { + /** Same as [typeCheck] but works on a [ThunkEnvValue>] instead of a [Thunk]. */ + protected fun ThunkValue>.typeCheckEnvValueList(metas: MetaContainer): ThunkValue> = + when (thunkOptions.thunkReturnTypeAssertions) { ThunkReturnTypeAssertions.DISABLED -> this ThunkReturnTypeAssertions.ENABLED -> { - val wrapper = { env: Environment, value: List -> + val wrapper = { env: TEnv, value: List -> val thunkResult: ExprValue = this(env, value) checkEvaluationTimeType(thunkResult, metas) } @@ -196,17 +204,17 @@ internal abstract class ThunkFactory( } /** - * Creates a [ThunkEnv] which handles exceptions by wrapping them into an [EvaluationException] which uses + * Creates a [Thunk] which handles exceptions by wrapping them into an [EvaluationException] which uses * [handleException] to handle exceptions appropriately. * * Literal lambdas passed to this function as [t] are inlined into the body of the function being returned, which * reduces the need to create additional call contexts. The lambdas passed as [t] may not contain non-local returns * (`crossinline`). */ - internal inline fun thunkEnv(metas: MetaContainer, crossinline t: ThunkEnv): ThunkEnv { + internal inline fun thunkEnv(metas: MetaContainer, crossinline t: Thunk): Thunk { val sourceLocationMeta = metas[SourceLocationMeta.TAG] as? SourceLocationMeta - return { env: Environment -> + return { env: TEnv -> handleException(sourceLocationMeta) { t(env) } @@ -221,8 +229,11 @@ internal abstract class ThunkFactory( * * For all [TypingMode]s, if the values returned by [getVal1], [getVal2] and [getVal2] are all known, * [compute] is invoked to perform the operation-specific computation. + * + * Note: this must be public due to a Kotlin compiler bug: https://youtrack.jetbrains.com/issue/KT-22625. + * This shouldn't matter though because this class is still `internal`. */ - protected abstract fun propagateUnknowns( + abstract fun propagateUnknowns( getVal1: () -> ExprValue, getVal2: (() -> ExprValue)?, getVal3: (() -> ExprValue)?, @@ -232,14 +243,17 @@ internal abstract class ThunkFactory( /** * Similar to the other [propagateUnknowns] overload, performs unknown propagation for a variadic sequence of * operations. + * + * Note: this must be public due to a Kotlin compiler bug: https://youtrack.jetbrains.com/issue/KT-22625. + * This shouldn't matter though because this class is still `internal`. */ - protected abstract fun propagateUnknowns( + abstract fun propagateUnknowns( operands: Sequence, compute: (List) -> ExprValue ): ExprValue /** - * Creates a thunk that accepts three [ThunkEnv] operands ([t1], [t2], and [t3]), evaluates them and propagates + * Creates a thunk that accepts three [Thunk] operands ([t1], [t2], and [t3]), evaluates them and propagates * unknowns according to the current [TypingMode]. When possible, use this function or one of its overloads * instead of [thunkEnv] when the operation requires propagation of unknown values. * @@ -260,48 +274,48 @@ internal abstract class ThunkFactory( */ internal inline fun thunkEnvOperands( metas: MetaContainer, - crossinline t1: ThunkEnv, - crossinline t2: ThunkEnv, - crossinline t3: ThunkEnv, - crossinline compute: (Environment, ExprValue, ExprValue, ExprValue) -> ExprValue - ): ThunkEnv = + crossinline t1: Thunk, + crossinline t2: Thunk, + crossinline t3: Thunk, + crossinline compute: (TEnv, ExprValue, ExprValue, ExprValue) -> ExprValue + ): Thunk = thunkEnv(metas) { env -> propagateUnknowns({ t1(env) }, { t2(env) }, { t3(env) }) { v1, v2, v3 -> compute(env, v1, v2!!, v3!!) } }.typeCheck(metas) - /** See the [thunkEnvOperands] with three [ThunkEnv] operands. */ + /** See the [thunkEnvOperands] with three [Thunk] operands. */ internal inline fun thunkEnvOperands( metas: MetaContainer, - crossinline t1: ThunkEnv, - crossinline t2: ThunkEnv, - crossinline compute: (Environment, ExprValue, ExprValue) -> ExprValue - ): ThunkEnv = + crossinline t1: Thunk, + crossinline t2: Thunk, + crossinline compute: (TEnv, ExprValue, ExprValue) -> ExprValue + ): Thunk = this.thunkEnv(metas) { env -> propagateUnknowns({ t1(env) }, { t2(env) }, null) { v1, v2, _ -> compute(env, v1, v2!!) } }.typeCheck(metas) - /** See the [thunkEnvOperands] with three [ThunkEnv] operands. */ + /** See the [thunkEnvOperands] with three [Thunk] operands. */ internal inline fun thunkEnvOperands( metas: MetaContainer, - crossinline t1: ThunkEnv, - crossinline compute: (Environment, ExprValue) -> ExprValue - ): ThunkEnv = + crossinline t1: Thunk, + crossinline compute: (TEnv, ExprValue) -> ExprValue + ): Thunk = this.thunkEnv(metas) { env -> propagateUnknowns({ t1(env) }, null, null) { v1, _, _ -> compute(env, v1) } }.typeCheck(metas) - /** See the [thunkEnvOperands] with a variadic list of [ThunkEnv] operands. */ + /** See the [thunkEnvOperands] with a variadic list of [Thunk] operands. */ internal inline fun thunkEnvOperands( metas: MetaContainer, - operandThunks: List, - crossinline compute: (Environment, List) -> ExprValue - ): ThunkEnv { + operandThunks: List>, + crossinline compute: (TEnv, List) -> ExprValue + ): Thunk { return this.thunkEnv(metas) { env -> val operandSeq = sequence { operandThunks.forEach { yield(it(env)) } } @@ -314,11 +328,11 @@ internal abstract class ThunkFactory( /** Similar to [thunkEnv], but creates a [ThunkEnvValue] instead. */ internal inline fun thunkEnvValue( metas: MetaContainer, - crossinline t: ThunkEnvValue - ): ThunkEnvValue { + crossinline t: ThunkValue + ): ThunkValue { val sourceLocationMeta = metas[SourceLocationMeta.TAG] as? SourceLocationMeta - return { env: Environment, arg1: ExprValue -> + return { env: TEnv, arg1: ExprValue -> handleException(sourceLocationMeta) { t(env, arg1) } @@ -328,11 +342,11 @@ internal abstract class ThunkFactory( /** Similar to [thunkEnv], but creates a [ThunkEnvValue>] instead. */ internal inline fun thunkEnvValueList( metas: MetaContainer, - crossinline t: ThunkEnvValue> - ): ThunkEnvValue> { + crossinline t: ThunkValue> + ): ThunkValue> { val sourceLocationMeta = metas[SourceLocationMeta.TAG] as? SourceLocationMeta - return { env: Environment, arg1: List -> + return { env: TEnv, arg1: List -> handleException(sourceLocationMeta) { t(env, arg1) } @@ -353,9 +367,9 @@ internal abstract class ThunkFactory( */ internal abstract fun thunkFold( metas: MetaContainer, - argThunks: List, + argThunks: List>, op: (ExprValue, ExprValue) -> ExprValue - ): ThunkEnv + ): Thunk /** * Similar to [thunkFold] but intended for comparison operators, i.e. `=`, `>`, `>=`, `<`, `<=`. @@ -374,35 +388,25 @@ internal abstract class ThunkFactory( */ internal abstract fun thunkAndMap( metas: MetaContainer, - argThunks: List, + argThunks: List>, op: (ExprValue, ExprValue) -> Boolean - ): ThunkEnv + ): Thunk /** Populates [exception] with the line & column from the specified [SourceLocationMeta]. */ protected fun populateErrorContext( exception: EvaluationException, sourceLocation: SourceLocationMeta? - ) = when (exception.errorContext) { - null -> - EvaluationException( - message = exception.message, - errorCode = exception.errorCode, - errorContext = errorContextFrom(sourceLocation), - cause = exception, - internal = exception.internal - ) - else -> { - // Only add source location data to the error context if it doesn't already exist - // in [errorContext]. - if (!exception.errorContext.hasProperty(Property.LINE_NUMBER)) { - sourceLocation?.let { fillErrorContext(exception.errorContext, sourceLocation) } - } - exception + ): EvaluationException { + // Only add source location data to the error context if it doesn't already exist + // in [errorContext]. + if (!exception.errorContext.hasProperty(Property.LINE_NUMBER)) { + sourceLocation?.let { fillErrorContext(exception.errorContext, sourceLocation) } } + return exception } /** - * Handles exceptions appropriately for a run-time [ThunkEnv]. + * Handles exceptions appropriately for a run-time [Thunk]. * * - The [SourceLocationMeta] will be extracted from [MetaContainer] and included in any [EvaluationException] that * is thrown, if present. @@ -419,10 +423,10 @@ internal abstract class ThunkFactory( /** * Provides methods for constructing new thunks according to the specified [CompileOptions] for [TypingMode.LEGACY] behaviour. */ -internal class LegacyThunkFactory( - compileOptions: CompileOptions, +internal class LegacyThunkFactory( + thunkOptions: ThunkOptions, valueFactory: ExprValueFactory -) : ThunkFactory(compileOptions, valueFactory) { +) : ThunkFactory(thunkOptions, valueFactory) { override fun propagateUnknowns( getVal1: () -> ExprValue, @@ -470,9 +474,9 @@ internal class LegacyThunkFactory( /** See [ThunkFactory.thunkFold]. */ override fun thunkFold( metas: MetaContainer, - argThunks: List, + argThunks: List>, op: (ExprValue, ExprValue) -> ExprValue - ): ThunkEnv { + ): Thunk { require(argThunks.isNotEmpty()) { "argThunks must not be empty" } val firstThunk = argThunks.first() @@ -498,9 +502,9 @@ internal class LegacyThunkFactory( /** See [ThunkFactory.thunkAndMap]. */ override fun thunkAndMap( metas: MetaContainer, - argThunks: List, + argThunks: List>, op: (ExprValue, ExprValue) -> Boolean - ): ThunkEnv { + ): Thunk { require(argThunks.size >= 2) { "argThunks must have at least two elements" } val firstThunk = argThunks.first() @@ -534,7 +538,7 @@ internal class LegacyThunkFactory( } /** - * Handles exceptions appropriately for a run-time [ThunkEnv] respecting [TypingMode.LEGACY] behaviour. + * Handles exceptions appropriately for a run-time [Thunk] respecting [TypingMode.LEGACY] behaviour. * * - The [SourceLocationMeta] will be extracted from [MetaContainer] and included in any [EvaluationException] that * is thrown, if present. @@ -551,7 +555,7 @@ internal class LegacyThunkFactory( } catch (e: EvaluationException) { throw populateErrorContext(e, sourceLocation) } catch (e: Exception) { - compileOptions.thunkOptions.handleExceptionForLegacyMode(e, sourceLocation) + thunkOptions.handleExceptionForLegacyMode(e, sourceLocation) } } @@ -559,10 +563,10 @@ internal class LegacyThunkFactory( * Provides methods for constructing new thunks according to the specified [CompileOptions] and for * [TypingMode.PERMISSIVE] behaviour. */ -internal class PermissiveThunkFactory( - compileOptions: CompileOptions, +internal class PermissiveThunkFactory( + thunkOptions: ThunkOptions, valueFactory: ExprValueFactory -) : ThunkFactory(compileOptions, valueFactory) { +) : ThunkFactory(thunkOptions, valueFactory) { override fun propagateUnknowns( getVal1: () -> ExprValue, @@ -628,9 +632,9 @@ internal class PermissiveThunkFactory( /** See [ThunkFactory.thunkFold]. */ override fun thunkFold( metas: MetaContainer, - argThunks: List, + argThunks: List>, op: (ExprValue, ExprValue) -> ExprValue - ): ThunkEnv { + ): Thunk { require(argThunks.isNotEmpty()) { "argThunks must not be empty" } return thunkEnv(metas) { env -> @@ -654,9 +658,9 @@ internal class PermissiveThunkFactory( /** See [ThunkFactory.thunkAndMap]. */ override fun thunkAndMap( metas: MetaContainer, - argThunks: List, + argThunks: List>, op: (ExprValue, ExprValue) -> Boolean - ): ThunkEnv { + ): Thunk { require(argThunks.size >= 2) { "argThunks must have at least two elements" } return thunkEnv(metas) thunkBlock@{ env -> @@ -684,7 +688,7 @@ internal class PermissiveThunkFactory( } /** - * Handles exceptions appropriately for a run-time [ThunkEnv] respecting [TypingMode.PERMISSIVE] behaviour. + * Handles exceptions appropriately for a run-time [Thunk] respecting [TypingMode.PERMISSIVE] behaviour. * * - Exceptions thrown by [block] that are [EvaluationException] are caught and [MissingExprValue] is returned. * - Exceptions thrown by [block] that are not an [EvaluationException] cause an [EvaluationException] to be thrown @@ -697,13 +701,13 @@ internal class PermissiveThunkFactory( try { block() } catch (e: EvaluationException) { - compileOptions.thunkOptions.handleExceptionForPermissiveMode(e, sourceLocation) + thunkOptions.handleExceptionForPermissiveMode(e, sourceLocation) when (e.errorCode.errorBehaviorInPermissiveMode) { // Rethrows the exception as it does in LEGACY mode. ErrorBehaviorInPermissiveMode.THROW_EXCEPTION -> throw populateErrorContext(e, sourceLocation) ErrorBehaviorInPermissiveMode.RETURN_MISSING -> valueFactory.missingValue } } catch (e: Exception) { - compileOptions.thunkOptions.handleExceptionForLegacyMode(e, sourceLocation) + thunkOptions.handleExceptionForLegacyMode(e, sourceLocation) } } diff --git a/lang/src/org/partiql/lang/eval/builtins/BuiltinFunctions.kt b/lang/src/org/partiql/lang/eval/builtins/BuiltinFunctions.kt index 5a685f37db..d4c8fa8d2b 100644 --- a/lang/src/org/partiql/lang/eval/builtins/BuiltinFunctions.kt +++ b/lang/src/org/partiql/lang/eval/builtins/BuiltinFunctions.kt @@ -15,15 +15,20 @@ package org.partiql.lang.eval.builtins import com.amazon.ion.system.IonSystemBuilder +import org.partiql.lang.eval.DEFAULT_COMPARATOR import org.partiql.lang.eval.EvaluationSession import org.partiql.lang.eval.ExprFunction import org.partiql.lang.eval.ExprValue import org.partiql.lang.eval.ExprValueFactory import org.partiql.lang.eval.stringValue +import org.partiql.lang.eval.unnamedValue import org.partiql.lang.types.AnyOfType import org.partiql.lang.types.FunctionSignature import org.partiql.lang.types.StaticType import org.partiql.lang.types.UnknownArguments +import java.util.TreeSet + +internal const val DYNAMIC_LOOKUP_FUNCTION_NAME = "\$__dynamic_lookup__" internal fun createBuiltinFunctionSignatures(): Map = // Creating a new IonSystem in this instance is not the problem it would normally be since we are @@ -40,6 +45,7 @@ internal fun createBuiltinFunctions(valueFactory: ExprValueFactory) = createCharacterLength("character_length", valueFactory), createCharacterLength("char_length", valueFactory), createUtcNow(valueFactory), + createFilterDistinct(valueFactory), DateAddExprFunction(valueFactory), DateDiffExprFunction(valueFactory), ExtractExprFunction(valueFactory), @@ -77,6 +83,29 @@ internal fun createUtcNow(valueFactory: ExprValueFactory): ExprFunction = object valueFactory.newTimestamp(session.now) } +internal fun createFilterDistinct(valueFactory: ExprValueFactory): ExprFunction = object : ExprFunction { + override val signature = FunctionSignature( + "filter_distinct", + listOf(StaticType.unionOf(StaticType.BAG, StaticType.LIST, StaticType.SEXP, StaticType.STRUCT)), + returnType = StaticType.BAG + ) + + override fun callWithRequired(session: EvaluationSession, required: List): ExprValue { + val argument = required.first() + // We cannot use a [HashSet] here because [ExprValue] does not implement .equals() and .hashCode() + val encountered = TreeSet(DEFAULT_COMPARATOR) + return valueFactory.newBag( + sequence { + argument.asSequence().forEach { + if (!encountered.contains(it)) { + encountered.add(it.unnamedValue()) + yield(it) + } + } + } + ) + } +} internal fun createCharacterLength(name: String, valueFactory: ExprValueFactory): ExprFunction = object : ExprFunction { override val signature: FunctionSignature diff --git a/lang/src/org/partiql/lang/eval/builtins/DynamicLookupExprFunction.kt b/lang/src/org/partiql/lang/eval/builtins/DynamicLookupExprFunction.kt new file mode 100644 index 0000000000..080e75c8da --- /dev/null +++ b/lang/src/org/partiql/lang/eval/builtins/DynamicLookupExprFunction.kt @@ -0,0 +1,106 @@ +package org.partiql.lang.eval.builtins + +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.eval.BindingCase +import org.partiql.lang.eval.BindingName +import org.partiql.lang.eval.EvaluationException +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.ExprFunction +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.ExprValueType +import org.partiql.lang.eval.physical.throwUndefinedVariableException +import org.partiql.lang.eval.stringValue +import org.partiql.lang.types.FunctionSignature +import org.partiql.lang.types.StaticType +import org.partiql.lang.types.VarargFormalParameter + +/** + * Performs dynamic variable resolution. Query authors should never call this function directly (and indeed it is + * named to avoid collision with the names of custom functions)--instead, the query planner injects call sites + * to this function to perform dynamic variable resolution of undefined variables. This provides a migration path + * for legacy customers that depend on this behavior. + * + * Arguments: + * + * 1. variable name (must be a symbol) + * 2. case sensitivity (must be a symbol; one of: `case_insensitive` or `case_sensitive`) + * 3. lookup strategy (must be a symbol; one of: `globals_then_locals` or `locals_then_globals`) + * 4. A variadic list of values to be searched. Only struct are searched. This is required because it is not + * currently possible to know the types of these arguments within the variable resolution pass + * ([org.partiql.lang.planner.transforms.LogicalToLogicalResolvedVisitorTransform]). Therefore all variables + * in the current scope must be included in the list of values to be searched. + * TODO: when the open type system's static type inferencer is working, static type information can be used to identify + * and remove non-struct types from call sites to this function. + * + * The name of this function is [DYNAMIC_LOOKUP_FUNCTION_NAME], which includes a unique prefix and suffix so as to + * avoid clashes with user-defined functions. + */ +class DynamicLookupExprFunction : ExprFunction { + override val signature: FunctionSignature + get() { + return FunctionSignature( + name = DYNAMIC_LOOKUP_FUNCTION_NAME, + // Required parameters are: variable name, case sensitivity and lookup strategy + requiredParameters = listOf(StaticType.SYMBOL, StaticType.SYMBOL, StaticType.SYMBOL), + variadicParameter = VarargFormalParameter(StaticType.ANY, 0..Int.MAX_VALUE), + returnType = StaticType.ANY + ) + } + + override fun callWithVariadic( + session: EvaluationSession, + required: List, + variadic: List + ): ExprValue { + val variableName = required[0].stringValue() + + val caseSensitivity = when (val caseSensitivityParameterValue = required[1].stringValue()) { + "case_sensitive" -> BindingCase.SENSITIVE + "case_insensitive" -> BindingCase.INSENSITIVE + else -> throw EvaluationException( + message = "Invalid case sensitivity: $caseSensitivityParameterValue", + errorCode = ErrorCode.INTERNAL_ERROR, + internal = true + ) + } + + val bindingName = BindingName(variableName, caseSensitivity) + + val globalsFirst = when (val lookupStrategyParameterValue = required[2].stringValue()) { + "locals_then_globals" -> false + "globals_then_locals" -> true + else -> throw EvaluationException( + message = "Invalid lookup strategy: $lookupStrategyParameterValue", + errorCode = ErrorCode.INTERNAL_ERROR, + internal = true + ) + } + + val found = when { + globalsFirst -> { + session.globals[bindingName] ?: searchLocals(variadic, bindingName) + } + else -> { + searchLocals(variadic, bindingName) ?: session.globals[bindingName] + } + } + + if (found == null) { + // We don't know the metas inside ExprFunction implementations. The ThunkFactory error handlers + // should add line & col info to the exception & rethrow anyway. + throwUndefinedVariableException(bindingName, metas = null) + } else { + return found + } + } + + private fun searchLocals(possibleLocations: List, bindingName: BindingName) = + possibleLocations.asSequence().map { + when (it.type) { + ExprValueType.STRUCT -> + it.bindings[bindingName] + else -> + null + } + }.firstOrNull { it != null } +} diff --git a/lang/src/org/partiql/lang/eval/physical/EvaluatorState.kt b/lang/src/org/partiql/lang/eval/physical/EvaluatorState.kt new file mode 100644 index 0000000000..fd9e711226 --- /dev/null +++ b/lang/src/org/partiql/lang/eval/physical/EvaluatorState.kt @@ -0,0 +1,35 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.lang.eval.physical + +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.ExprValue + +/** + * Contains state needed during query evaluation such as an instance of [EvaluationSession] and an array of [registers] + * for each local variable that is part of the query. + * + * Since the elements of [registers] are mutable, when/if we decide to make query execution multi-threaded, we'll have + * to take care to not share [EvaluatorState] instances among different threads. + * + * @param session The evaluation session. + * @param registers An array of registers containing [ExprValue]s needed during query execution. Generally, there is + * one register per local variable. This is an array (and not a [List]) because its semantics match exactly what we + * need: fixed length but mutable elements. + */ +internal class EvaluatorState( + val session: EvaluationSession, + val registers: Array +) diff --git a/lang/src/org/partiql/lang/eval/physical/OffsetLimitHelpers.kt b/lang/src/org/partiql/lang/eval/physical/OffsetLimitHelpers.kt new file mode 100644 index 0000000000..f68ac2d049 --- /dev/null +++ b/lang/src/org/partiql/lang/eval/physical/OffsetLimitHelpers.kt @@ -0,0 +1,106 @@ +package org.partiql.lang.eval.physical + +import com.amazon.ion.IntegerSize +import com.amazon.ion.IonInt +import org.partiql.lang.ast.SourceLocationMeta +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.eval.ExprValueType +import org.partiql.lang.eval.err +import org.partiql.lang.eval.errorContextFrom +import org.partiql.lang.eval.numberValue + +// The functions in this file look very similar and so the temptation to DRY is quite strong.... +// However, there are enough subtle differences between them that avoiding the duplication isn't worth it. + +internal fun evalLimitRowCount(rowCountThunk: PhysicalPlanThunk, env: EvaluatorState, limitLocationMeta: SourceLocationMeta?): Long { + val limitExprValue = rowCountThunk(env) + + if (limitExprValue.type != ExprValueType.INT) { + err( + "LIMIT value was not an integer", + ErrorCode.EVALUATOR_NON_INT_LIMIT_VALUE, + errorContextFrom(limitLocationMeta).also { + it[Property.ACTUAL_TYPE] = limitExprValue.type.toString() + }, + internal = false + ) + } + + // `Number.toLong()` (used below) does *not* cause an overflow exception if the underlying [Number] + // implementation (i.e. Decimal or BigInteger) exceeds the range that can be represented by Longs. + // This can cause very confusing behavior if the user specifies a LIMIT value that exceeds + // Long.MAX_VALUE, because no results will be returned from their query. That no overflow exception + // is thrown is not a problem as long as PartiQL's restriction of integer values to +/- 2^63 remains. + // We throw an exception here if the value exceeds the supported range (say if we change that + // restriction or if a custom [ExprValue] is provided which exceeds that value). + val limitIonValue = limitExprValue.ionValue as IonInt + if (limitIonValue.integerSize == IntegerSize.BIG_INTEGER) { + err( + "IntegerSize.BIG_INTEGER not supported for LIMIT values", + ErrorCode.INTERNAL_ERROR, + errorContextFrom(limitLocationMeta), + internal = true + ) + } + + val limitValue = limitExprValue.numberValue().toLong() + + if (limitValue < 0) { + err( + "negative LIMIT", + ErrorCode.EVALUATOR_NEGATIVE_LIMIT, + errorContextFrom(limitLocationMeta), + internal = false + ) + } + + // we can't use the Kotlin's Sequence.take(n) for this since it accepts only an integer. + // this references [Sequence.take(count: Long): Sequence] defined in [org.partiql.util]. + return limitValue +} + +internal fun evalOffsetRowCount(rowCountThunk: PhysicalPlanThunk, env: EvaluatorState, offsetLocationMeta: SourceLocationMeta?): Long { + val offsetExprValue = rowCountThunk(env) + + if (offsetExprValue.type != ExprValueType.INT) { + err( + "OFFSET value was not an integer", + ErrorCode.EVALUATOR_NON_INT_OFFSET_VALUE, + errorContextFrom(offsetLocationMeta).also { + it[Property.ACTUAL_TYPE] = offsetExprValue.type.toString() + }, + internal = false + ) + } + + // `Number.toLong()` (used below) does *not* cause an overflow exception if the underlying [Number] + // implementation (i.e. Decimal or BigInteger) exceeds the range that can be represented by Longs. + // This can cause very confusing behavior if the user specifies a OFFSET value that exceeds + // Long.MAX_VALUE, because no results will be returned from their query. That no overflow exception + // is thrown is not a problem as long as PartiQL's restriction of integer values to +/- 2^63 remains. + // We throw an exception here if the value exceeds the supported range (say if we change that + // restriction or if a custom [ExprValue] is provided which exceeds that value). + val offsetIonValue = offsetExprValue.ionValue as IonInt + if (offsetIonValue.integerSize == IntegerSize.BIG_INTEGER) { + err( + "IntegerSize.BIG_INTEGER not supported for OFFSET values", + ErrorCode.INTERNAL_ERROR, + errorContextFrom(offsetLocationMeta), + internal = true + ) + } + + val offsetValue = offsetExprValue.numberValue().toLong() + + if (offsetValue < 0) { + err( + "negative OFFSET", + ErrorCode.EVALUATOR_NEGATIVE_OFFSET, + errorContextFrom(offsetLocationMeta), + internal = false + ) + } + + return offsetValue +} diff --git a/lang/src/org/partiql/lang/eval/physical/PhysicalBexprToThunkConverter.kt b/lang/src/org/partiql/lang/eval/physical/PhysicalBexprToThunkConverter.kt new file mode 100644 index 0000000000..d4ceafda0a --- /dev/null +++ b/lang/src/org/partiql/lang/eval/physical/PhysicalBexprToThunkConverter.kt @@ -0,0 +1,335 @@ +package org.partiql.lang.eval.physical + +import com.amazon.ionelement.api.BoolElement +import com.amazon.ionelement.api.MetaContainer +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.ExprValueFactory +import org.partiql.lang.eval.ExprValueType +import org.partiql.lang.eval.Thunk +import org.partiql.lang.eval.ThunkValue +import org.partiql.lang.eval.address +import org.partiql.lang.eval.booleanValue +import org.partiql.lang.eval.isUnknown +import org.partiql.lang.eval.name +import org.partiql.lang.eval.relation.RelationIterator +import org.partiql.lang.eval.relation.RelationScope +import org.partiql.lang.eval.relation.RelationType +import org.partiql.lang.eval.relation.relation +import org.partiql.lang.eval.sourceLocationMeta +import org.partiql.lang.eval.unnamedValue +import org.partiql.lang.util.toIntExact + +private val DEFAULT_IMPL = PartiqlPhysical.build { impl("default") } + +/** A specialization of [Thunk] that we use for evaluation of physical plans. */ +internal typealias PhysicalPlanThunk = Thunk + +/** A specialization of [ThunkValue] that we use for evaluation of physical plans. */ +internal typealias PhysicalPlanThunkValue = ThunkValue + +internal class PhysicalBexprToThunkConverter( + private val exprConverter: PhysicalExprToThunkConverter, + private val valueFactory: ExprValueFactory, +) : PartiqlPhysical.Bexpr.Converter { + + private fun blockNonDefaultImpl(i: PartiqlPhysical.Impl) { + if (i != DEFAULT_IMPL) { + TODO("Support non-default operator implementations") + } + } + + override fun convertProject(node: PartiqlPhysical.Bexpr.Project): RelationThunkEnv { + TODO("not implemented") + } + + override fun convertScan(node: PartiqlPhysical.Bexpr.Scan): RelationThunkEnv { + blockNonDefaultImpl(node.i) + + val exprThunk = exprConverter.convert(node.expr) + val asIndex = node.asDecl.index.value.toIntExact() + val atIndex = node.atDecl?.index?.value?.toIntExact() ?: -1 + val byIndex = node.byDecl?.index?.value?.toIntExact() ?: -1 + + return relationThunk(node.metas) { env -> + val valueToScan = exprThunk.invoke(env) + + // coerces non-collection types to a singleton Sequence<>. + val rows: Sequence = when (valueToScan.type) { + ExprValueType.LIST, ExprValueType.BAG -> valueToScan.asSequence() + else -> sequenceOf(valueToScan) + } + + relation(RelationType.BAG) { + var rowsIter: Iterator = rows.iterator() + while (rowsIter.hasNext()) { + val item = rowsIter.next() + env.registers[asIndex] = item.unnamedValue() // Remove any ordinal (output is a bag) + + if (atIndex >= 0) { + env.registers[atIndex] = item.name ?: valueFactory.missingValue + } + + if (byIndex >= 0) { + env.registers[byIndex] = item.address ?: valueFactory.missingValue + } + yield() + } + } + } + } + + override fun convertFilter(node: PartiqlPhysical.Bexpr.Filter): RelationThunkEnv { + blockNonDefaultImpl(node.i) + + val predicateThunk = exprConverter.convert(node.predicate) + val sourceThunk = this.convert(node.source) + + return relationThunk(node.metas) { env -> + val sourceToFilter = sourceThunk(env) + createFilterRelItr(sourceToFilter, predicateThunk, env) + } + } + + override fun convertJoin(node: PartiqlPhysical.Bexpr.Join): RelationThunkEnv { + blockNonDefaultImpl(node.i) + + val leftThunk = this.convert(node.left) + val rightThunk = this.convert(node.right) + val predicateThunk = node.predicate?.let { exprConverter.convert(it).takeIf { !node.predicate.isLitTrue() } } + + return when (node.joinType) { + is PartiqlPhysical.JoinType.Inner -> { + createInnerJoinThunk(node.metas, leftThunk, rightThunk, predicateThunk) + } + is PartiqlPhysical.JoinType.Left -> { + val rightVariableIndexes = node.right.extractAccessibleVarDecls().map { it.index.value.toIntExact() } + createLeftJoinThunk( + joinMetas = node.metas, + leftThunk = leftThunk, + rightThunk = rightThunk, + rightVariableIndexes = rightVariableIndexes, + predicateThunk = predicateThunk + ) + } + is PartiqlPhysical.JoinType.Right -> { + // Note that this is the same as the left join but the right and left sides are swapped. + val leftVariableIndexes = node.left.extractAccessibleVarDecls().map { it.index.value.toIntExact() } + createLeftJoinThunk( + joinMetas = node.metas, + leftThunk = rightThunk, + rightThunk = leftThunk, + rightVariableIndexes = leftVariableIndexes, + predicateThunk = predicateThunk + ) + } + is PartiqlPhysical.JoinType.Full -> TODO("Full join") + } + } + + private fun createInnerJoinThunk( + joinMetas: MetaContainer, + leftThunk: RelationThunkEnv, + rightThunk: RelationThunkEnv, + predicateThunk: PhysicalPlanThunk? + ) = if (predicateThunk == null) { + relationThunk(joinMetas) { env -> + createCrossJoinRelItr(leftThunk, rightThunk, env) + } + } else { + relationThunk(joinMetas) { env -> + val crossJoinRelItr = createCrossJoinRelItr(leftThunk, rightThunk, env) + createFilterRelItr(crossJoinRelItr, predicateThunk, env) + } + } + + private fun createCrossJoinRelItr( + leftThunk: RelationThunkEnv, + rightThunk: RelationThunkEnv, + env: EvaluatorState + ): RelationIterator { + return relation(RelationType.BAG) { + val leftItr = leftThunk(env) + while (leftItr.nextRow()) { + val rightItr = rightThunk(env) + while (rightItr.nextRow()) { + yield() + } + } + } + } + + private fun createLeftJoinThunk( + joinMetas: MetaContainer, + leftThunk: RelationThunkEnv, + rightThunk: RelationThunkEnv, + rightVariableIndexes: List, + predicateThunk: PhysicalPlanThunk? + ) = + relationThunk(joinMetas) { env -> + createLeftJoinRelItr(leftThunk, rightThunk, rightVariableIndexes, predicateThunk, env) + } + + /** + * Like [createCrossJoinRelItr], but the right-hand relation is padded with unknown values in the event + * that it is empty or that the predicate does not match. + */ + private fun createLeftJoinRelItr( + leftThunk: RelationThunkEnv, + rightThunk: RelationThunkEnv, + rightVariableIndexes: List, + predicateThunk: PhysicalPlanThunk?, + env: EvaluatorState + ): RelationIterator { + return if (predicateThunk == null) { + relation(RelationType.BAG) { + val leftItr = leftThunk(env) + while (leftItr.nextRow()) { + val rightItr = rightThunk(env) + // if the rightItr does has a row... + if (rightItr.nextRow()) { + yield() // yield current row + yieldAll(rightItr) // yield remaining rows + } else { + // no row--yield padded row + yieldPaddedUnknowns(rightVariableIndexes, env) + } + } + } + } else { + relation(RelationType.BAG) { + val leftItr = leftThunk(env) + while (leftItr.nextRow()) { + val rightItr = rightThunk(env) + var yieldedSomething = false + while (rightItr.nextRow()) { + if (coercePredicateResult(predicateThunk(env))) { + yield() + yieldedSomething = true + } + } + // If we still haven't yielded anything, we still need to emit a row with right-hand side variables + // padded with unknowns. + if (!yieldedSomething) { + yieldPaddedUnknowns(rightVariableIndexes, env) + } + } + } + } + } + + private suspend fun RelationScope.yieldPaddedUnknowns( + rightVariableIndexes: List, + env: EvaluatorState + ) { + rightVariableIndexes.forEach { env.registers[it] = valueFactory.nullValue } + yield() + } + + private fun PartiqlPhysical.Bexpr.extractAccessibleVarDecls(): List = + // This fold traverses a [PartiqlPhysical.Bexpr] node and extracts all variable declarations within + // It avoids recursing into sub-queries. + object : PartiqlPhysical.VisitorFold>() { + override fun visitVarDecl( + node: PartiqlPhysical.VarDecl, + accumulator: List + ): List = accumulator + node + + /** + * Avoids recursion into expressions, since these may contain sub-queries with other var-decls that we don't + * care about here. + */ + override fun walkExpr( + node: PartiqlPhysical.Expr, + accumulator: List + ): List { + return accumulator + } + }.walkBexpr(this, emptyList()) + + private fun createFilterRelItr( + relItr: RelationIterator, + predicateThunk: PhysicalPlanThunk, + env: EvaluatorState + ) = relation(RelationType.BAG) { + while (true) { + if (!relItr.nextRow()) { + break + } else { + val matches = predicateThunk(env) + if (coercePredicateResult(matches)) { + yield() + } + } + } + } + + private fun coercePredicateResult(value: ExprValue): Boolean = + when { + value.isUnknown() -> false + else -> value.booleanValue() // <-- throws if [value] is not a boolean. + } + + override fun convertOffset(node: PartiqlPhysical.Bexpr.Offset): RelationThunkEnv { + val rowCountThunk = exprConverter.convert(node.rowCount) + val sourceThunk = this.convert(node.source) + val rowCountLocation = node.rowCount.metas.sourceLocationMeta + return relationThunk(node.metas) { env -> + val skipCount: Long = evalOffsetRowCount(rowCountThunk, env, rowCountLocation) + relation(RelationType.BAG) { + val sourceRel = sourceThunk(env) + var rowCount = 0L + while (rowCount++ < skipCount) { + // stop iterating if we finish run out of rows before we hit the offset. + if (!sourceRel.nextRow()) { + return@relation + } + } + + yieldAll(sourceRel) + } + } + } + + override fun convertLimit(node: PartiqlPhysical.Bexpr.Limit): RelationThunkEnv { + val rowCountThunk = exprConverter.convert(node.rowCount) + val sourceThunk = this.convert(node.source) + val rowCountLocation = node.rowCount.metas.sourceLocationMeta + return relationThunk(node.metas) { env -> + val limitCount = evalLimitRowCount(rowCountThunk, env, rowCountLocation) + val rowIter = sourceThunk(env) + relation(RelationType.BAG) { + var rowCount = 0L + while (rowCount++ < limitCount && rowIter.nextRow()) { + yield() + } + } + } + } + + override fun convertLet(node: PartiqlPhysical.Bexpr.Let): RelationThunkEnv { + val sourceThunk = this.convert(node.source) + class CompiledBinding(val index: Int, val valueThunk: PhysicalPlanThunk) + val compiledBindings = node.bindings.map { + CompiledBinding( + it.decl.index.value.toIntExact(), + exprConverter.convert(it.value) + ) + } + return relationThunk(node.metas) { env -> + val sourceItr = sourceThunk(env) + + relation(sourceItr.relType) { + while (sourceItr.nextRow()) { + compiledBindings.forEach { + env.registers[it.index] = it.valueThunk(env) + } + yield() + } + } + } + } +} + +private fun PartiqlPhysical.Expr.isLitTrue() = + this is PartiqlPhysical.Expr.Lit && this.value is BoolElement && this.value.booleanValue diff --git a/lang/src/org/partiql/lang/eval/physical/PhysicalExprToThunkConverter.kt b/lang/src/org/partiql/lang/eval/physical/PhysicalExprToThunkConverter.kt new file mode 100644 index 0000000000..57f11c12d1 --- /dev/null +++ b/lang/src/org/partiql/lang/eval/physical/PhysicalExprToThunkConverter.kt @@ -0,0 +1,13 @@ +package org.partiql.lang.eval.physical + +import org.partiql.lang.domains.PartiqlPhysical + +/** + * Simple API that defines a method to convert a [PartiqlPhysical.Expr] to a [PhysicalPlanThunk]. + * + * Intended to prevent [PhysicalBexprToThunkConverter] from having to take a direct dependency on + * [org.partiql.lang.eval.EvaluatingCompiler]. + */ +internal interface PhysicalExprToThunkConverter { + fun convert(expr: PartiqlPhysical.Expr): PhysicalPlanThunk +} diff --git a/lang/src/org/partiql/lang/eval/physical/PhysicalExprToThunkConverterImpl.kt b/lang/src/org/partiql/lang/eval/physical/PhysicalExprToThunkConverterImpl.kt new file mode 100644 index 0000000000..9fdeca4d4b --- /dev/null +++ b/lang/src/org/partiql/lang/eval/physical/PhysicalExprToThunkConverterImpl.kt @@ -0,0 +1,1902 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.lang.eval.physical + +import com.amazon.ion.IonString +import com.amazon.ion.IonValue +import com.amazon.ion.Timestamp +import com.amazon.ionelement.api.MetaContainer +import com.amazon.ionelement.api.toIonValue +import org.partiql.lang.ast.SourceLocationMeta +import org.partiql.lang.ast.sourceLocation +import org.partiql.lang.ast.toPartiQlMetaContainer +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.domains.staticType +import org.partiql.lang.domains.toBindingCase +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.errors.PropertyValueMap +import org.partiql.lang.eval.AnyOfCastTable +import org.partiql.lang.eval.Arguments +import org.partiql.lang.eval.BaseExprValue +import org.partiql.lang.eval.BindingName +import org.partiql.lang.eval.CastFunc +import org.partiql.lang.eval.DEFAULT_COMPARATOR +import org.partiql.lang.eval.ErrorDetails +import org.partiql.lang.eval.EvaluationException +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.ExprFunction +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.ExprValueFactory +import org.partiql.lang.eval.ExprValueType +import org.partiql.lang.eval.Expression +import org.partiql.lang.eval.Named +import org.partiql.lang.eval.ProjectionIterationBehavior +import org.partiql.lang.eval.RequiredArgs +import org.partiql.lang.eval.RequiredWithOptional +import org.partiql.lang.eval.RequiredWithVariadic +import org.partiql.lang.eval.SequenceExprValue +import org.partiql.lang.eval.StructOrdering +import org.partiql.lang.eval.ThunkValue +import org.partiql.lang.eval.TypedOpBehavior +import org.partiql.lang.eval.TypingMode +import org.partiql.lang.eval.booleanValue +import org.partiql.lang.eval.builtins.storedprocedure.StoredProcedure +import org.partiql.lang.eval.call +import org.partiql.lang.eval.cast +import org.partiql.lang.eval.compareTo +import org.partiql.lang.eval.createErrorSignaler +import org.partiql.lang.eval.createThunkFactory +import org.partiql.lang.eval.err +import org.partiql.lang.eval.errInvalidArgumentType +import org.partiql.lang.eval.errNoContext +import org.partiql.lang.eval.errorContextFrom +import org.partiql.lang.eval.errorIf +import org.partiql.lang.eval.exprEquals +import org.partiql.lang.eval.fillErrorContext +import org.partiql.lang.eval.isNotUnknown +import org.partiql.lang.eval.isUnknown +import org.partiql.lang.eval.like.parsePattern +import org.partiql.lang.eval.namedValue +import org.partiql.lang.eval.numberValue +import org.partiql.lang.eval.rangeOver +import org.partiql.lang.eval.stringValue +import org.partiql.lang.eval.syntheticColumnName +import org.partiql.lang.eval.time.Time +import org.partiql.lang.eval.unnamedValue +import org.partiql.lang.eval.visitors.PartiqlPhysicalSanityValidator +import org.partiql.lang.planner.EvaluatorOptions +import org.partiql.lang.types.AnyOfType +import org.partiql.lang.types.AnyType +import org.partiql.lang.types.FunctionSignature +import org.partiql.lang.types.IntType +import org.partiql.lang.types.SingleType +import org.partiql.lang.types.StaticType +import org.partiql.lang.types.TypedOpParameter +import org.partiql.lang.types.UnknownArguments +import org.partiql.lang.types.UnsupportedTypeCheckException +import org.partiql.lang.types.toTypedOpParameter +import org.partiql.lang.util.checkThreadInterrupted +import org.partiql.lang.util.codePointSequence +import org.partiql.lang.util.div +import org.partiql.lang.util.isZero +import org.partiql.lang.util.minus +import org.partiql.lang.util.plus +import org.partiql.lang.util.rem +import org.partiql.lang.util.stringValue +import org.partiql.lang.util.times +import org.partiql.lang.util.timestampValue +import org.partiql.lang.util.toIntExact +import org.partiql.lang.util.totalMinutes +import org.partiql.lang.util.unaryMinus +import java.math.BigDecimal +import java.util.LinkedList +import java.util.TreeSet +import java.util.regex.Pattern +import kotlin.collections.ArrayList + +/** + * A basic "compiler" that converts an instance of [PartiqlPhysical.Expr] to an [Expression]. + * + * This is a modified copy of the legacy `EvaluatingCompiler` class, which is now legacy. + * The primary differences between this class an `EvaluatingCompiler` are: + * + * - All references to `PartiqlPhysical` are replaced with `PartiqlPhysical`. + * - `EvaluatingCompiler` compiles "monolithic" SFW queries--this class compiles relational + * operators (in concert with [PhysicalBexprToThunkConverter]). + * + * This implementation produces a "compiled" form consisting of context-threaded + * code in the form of a tree of [PhysicalPlanThunk]s. An overview of this technique can be found + * [here][1]. + * + * **Note:** *threaded* in this context is used in how the code gets *threaded* together for + * interpretation and **not** the concurrency primitive. That is to say this code is NOT thread + * safe. + * + * [1]: https://www.complang.tuwien.ac.at/anton/lvas/sem06w/fest.pdf + */ +internal class PhysicalExprToThunkConverterImpl( + private val valueFactory: ExprValueFactory, + private val functions: Map, + private val customTypedOpParameters: Map, + private val procedures: Map, + private val evaluatorOptions: EvaluatorOptions = EvaluatorOptions.standard() +) : PhysicalExprToThunkConverter { + private val errorSignaler = evaluatorOptions.typingMode.createErrorSignaler(valueFactory) + private val thunkFactory = evaluatorOptions.typingMode.createThunkFactory( + evaluatorOptions.thunkOptions, + valueFactory + ) + + private fun Number.exprValue(): ExprValue = when (this) { + is Int -> valueFactory.newInt(this) + is Long -> valueFactory.newInt(this) + is Double -> valueFactory.newFloat(this) + is BigDecimal -> valueFactory.newDecimal(this) + else -> errNoContext( + "Cannot convert number to expression value: $this", + errorCode = ErrorCode.EVALUATOR_INVALID_CONVERSION, + internal = true + ) + } + + private fun Boolean.exprValue(): ExprValue = valueFactory.newBoolean(this) + private fun String.exprValue(): ExprValue = valueFactory.newString(this) + + /** + * Compiles a [PartiqlPhysical.Statement] tree to an [Expression]. + * + * Checks [Thread.interrupted] before every expression and sub-expression is compiled + * and throws [InterruptedException] if [Thread.interrupted] it has been set in the + * hope that long-running compilations may be aborted by the caller. + */ + fun compile(plan: PartiqlPhysical.Plan): Expression { + PartiqlPhysicalSanityValidator(evaluatorOptions).walkPlan(plan) + + val thunk = compileAstStatement(plan.stmt) + + return object : Expression { + override fun eval(session: EvaluationSession): ExprValue { + val env = EvaluatorState( + session = session, + registers = Array(plan.locals.size) { valueFactory.missingValue } + ) + + return thunk(env) + } + } + } + + override fun convert(expr: PartiqlPhysical.Expr): PhysicalPlanThunk = this.compileAstExpr(expr) + + /** + * Compiles the specified [PartiqlPhysical.Statement] into a [PhysicalPlanThunk]. + * + * This function will [InterruptedException] if [Thread.interrupted] has been set. + */ + private fun compileAstStatement(ast: PartiqlPhysical.Statement): PhysicalPlanThunk { + return when (ast) { + is PartiqlPhysical.Statement.Query -> compileAstExpr(ast.expr) + is PartiqlPhysical.Statement.Exec -> compileExec(ast) + } + } + + private fun compileAstExpr(expr: PartiqlPhysical.Expr): PhysicalPlanThunk { + checkThreadInterrupted() + val metas = expr.metas + + return when (expr) { + is PartiqlPhysical.Expr.Lit -> compileLit(expr, metas) + is PartiqlPhysical.Expr.Missing -> compileMissing(metas) + is PartiqlPhysical.Expr.LocalId -> compileLocalId(expr, metas) + is PartiqlPhysical.Expr.GlobalId -> compileGlobalId(expr) + is PartiqlPhysical.Expr.SimpleCase -> compileSimpleCase(expr, metas) + is PartiqlPhysical.Expr.SearchedCase -> compileSearchedCase(expr, metas) + is PartiqlPhysical.Expr.Path -> compilePath(expr, metas) + is PartiqlPhysical.Expr.Struct -> compileStruct(expr) + is PartiqlPhysical.Expr.CallAgg -> compileCallAgg(expr, metas) + is PartiqlPhysical.Expr.Parameter -> compileParameter(expr, metas) + is PartiqlPhysical.Expr.Date -> compileDate(expr, metas) + is PartiqlPhysical.Expr.LitTime -> compileLitTime(expr, metas) + + // arithmetic operations + is PartiqlPhysical.Expr.Plus -> compilePlus(expr, metas) + is PartiqlPhysical.Expr.Times -> compileTimes(expr, metas) + is PartiqlPhysical.Expr.Minus -> compileMinus(expr, metas) + is PartiqlPhysical.Expr.Divide -> compileDivide(expr, metas) + is PartiqlPhysical.Expr.Modulo -> compileModulo(expr, metas) + + // comparison operators + is PartiqlPhysical.Expr.And -> compileAnd(expr, metas) + is PartiqlPhysical.Expr.Between -> compileBetween(expr, metas) + is PartiqlPhysical.Expr.Eq -> compileEq(expr, metas) + is PartiqlPhysical.Expr.Gt -> compileGt(expr, metas) + is PartiqlPhysical.Expr.Gte -> compileGte(expr, metas) + is PartiqlPhysical.Expr.Lt -> compileLt(expr, metas) + is PartiqlPhysical.Expr.Lte -> compileLte(expr, metas) + is PartiqlPhysical.Expr.Like -> compileLike(expr, metas) + is PartiqlPhysical.Expr.InCollection -> compileIn(expr, metas) + + // logical operators + is PartiqlPhysical.Expr.Ne -> compileNe(expr, metas) + is PartiqlPhysical.Expr.Or -> compileOr(expr, metas) + + // unary + is PartiqlPhysical.Expr.Not -> compileNot(expr, metas) + is PartiqlPhysical.Expr.Pos -> compilePos(expr, metas) + is PartiqlPhysical.Expr.Neg -> compileNeg(expr, metas) + + // other operators + is PartiqlPhysical.Expr.Concat -> compileConcat(expr, metas) + is PartiqlPhysical.Expr.Call -> compileCall(expr, metas) + is PartiqlPhysical.Expr.NullIf -> compileNullIf(expr, metas) + is PartiqlPhysical.Expr.Coalesce -> compileCoalesce(expr, metas) + + // "typed" operators (RHS is a data type and not an expression) + is PartiqlPhysical.Expr.Cast -> compileCast(expr, metas) + is PartiqlPhysical.Expr.IsType -> compileIs(expr, metas) + is PartiqlPhysical.Expr.CanCast -> compileCanCast(expr, metas) + is PartiqlPhysical.Expr.CanLosslessCast -> compileCanLosslessCast(expr, metas) + + // sequence constructors + is PartiqlPhysical.Expr.List -> compileSeq(ExprValueType.LIST, expr.values, metas) + is PartiqlPhysical.Expr.Sexp -> compileSeq(ExprValueType.SEXP, expr.values, metas) + is PartiqlPhysical.Expr.Bag -> compileSeq(ExprValueType.BAG, expr.values, metas) + + // set operators + is PartiqlPhysical.Expr.Intersect, + is PartiqlPhysical.Expr.Union, + is PartiqlPhysical.Expr.Except -> { + err( + "${expr.javaClass.canonicalName} is not yet supported", + ErrorCode.EVALUATOR_FEATURE_NOT_SUPPORTED_YET, + errorContextFrom(metas).also { + it[Property.FEATURE_NAME] = expr.javaClass.canonicalName + }, + internal = false + ) + } + is PartiqlPhysical.Expr.BindingsToValues -> compileBindingsToValues(expr) + } + } + + private fun compileBindingsToValues(expr: PartiqlPhysical.Expr.BindingsToValues): PhysicalPlanThunk { + val mapThunk = compileAstExpr(expr.exp) + val bexprThunk: RelationThunkEnv = PhysicalBexprToThunkConverter(this, thunkFactory.valueFactory) + .convert(expr.query) + + return thunkFactory.thunkEnv(expr.metas) { env -> + val elements = sequence { + val relItr = bexprThunk(env) + while (relItr.nextRow()) { + yield(mapThunk(env)) + } + } + valueFactory.newBag(elements) + } + } + + private fun compileAstExprs(args: List) = args.map { compileAstExpr(it) } + + private fun compileNullIf(expr: PartiqlPhysical.Expr.NullIf, metas: MetaContainer): PhysicalPlanThunk { + val expr1Thunk = compileAstExpr(expr.expr1) + val expr2Thunk = compileAstExpr(expr.expr2) + + // Note: NULLIF does not propagate the unknown values and .exprEquals provides the correct semantics. + return thunkFactory.thunkEnv(metas) { env -> + val expr1Value = expr1Thunk(env) + val expr2Value = expr2Thunk(env) + when { + expr1Value.exprEquals(expr2Value) -> valueFactory.nullValue + else -> expr1Value + } + } + } + + private fun compileCoalesce(expr: PartiqlPhysical.Expr.Coalesce, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.args) + + return thunkFactory.thunkEnv(metas) { env -> + var nullFound = false + var knownValue: ExprValue? = null + for (thunk in argThunks) { + val argValue = thunk(env) + if (argValue.isNotUnknown()) { + knownValue = argValue + // No need to execute remaining thunks to save computation as first non-unknown value is found + break + } + if (argValue.type == ExprValueType.NULL) { + nullFound = true + } + } + when (knownValue) { + null -> when { + evaluatorOptions.typingMode == TypingMode.PERMISSIVE && !nullFound -> valueFactory.missingValue + else -> valueFactory.nullValue + } + else -> knownValue + } + } + } + + /** + * Returns a function that accepts an [ExprValue] as an argument and returns true it is `NULL`, `MISSING`, or + * within the range specified by [range]. + */ + private fun integerValueValidator( + range: LongRange + ): (ExprValue) -> Boolean = { value -> + when (value.type) { + ExprValueType.NULL, ExprValueType.MISSING -> true + ExprValueType.INT -> { + val longValue: Long = value.scalar.numberValue()?.toLong() + ?: error( + "ExprValue.numberValue() must not be `NULL` when its type is INT." + + "This indicates that the ExprValue instance has a bug." + ) + + // PRO-TIP: make sure to use the `Long` primitive type here with `.contains` otherwise + // Kotlin will use the version of `.contains` that treats [range] as a collection, and it will + // be very slow! + range.contains(longValue) + } + else -> error( + "The expression's static type was supposed to be INT but instead it was ${value.type}" + + "This may indicate the presence of a bug in the type inferencer." + ) + } + } + + /** + * For operators which could return integer type, check integer overflow in case of [TypingMode.PERMISSIVE]. + */ + private fun resolveIntConstraint(computeThunk: PhysicalPlanThunk, metas: MetaContainer): PhysicalPlanThunk = + when (val staticTypes = metas.staticType?.type?.getTypes()) { + // No staticType, can't validate integer size. + null -> computeThunk + else -> { + when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> { + // integer size constraints have not been tested under [TypingMode.LEGACY] because the + // [StaticTypeInferenceVisitorTransform] doesn't support being used with legacy mode yet. + // throw an exception in case we encounter this untested scenario. This might work fine, but I + // wouldn't bet on it. + val hasConstrainedInteger = staticTypes.any { + it is IntType && it.rangeConstraint != IntType.IntRangeConstraint.UNCONSTRAINED + } + if (hasConstrainedInteger) { + TODO("Legacy mode doesn't support integer size constraints yet.") + } else { + computeThunk + } + } + TypingMode.PERMISSIVE -> { + val biggestIntegerType = staticTypes.filterIsInstance().maxByOrNull { + it.rangeConstraint.numBytes + } + when (biggestIntegerType) { + is IntType -> { + val validator = integerValueValidator(biggestIntegerType.rangeConstraint.validRange) + + thunkFactory.thunkEnv(metas) { env -> + val naryResult = computeThunk(env) + errorSignaler.errorIf( + !validator(naryResult), + ErrorCode.EVALUATOR_INTEGER_OVERFLOW, + { ErrorDetails(metas, "Integer overflow", errorContextFrom(metas)) }, + { naryResult } + ) + } + } + // If there is no IntType StaticType, can't validate the integer size either. + null -> computeThunk + else -> computeThunk + } + } + } + } + } + + private fun compilePlus(expr: PartiqlPhysical.Expr.Plus, metas: MetaContainer): PhysicalPlanThunk { + if (expr.operands.size < 2) { + error("Internal Error: PartiqlPhysical.Expr.Plus must have at least 2 arguments") + } + + val argThunks = compileAstExprs(expr.operands) + + val computeThunk = thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + (lValue.numberValue() + rValue.numberValue()).exprValue() + } + + return resolveIntConstraint(computeThunk, metas) + } + + private fun compileMinus(expr: PartiqlPhysical.Expr.Minus, metas: MetaContainer): PhysicalPlanThunk { + if (expr.operands.size < 2) { + error("Internal Error: PartiqlPhysical.Expr.Minus must have at least 2 arguments") + } + + val argThunks = compileAstExprs(expr.operands) + + val computeThunk = thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + (lValue.numberValue() - rValue.numberValue()).exprValue() + } + + return resolveIntConstraint(computeThunk, metas) + } + + private fun compilePos(expr: PartiqlPhysical.Expr.Pos, metas: MetaContainer): PhysicalPlanThunk { + val exprThunk = compileAstExpr(expr.expr) + + val computeThunk = thunkFactory.thunkEnvOperands(metas, exprThunk) { _, value -> + // Invoking .numberValue() here makes this essentially just a type check + value.numberValue() + // Original value is returned unmodified. + value + } + + return resolveIntConstraint(computeThunk, metas) + } + + private fun compileNeg(expr: PartiqlPhysical.Expr.Neg, metas: MetaContainer): PhysicalPlanThunk { + val exprThunk = compileAstExpr(expr.expr) + + val computeThunk = thunkFactory.thunkEnvOperands(metas, exprThunk) { _, value -> + (-value.numberValue()).exprValue() + } + + return resolveIntConstraint(computeThunk, metas) + } + + private fun compileTimes(expr: PartiqlPhysical.Expr.Times, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + val computeThunk = thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + (lValue.numberValue() * rValue.numberValue()).exprValue() + } + + return resolveIntConstraint(computeThunk, metas) + } + + private fun compileDivide(expr: PartiqlPhysical.Expr.Divide, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + val computeThunk = thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + val denominator = rValue.numberValue() + + errorSignaler.errorIf( + denominator.isZero(), + ErrorCode.EVALUATOR_DIVIDE_BY_ZERO, + { ErrorDetails(metas, "/ by zero") } + ) { + try { + (lValue.numberValue() / denominator).exprValue() + } catch (e: ArithmeticException) { + // Setting the internal flag as true as it is not clear what + // ArithmeticException may be thrown by the above + throw EvaluationException( + cause = e, + errorCode = ErrorCode.EVALUATOR_ARITHMETIC_EXCEPTION, + internal = true + ) + } + } + } + + return resolveIntConstraint(computeThunk, metas) + } + + private fun compileModulo(expr: PartiqlPhysical.Expr.Modulo, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + val computeThunk = thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + val denominator = rValue.numberValue() + if (denominator.isZero()) { + err("% by zero", ErrorCode.EVALUATOR_MODULO_BY_ZERO, errorContext = null, internal = false) + } + + (lValue.numberValue() % denominator).exprValue() + } + + return resolveIntConstraint(computeThunk, metas) + } + + private fun compileEq(expr: PartiqlPhysical.Expr.Eq, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkAndMap(metas, argThunks) { lValue, rValue -> + (lValue.exprEquals(rValue)) + } + } + + private fun compileNe(expr: PartiqlPhysical.Expr.Ne, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + ((!lValue.exprEquals(rValue)).exprValue()) + } + } + + private fun compileLt(expr: PartiqlPhysical.Expr.Lt, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkAndMap(metas, argThunks) { lValue, rValue -> lValue < rValue } + } + + private fun compileLte(expr: PartiqlPhysical.Expr.Lte, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkAndMap(metas, argThunks) { lValue, rValue -> lValue <= rValue } + } + + private fun compileGt(expr: PartiqlPhysical.Expr.Gt, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkAndMap(metas, argThunks) { lValue, rValue -> lValue > rValue } + } + + private fun compileGte(expr: PartiqlPhysical.Expr.Gte, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkAndMap(metas, argThunks) { lValue, rValue -> lValue >= rValue } + } + + private fun compileBetween(expr: PartiqlPhysical.Expr.Between, metas: MetaContainer): PhysicalPlanThunk { + val valueThunk = compileAstExpr(expr.value) + val fromThunk = compileAstExpr(expr.from) + val toThunk = compileAstExpr(expr.to) + + return thunkFactory.thunkEnvOperands(metas, valueThunk, fromThunk, toThunk) { _, v, f, t -> + (v >= f && v <= t).exprValue() + } + } + + /** + * `IN` can *almost* be thought of has being syntactic sugar for the `OR` operator. + * + * `a IN (b, c, d)` is equivalent to `a = b OR a = c OR a = d`. On deep inspection, there + * are important implications to this regarding propagation of unknown values. Specifically, the + * presence of any unknown in `b`, `c`, or `d` will result in unknown propagation iif `a` does not + * equal `b`, `c`, or `d`. i.e.: + * + * - `1 in (null, 2, 3)` -> `null` + * - `2 in (null, 2, 3)` -> `true` + * - `2 in (1, 2, 3)` -> `true` + * - `0 in (1, 2, 4)` -> `false` + * + * `IN` is varies from the `OR` operator in that this behavior holds true when other types of expressions are + * used on the right side of `IN` such as sub-queries and variables whose value is that of a list or bag. + */ + private fun compileIn(expr: PartiqlPhysical.Expr.InCollection, metas: MetaContainer): PhysicalPlanThunk { + val args = expr.operands + val leftThunk = compileAstExpr(args[0]) + val rightOp = args[1] + + fun isOptimizedCase(values: List): Boolean = values.all { it is PartiqlPhysical.Expr.Lit && !it.value.isNull } + + fun optimizedCase(values: List): PhysicalPlanThunk { + // Put all the literals in the sequence into a pre-computed map to be checked later by the thunk. + // If the left-hand value is one of these we can short-circuit with a result of TRUE. + // This is the fastest possible case and allows for hundreds of literal values (or more) in the + // sequence without a huge performance penalty. + // NOTE: we cannot use a [HashSet<>] here because [ExprValue] does not implement [Object.hashCode] or + // [Object.equals]. + val precomputedLiteralsMap = values + .filterIsInstance() + .mapTo(TreeSet(DEFAULT_COMPARATOR)) { + valueFactory.newFromIonValue( + it.value.toIonValue( + valueFactory.ion + ) + ) + } + + // the compiled thunk simply checks if the left side is contained on the right side. + // thunkEnvOperands takes care of unknown propagation for the left side; for the right, + // this unknown propagation does not apply since we've eliminated the possibility of unknowns above. + return thunkFactory.thunkEnvOperands(metas, leftThunk) { _, leftValue -> + precomputedLiteralsMap.contains(leftValue).exprValue() + } + } + + return when { + // We can significantly optimize this if rightArg is a sequence constructor which is composed of entirely + // of non-null literal values. + rightOp is PartiqlPhysical.Expr.List && isOptimizedCase(rightOp.values) -> optimizedCase(rightOp.values) + rightOp is PartiqlPhysical.Expr.Bag && isOptimizedCase(rightOp.values) -> optimizedCase(rightOp.values) + rightOp is PartiqlPhysical.Expr.Sexp && isOptimizedCase(rightOp.values) -> optimizedCase(rightOp.values) + // The unoptimized case... + else -> { + val rightThunk = compileAstExpr(rightOp) + + // Legacy mode: + // Returns FALSE when the right side of IN is not a sequence + // Returns NULL if the right side is MISSING or any value on the right side is MISSING + // Permissive mode: + // Returns MISSING when the right side of IN is not a sequence + // Returns MISSING if the right side is MISSING or any value on the right side is MISSING + val (propagateMissingAs, propagateNotASeqAs) = with(valueFactory) { + when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> nullValue to newBoolean(false) + TypingMode.PERMISSIVE -> missingValue to missingValue + } + } + + // Note that standard unknown propagation applies to the left and right operands. Both [TypingMode]s + // are handled by [ThunkFactory.thunkEnvOperands] and that additional rules for unknown propagation are + // implemented within the thunk for the values within the sequence on the right side of IN. + thunkFactory.thunkEnvOperands(metas, leftThunk, rightThunk) { _, leftValue, rightValue -> + var nullSeen = false + var missingSeen = false + + when { + rightValue.type == ExprValueType.MISSING -> propagateMissingAs + !rightValue.type.isSequence -> propagateNotASeqAs + else -> { + rightValue.forEach { + when (it.type) { + ExprValueType.NULL -> nullSeen = true + ExprValueType.MISSING -> missingSeen = true + // short-circuit to TRUE on the first matching value + else -> if (it.exprEquals(leftValue)) { + return@thunkEnvOperands valueFactory.newBoolean(true) + } + } + } + // If we make it here then there was no match. Propagate MISSING, NULL or return false. + // Note that if both MISSING and NULL was encountered, MISSING takes precedence. + when { + missingSeen -> propagateMissingAs + nullSeen -> valueFactory.nullValue + else -> valueFactory.newBoolean(false) + } + } + } + } + } + } + } + + private fun compileNot(expr: PartiqlPhysical.Expr.Not, metas: MetaContainer): PhysicalPlanThunk { + val argThunk = compileAstExpr(expr.expr) + + return thunkFactory.thunkEnvOperands(metas, argThunk) { _, value -> + (!value.booleanValue()).exprValue() + } + } + + private fun compileAnd(expr: PartiqlPhysical.Expr.And, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + // can't use the null propagation supplied by [ThunkFactory.thunkEnv] here because AND short-circuits on + // false values and *NOT* on NULL or MISSING + return when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> thunkFactory.thunkEnv(metas) thunk@{ env -> + var hasUnknowns = false + argThunks.forEach { currThunk -> + val currValue = currThunk(env) + when { + currValue.isUnknown() -> hasUnknowns = true + // Short circuit only if we encounter a known false value. + !currValue.booleanValue() -> return@thunk valueFactory.newBoolean(false) + } + } + + when (hasUnknowns) { + true -> valueFactory.nullValue + false -> valueFactory.newBoolean(true) + } + } + TypingMode.PERMISSIVE -> thunkFactory.thunkEnv(metas) thunk@{ env -> + var hasNull = false + var hasMissing = false + argThunks.forEach { currThunk -> + val currValue = currThunk(env) + when (currValue.type) { + // Short circuit only if we encounter a known false value. + ExprValueType.BOOL -> if (!currValue.booleanValue()) return@thunk valueFactory.newBoolean(false) + ExprValueType.NULL -> hasNull = true + // type mismatch, return missing + else -> hasMissing = true + } + } + + when { + hasMissing -> valueFactory.missingValue + hasNull -> valueFactory.nullValue + else -> valueFactory.newBoolean(true) + } + } + } + } + + private fun compileOr(expr: PartiqlPhysical.Expr.Or, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + // can't use the null propagation supplied by [ThunkFactory.thunkEnv] here because OR short-circuits on + // true values and *NOT* on NULL or MISSING + return when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> + thunkFactory.thunkEnv(metas) thunk@{ env -> + var hasUnknowns = false + argThunks.forEach { currThunk -> + val currValue = currThunk(env) + // How null-propagation works for OR is rather weird according to the SQL-92 spec. + // Nulls are propagated like other expressions only when none of the terms are TRUE. + // If any one of them is TRUE, then the entire expression evaluates to TRUE, i.e.: + // NULL OR TRUE -> TRUE + // NULL OR FALSE -> NULL + // (strange but true) + when { + currValue.isUnknown() -> hasUnknowns = true + currValue.booleanValue() -> return@thunk valueFactory.newBoolean(true) + } + } + + when (hasUnknowns) { + true -> valueFactory.nullValue + false -> valueFactory.newBoolean(false) + } + } + TypingMode.PERMISSIVE -> thunkFactory.thunkEnv(metas) thunk@{ env -> + var hasNull = false + var hasMissing = false + argThunks.forEach { currThunk -> + val currValue = currThunk(env) + when (currValue.type) { + // Short circuit only if we encounter a known true value. + ExprValueType.BOOL -> if (currValue.booleanValue()) return@thunk valueFactory.newBoolean(true) + ExprValueType.NULL -> hasNull = true + else -> hasMissing = true // type mismatch, return missing. + } + } + + when { + hasMissing -> valueFactory.missingValue + hasNull -> valueFactory.nullValue + else -> valueFactory.newBoolean(false) + } + } + } + } + + private fun compileConcat(expr: PartiqlPhysical.Expr.Concat, metas: MetaContainer): PhysicalPlanThunk { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + val lType = lValue.type + val rType = rValue.type + + if (lType.isText && rType.isText) { + // null/missing propagation is handled before getting here + (lValue.stringValue() + rValue.stringValue()).exprValue() + } else { + err( + "Wrong argument type for ||", + ErrorCode.EVALUATOR_CONCAT_FAILED_DUE_TO_INCOMPATIBLE_TYPE, + errorContextFrom(metas).also { + it[Property.ACTUAL_ARGUMENT_TYPES] = listOf(lType, rType).toString() + }, + internal = false + ) + } + } + } + + private fun compileCall(expr: PartiqlPhysical.Expr.Call, metas: MetaContainer): PhysicalPlanThunk { + val funcArgThunks = compileAstExprs(expr.args) + val func = functions[expr.funcName.text] ?: err( + "No such function: ${expr.funcName.text}", + ErrorCode.EVALUATOR_NO_SUCH_FUNCTION, + errorContextFrom(metas).also { + it[Property.FUNCTION_NAME] = expr.funcName.text + }, + internal = false + ) + + // Check arity + if (funcArgThunks.size !in func.signature.arity) { + val errorContext = errorContextFrom(metas).also { + it[Property.FUNCTION_NAME] = func.signature.name + it[Property.EXPECTED_ARITY_MIN] = func.signature.arity.first + it[Property.EXPECTED_ARITY_MAX] = func.signature.arity.last + it[Property.ACTUAL_ARITY] = funcArgThunks.size + } + + val message = when { + func.signature.arity.first == 1 && func.signature.arity.last == 1 -> + "${func.signature.name} takes a single argument, received: ${funcArgThunks.size}" + func.signature.arity.first == func.signature.arity.last -> + "${func.signature.name} takes exactly ${func.signature.arity.first} arguments, received: ${funcArgThunks.size}" + else -> + "${func.signature.name} takes between ${func.signature.arity.first} and " + + "${func.signature.arity.last} arguments, received: ${funcArgThunks.size}" + } + + throw EvaluationException( + message, + ErrorCode.EVALUATOR_INCORRECT_NUMBER_OF_ARGUMENTS_TO_FUNC_CALL, + errorContext, + internal = false + ) + } + + fun checkArgumentTypes(signature: FunctionSignature, args: List): Arguments { + fun checkArgumentType(formalStaticType: StaticType, actualArg: ExprValue, position: Int) { + val formalExprValueTypeDomain = formalStaticType.typeDomain + + val actualExprValueType = actualArg.type + val actualStaticType = StaticType.fromExprValue(actualArg) + + if (!actualStaticType.isSubTypeOf(formalStaticType)) { + errInvalidArgumentType( + signature = signature, + position = position, + expectedTypes = formalExprValueTypeDomain.toList(), + actualType = actualExprValueType + ) + } + } + + val required = args.take(signature.requiredParameters.size) + val rest = args.drop(signature.requiredParameters.size) + + signature.requiredParameters.zip(required).forEachIndexed { idx, (expected, actual) -> + checkArgumentType(expected, actual, idx + 1) + } + + return if (signature.optionalParameter != null && rest.isNotEmpty()) { + val opt = rest.last() + checkArgumentType(signature.optionalParameter, opt, required.size + 1) + RequiredWithOptional(required, opt) + } else if (signature.variadicParameter != null) { + rest.forEachIndexed { idx, arg -> + checkArgumentType(signature.variadicParameter.type, arg, required.size + 1 + idx) + } + RequiredWithVariadic(required, rest) + } else { + RequiredArgs(required) + } + } + + val computeThunk = when (func.signature.unknownArguments) { + UnknownArguments.PROPAGATE -> thunkFactory.thunkEnvOperands(metas, funcArgThunks) { env, values -> + val checkedArgs = checkArgumentTypes(func.signature, values) + func.call(env.session, checkedArgs) + } + UnknownArguments.PASS_THRU -> thunkFactory.thunkEnv(metas) { env -> + val funcArgValues = funcArgThunks.map { it(env) } + val checkedArgs = checkArgumentTypes(func.signature, funcArgValues) + func.call(env.session, checkedArgs) + } + } + + return resolveIntConstraint(computeThunk, metas) + } + + private fun compileLit(expr: PartiqlPhysical.Expr.Lit, metas: MetaContainer): PhysicalPlanThunk { + val value = valueFactory.newFromIonValue(expr.value.toIonValue(valueFactory.ion)) + + return thunkFactory.thunkEnv(metas) { value } + } + + private fun compileMissing(metas: MetaContainer): PhysicalPlanThunk = + thunkFactory.thunkEnv(metas) { valueFactory.missingValue } + + private fun compileGlobalId(expr: PartiqlPhysical.Expr.GlobalId): PhysicalPlanThunk { + val bindingCase = expr.case.toBindingCase() + return thunkFactory.thunkEnv(expr.metas) { env -> + val bindingName = BindingName(expr.uniqueId.text, bindingCase) + env.session.globals[bindingName] ?: throwUndefinedVariableException(bindingName, expr.metas) + } + } + + @Suppress("UNUSED_PARAMETER") + private fun compileLocalId(expr: PartiqlPhysical.Expr.LocalId, metas: MetaContainer): PhysicalPlanThunk { + val localIndex = expr.index.value.toIntExact() + return thunkFactory.thunkEnv(metas) { env -> + env.registers[localIndex] + } + } + + private fun compileParameter(expr: PartiqlPhysical.Expr.Parameter, metas: MetaContainer): PhysicalPlanThunk { + val ordinal = expr.index.value.toInt() + val index = ordinal - 1 + + return { env -> + val params = env.session.parameters + if (params.size <= index) { + throw EvaluationException( + "Unbound parameter for ordinal: $ordinal", + ErrorCode.EVALUATOR_UNBOUND_PARAMETER, + errorContextFrom(metas).also { + it[Property.EXPECTED_PARAMETER_ORDINAL] = ordinal + it[Property.BOUND_PARAMETER_COUNT] = params.size + }, + internal = false + ) + } + params[index] + } + } + + /** + * Returns a lambda that implements the `IS` operator type check according to the current + * [TypedOpBehavior]. + */ + private fun makeIsCheck( + staticType: SingleType, + typedOpParameter: TypedOpParameter, + metas: MetaContainer + ): (ExprValue) -> Boolean { + val exprValueType = staticType.runtimeType + + // The "simple" type match function only looks at the [ExprValueType] of the [ExprValue] + // and invokes the custom [validationThunk] if one exists. + val simpleTypeMatchFunc = { expValue: ExprValue -> + val isTypeMatch = when (exprValueType) { + // MISSING IS NULL and NULL IS MISSING + ExprValueType.NULL -> expValue.type.isUnknown + else -> expValue.type == exprValueType + } + (isTypeMatch && typedOpParameter.validationThunk?.let { it(expValue) } != false) + } + + return when (evaluatorOptions.typedOpBehavior) { + TypedOpBehavior.LEGACY -> simpleTypeMatchFunc + TypedOpBehavior.HONOR_PARAMETERS -> { expValue: ExprValue -> + staticType.allTypes.any { + val matchesStaticType = try { + it.isInstance(expValue) + } catch (e: UnsupportedTypeCheckException) { + err( + e.message!!, + ErrorCode.UNIMPLEMENTED_FEATURE, + errorContextFrom(metas), + internal = true + ) + } + + when { + !matchesStaticType -> false + else -> when (val validator = typedOpParameter.validationThunk) { + null -> true + else -> validator(expValue) + } + } + } + } + } + } + + private fun compileIs(expr: PartiqlPhysical.Expr.IsType, metas: MetaContainer): PhysicalPlanThunk { + val expThunk = compileAstExpr(expr.value) + val typedOpParameter = expr.type.toTypedOpParameter(customTypedOpParameters) + if (typedOpParameter.staticType is AnyType) { + return thunkFactory.thunkEnv(metas) { valueFactory.newBoolean(true) } + } + if (evaluatorOptions.typedOpBehavior == TypedOpBehavior.HONOR_PARAMETERS && expr.type is PartiqlPhysical.Type.FloatType && expr.type.precision != null) { + err( + "FLOAT precision parameter is unsupported", + ErrorCode.SEMANTIC_FLOAT_PRECISION_UNSUPPORTED, + errorContextFrom(expr.type.metas), + internal = false + ) + } + + val typeMatchFunc = when (val staticType = typedOpParameter.staticType) { + is SingleType -> makeIsCheck(staticType, typedOpParameter, metas) + is AnyOfType -> staticType.types.map { childType -> + when (childType) { + is SingleType -> makeIsCheck(childType, typedOpParameter, metas) + else -> err( + "Union type cannot have ANY or nested AnyOf type for IS", + ErrorCode.SEMANTIC_UNION_TYPE_INVALID, + errorContextFrom(metas), + internal = true + ) + } + }.let { typeMatchFuncs -> + { expValue: ExprValue -> typeMatchFuncs.any { func -> func(expValue) } } + } + is AnyType -> throw IllegalStateException("Unexpected ANY type in IS compilation") + } + + return thunkFactory.thunkEnv(metas) { env -> + val expValue = expThunk(env) + typeMatchFunc(expValue).exprValue() + } + } + + private fun compileCastHelper(value: PartiqlPhysical.Expr, asType: PartiqlPhysical.Type, metas: MetaContainer): PhysicalPlanThunk { + val expThunk = compileAstExpr(value) + val typedOpParameter = asType.toTypedOpParameter(customTypedOpParameters) + if (typedOpParameter.staticType is AnyType) { + return expThunk + } + if (evaluatorOptions.typedOpBehavior == TypedOpBehavior.HONOR_PARAMETERS && asType is PartiqlPhysical.Type.FloatType && asType.precision != null) { + err( + "FLOAT precision parameter is unsupported", + ErrorCode.SEMANTIC_FLOAT_PRECISION_UNSUPPORTED, + errorContextFrom(asType.metas), + internal = false + ) + } + + fun typeOpValidate( + value: ExprValue, + castOutput: ExprValue, + typeName: String, + locationMeta: SourceLocationMeta? + ) { + if (typedOpParameter.validationThunk?.let { it(castOutput) } == false) { + val errorContext = PropertyValueMap().also { + it[Property.CAST_FROM] = value.type.toString() + it[Property.CAST_TO] = typeName + } + + locationMeta?.let { fillErrorContext(errorContext, it) } + + throw EvaluationException( + "Validation failure for $asType", + ErrorCode.EVALUATOR_CAST_FAILED, + errorContext, + internal = false + ) + } + } + + fun singleTypeCastFunc(singleType: SingleType): CastFunc { + val locationMeta = metas.sourceLocationMeta + return { value -> + val castOutput = value.cast( + singleType, + valueFactory, + evaluatorOptions.typedOpBehavior, + locationMeta, + evaluatorOptions.defaultTimezoneOffset + ) + typeOpValidate(value, castOutput, singleType.runtimeType.toString(), locationMeta) + castOutput + } + } + + fun compileSingleTypeCast(singleType: SingleType): PhysicalPlanThunk { + val castFunc = singleTypeCastFunc(singleType) + // We do not use thunkFactory here because we want to explicitly avoid + // the optional evaluation-time type check for CAN_CAST below. + // Can cast needs that returns false if an + // exception is thrown during a normal cast operation. + return { env -> + val valueToCast = expThunk(env) + castFunc(valueToCast) + } + } + + fun compileCast(type: StaticType): PhysicalPlanThunk = when (type) { + is SingleType -> compileSingleTypeCast(type) + is AnyOfType -> { + val locationMeta = metas.sourceLocationMeta + val castTable = AnyOfCastTable(type, metas, valueFactory, ::singleTypeCastFunc); + + // We do not use thunkFactory here because we want to explicitly avoid + // the optional evaluation-time type check for CAN_CAST below. + // note that this would interfere with the error handling for can_cast that returns false if an + // exception is thrown during a normal cast operation. + { env -> + val sourceValue = expThunk(env) + castTable.cast(sourceValue).also { + // TODO put the right type name here + typeOpValidate(sourceValue, it, "", locationMeta) + } + } + } + is AnyType -> throw IllegalStateException("Unreachable code") + } + + return compileCast(typedOpParameter.staticType) + } + + private fun compileCast(expr: PartiqlPhysical.Expr.Cast, metas: MetaContainer): PhysicalPlanThunk = + thunkFactory.thunkEnv(metas, compileCastHelper(expr.value, expr.asType, metas)) + + private fun compileCanCast(expr: PartiqlPhysical.Expr.CanCast, metas: MetaContainer): PhysicalPlanThunk { + val typedOpParameter = expr.asType.toTypedOpParameter(customTypedOpParameters) + if (typedOpParameter.staticType is AnyType) { + return thunkFactory.thunkEnv(metas) { valueFactory.newBoolean(true) } + } + + val expThunk = compileAstExpr(expr.value) + + // TODO consider making this more efficient by not directly delegating to CAST + // TODO consider also making the operand not double evaluated (e.g. having expThunk memoize) + val castThunkEnv = compileCastHelper(expr.value, expr.asType, expr.metas) + return thunkFactory.thunkEnv(metas) { env -> + val sourceValue = expThunk(env) + try { + when { + // NULL/MISSING can cast to anything as themselves + sourceValue.isUnknown() -> valueFactory.newBoolean(true) + else -> { + val castedValue = castThunkEnv(env) + when { + // NULL/MISSING from cast is a permissive way to signal failure + castedValue.isUnknown() -> valueFactory.newBoolean(false) + else -> valueFactory.newBoolean(true) + } + } + } + } catch (e: EvaluationException) { + if (e.internal) { + throw e + } + valueFactory.newBoolean(false) + } + } + } + + private fun compileCanLosslessCast(expr: PartiqlPhysical.Expr.CanLosslessCast, metas: MetaContainer): PhysicalPlanThunk { + val typedOpParameter = expr.asType.toTypedOpParameter(customTypedOpParameters) + if (typedOpParameter.staticType is AnyType) { + return thunkFactory.thunkEnv(metas) { valueFactory.newBoolean(true) } + } + + val expThunk = compileAstExpr(expr.value) + + // TODO consider making this more efficient by not directly delegating to CAST + val castThunkEnv = compileCastHelper(expr.value, expr.asType, expr.metas) + return thunkFactory.thunkEnv(metas) { env -> + val sourceValue = expThunk(env) + val sourceType = StaticType.fromExprValue(sourceValue) + + fun roundTrip(): ExprValue { + val castedValue = castThunkEnv(env) + + val locationMeta = metas.sourceLocationMeta + fun castFunc(singleType: SingleType) = + { value: ExprValue -> + value.cast( + singleType, + valueFactory, + evaluatorOptions.typedOpBehavior, + locationMeta, + evaluatorOptions.defaultTimezoneOffset + ) + } + + val roundTripped = when (sourceType) { + is SingleType -> castFunc(sourceType)(castedValue) + is AnyOfType -> { + val castTable = AnyOfCastTable(sourceType, metas, valueFactory, ::castFunc) + castTable.cast(sourceValue) + } + // Should not be possible + is AnyType -> throw IllegalStateException("ANY type is not configured correctly in compiler") + } + + val lossless = sourceValue.exprEquals(roundTripped) + return valueFactory.newBoolean(lossless) + } + + try { + when (sourceValue.type) { + // NULL can cast to anything as itself + ExprValueType.NULL -> valueFactory.newBoolean(true) + + // Short-circuit timestamp -> date roundtrip if precision isn't [Timestamp.Precision.DAY] or + // [Timestamp.Precision.MONTH] or [Timestamp.Precision.YEAR] + ExprValueType.TIMESTAMP -> when (typedOpParameter.staticType) { + StaticType.DATE -> when (sourceValue.ionValue.timestampValue().precision) { + Timestamp.Precision.DAY, Timestamp.Precision.MONTH, Timestamp.Precision.YEAR -> roundTrip() + else -> valueFactory.newBoolean(false) + } + StaticType.TIME -> valueFactory.newBoolean(false) + else -> roundTrip() + } + + // For all other cases, attempt a round-trip of the value through the source and dest types + else -> roundTrip() + } + } catch (e: EvaluationException) { + if (e.internal) { + throw e + } + valueFactory.newBoolean(false) + } + } + } + + private fun compileSimpleCase(expr: PartiqlPhysical.Expr.SimpleCase, metas: MetaContainer): PhysicalPlanThunk { + val valueThunk = compileAstExpr(expr.expr) + val branchThunks = expr.cases.pairs.map { Pair(compileAstExpr(it.first), compileAstExpr(it.second)) } + val elseThunk = when (expr.default) { + null -> thunkFactory.thunkEnv(metas) { valueFactory.nullValue } + else -> compileAstExpr(expr.default) + } + + return thunkFactory.thunkEnv(metas) thunk@{ env -> + val caseValue = valueThunk(env) + // if the case value is unknown then we can short-circuit to the elseThunk directly + when { + caseValue.isUnknown() -> elseThunk(env) + else -> { + branchThunks.forEach { bt -> + val branchValue = bt.first(env) + // Just skip any branch values that are unknown, which we consider the same as false here. + when { + branchValue.isUnknown() -> { /* intentionally blank */ + } + else -> { + if (caseValue.exprEquals(branchValue)) { + return@thunk bt.second(env) + } + } + } + } + } + } + elseThunk(env) + } + } + + private fun compileSearchedCase(expr: PartiqlPhysical.Expr.SearchedCase, metas: MetaContainer): PhysicalPlanThunk { + val branchThunks = expr.cases.pairs.map { compileAstExpr(it.first) to compileAstExpr(it.second) } + val elseThunk = when (expr.default) { + null -> thunkFactory.thunkEnv(metas) { valueFactory.nullValue } + else -> compileAstExpr(expr.default) + } + + return when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> thunkFactory.thunkEnv(metas) thunk@{ env -> + branchThunks.forEach { bt -> + val conditionValue = bt.first(env) + // Any unknown value is considered the same as false. + // Note that .booleanValue() here will throw an EvaluationException if + // the data type is not boolean. + // TODO: .booleanValue does not have access to metas, so the EvaluationException is reported to be + // at the line & column of the CASE statement, not the predicate, unfortunately. + if (conditionValue.isNotUnknown() && conditionValue.booleanValue()) { + return@thunk bt.second(env) + } + } + elseThunk(env) + } + // Permissive mode propagates data type mismatches as MISSING, which is + // equivalent to false for searched CASE predicates. To simplify this, + // all we really need to do is consider any non-boolean result from the + // predicate to be false. + TypingMode.PERMISSIVE -> thunkFactory.thunkEnv(metas) thunk@{ env -> + branchThunks.forEach { bt -> + val conditionValue = bt.first(env) + if (conditionValue.type == ExprValueType.BOOL && conditionValue.booleanValue()) { + return@thunk bt.second(env) + } + } + elseThunk(env) + } + } + } + + private fun compileStruct(expr: PartiqlPhysical.Expr.Struct): PhysicalPlanThunk { + val structParts = compileStructParts(expr.parts) + + val ordering = if (expr.parts.none { it is PartiqlPhysical.StructPart.StructFields }) + StructOrdering.ORDERED + else + StructOrdering.UNORDERED + + return thunkFactory.thunkEnv(expr.metas) { env -> + val columns = mutableListOf() + for (element in structParts) { + when (element) { + is CompiledStructPart.Field -> { + val fieldName = element.nameThunk(env) + when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> + if (!fieldName.type.isText) { + err( + "Found struct field key to be of type ${fieldName.type}", + ErrorCode.EVALUATOR_NON_TEXT_STRUCT_FIELD_KEY, + errorContextFrom(expr.metas.sourceLocationMeta).also { pvm -> + pvm[Property.ACTUAL_TYPE] = fieldName.type.toString() + }, + internal = false + ) + } + TypingMode.PERMISSIVE -> + if (!fieldName.type.isText) { + continue + } + } + val fieldValue = element.valueThunk(env) + columns.add(fieldValue.namedValue(fieldName)) + } + is CompiledStructPart.StructMerge -> { + for (projThunk in element.thunks) { + val value = projThunk(env) + if (value.type == ExprValueType.MISSING) continue + + val children = value.asSequence() + if (!children.any() || value.type.isSequence) { + val name = syntheticColumnName(columns.size).exprValue() + columns.add(value.namedValue(name)) + } else { + val valuesToProject = + when (evaluatorOptions.projectionIteration) { + ProjectionIterationBehavior.FILTER_MISSING -> { + value.filter { it.type != ExprValueType.MISSING } + } + ProjectionIterationBehavior.UNFILTERED -> value + } + for (childValue in valuesToProject) { + val namedFacet = childValue.asFacet(Named::class.java) + val name = namedFacet?.name + ?: syntheticColumnName(columns.size).exprValue() + columns.add(childValue.namedValue(name)) + } + } + } + } + } + } + createStructExprValue(columns.asSequence(), ordering) + } + } + + private fun compileStructParts(projectItems: List): List = + projectItems.map { it -> + when (it) { + is PartiqlPhysical.StructPart.StructField -> { + val fieldThunk = compileAstExpr(it.fieldName) + val valueThunk = compileAstExpr(it.value) + CompiledStructPart.Field(fieldThunk, valueThunk) + } + is PartiqlPhysical.StructPart.StructFields -> { + CompiledStructPart.StructMerge(listOf(compileAstExpr(it.partExpr))) + } + } + } + + private fun compileSeq(seqType: ExprValueType, itemExprs: List, metas: MetaContainer): PhysicalPlanThunk { + require(seqType.isSequence) { "seqType must be a sequence!" } + + val itemThunks = compileAstExprs(itemExprs) + + val makeItemThunkSequence = when (seqType) { + ExprValueType.BAG -> { env: EvaluatorState -> + itemThunks.asSequence().map { itemThunk -> + // call to unnamedValue() makes sure we don't expose any underlying value name/ordinal + itemThunk(env).unnamedValue() + } + } + else -> { env: EvaluatorState -> + itemThunks.asSequence().mapIndexed { i, itemThunk -> itemThunk(env).namedValue(i.exprValue()) } + } + } + + return thunkFactory.thunkEnv(metas) { env -> + // todo: use valueFactory.newSequence() instead. + SequenceExprValue( + valueFactory.ion, + seqType, + makeItemThunkSequence(env) + ) + } + } + + @Suppress("UNUSED_PARAMETER") + private fun compileCallAgg(expr: PartiqlPhysical.Expr.CallAgg, metas: MetaContainer): PhysicalPlanThunk = TODO("call_agg") + + private fun compilePath(expr: PartiqlPhysical.Expr.Path, metas: MetaContainer): PhysicalPlanThunk { + val rootThunk = compileAstExpr(expr.root) + val remainingComponents = LinkedList() + + expr.steps.forEach { remainingComponents.addLast(it) } + + val componentThunk = compilePathComponents(remainingComponents, metas) + + return thunkFactory.thunkEnv(metas) { env -> + val rootValue = rootThunk(env) + componentThunk(env, rootValue) + } + } + + private fun compilePathComponents( + remainingComponents: LinkedList, + pathMetas: MetaContainer + ): PhysicalPlanThunkValue { + + val componentThunks = ArrayList>() + + while (!remainingComponents.isEmpty()) { + val pathComponent = remainingComponents.removeFirst() + val componentMetas = pathComponent.metas + componentThunks.add( + when (pathComponent) { + is PartiqlPhysical.PathStep.PathExpr -> { + val indexExpr = pathComponent.index + val caseSensitivity = pathComponent.case + when { + // If indexExpr is a literal string, there is no need to evaluate it--just compile a + // thunk that directly returns a bound value + indexExpr is PartiqlPhysical.Expr.Lit && indexExpr.value.toIonValue(valueFactory.ion) is IonString -> { + val lookupName = BindingName( + indexExpr.value.toIonValue(valueFactory.ion).stringValue()!!, + caseSensitivity.toBindingCase() + ) + thunkFactory.thunkEnvValue(componentMetas) { _, componentValue -> + componentValue.bindings[lookupName] ?: valueFactory.missingValue + } + } + else -> { + val indexThunk = compileAstExpr(indexExpr) + thunkFactory.thunkEnvValue(componentMetas) { env, componentValue -> + val indexValue = indexThunk(env) + when { + indexValue.type == ExprValueType.INT -> { + componentValue.ordinalBindings[indexValue.numberValue().toInt()] + } + indexValue.type.isText -> { + val lookupName = + BindingName(indexValue.stringValue(), caseSensitivity.toBindingCase()) + componentValue.bindings[lookupName] + } + else -> { + when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> err( + "Cannot convert index to int/string: $indexValue", + ErrorCode.EVALUATOR_INVALID_CONVERSION, + errorContextFrom(componentMetas), + internal = false + ) + TypingMode.PERMISSIVE -> valueFactory.missingValue + } + } + } ?: valueFactory.missingValue + } + } + } + } + is PartiqlPhysical.PathStep.PathUnpivot -> { + when { + !remainingComponents.isEmpty() -> { + val tempThunk = compilePathComponents(remainingComponents, pathMetas) + thunkFactory.thunkEnvValue(componentMetas) { env, componentValue -> + val mapped = componentValue.unpivot() + .flatMap { tempThunk(env, it).rangeOver() } + .asSequence() + valueFactory.newBag(mapped) + } + } + else -> + thunkFactory.thunkEnvValue(componentMetas) { _, componentValue -> + valueFactory.newBag(componentValue.unpivot().asSequence()) + } + } + } + // this is for `path[*].component` + is PartiqlPhysical.PathStep.PathWildcard -> { + when { + !remainingComponents.isEmpty() -> { + val hasMoreWildCards = + remainingComponents.filterIsInstance().any() + val tempThunk = compilePathComponents(remainingComponents, pathMetas) + + when { + !hasMoreWildCards -> thunkFactory.thunkEnvValue(componentMetas) { env, componentValue -> + val mapped = componentValue + .rangeOver() + .map { tempThunk(env, it) } + .asSequence() + + valueFactory.newBag(mapped) + } + else -> thunkFactory.thunkEnvValue(componentMetas) { env, componentValue -> + val mapped = componentValue + .rangeOver() + .flatMap { + val tempValue = tempThunk(env, it) + tempValue + } + .asSequence() + + valueFactory.newBag(mapped) + } + } + } + else -> { + thunkFactory.thunkEnvValue(componentMetas) { _, componentValue -> + val mapped = componentValue.rangeOver().asSequence() + valueFactory.newBag(mapped) + } + } + } + } + } + ) + } + return when (componentThunks.size) { + 1 -> componentThunks.first() + else -> thunkFactory.thunkEnvValue(pathMetas) { env, rootValue -> + componentThunks.fold(rootValue) { componentValue, componentThunk -> + componentThunk(env, componentValue) + } + } + } + } + + /** + * Given an AST node that represents a `LIKE` predicate, return an ExprThunk that evaluates a `LIKE` predicate. + * + * Three cases + * + * 1. All arguments are literals, then compile and run the pattern + * 1. Search pattern and escape pattern are literals, compile the pattern. Running the pattern deferred to evaluation time. + * 1. Pattern or escape (or both) are *not* literals, compile and running of pattern deferred to evaluation time. + * + * ``` + * LIKE [ESCAPE ] + * ``` + * + * @return a thunk that when provided with an environment evaluates the `LIKE` predicate + */ + private fun compileLike(expr: PartiqlPhysical.Expr.Like, metas: MetaContainer): PhysicalPlanThunk { + val valueExpr = expr.value + val patternExpr = expr.pattern + val escapeExpr = expr.escape + + val patternLocationMeta = patternExpr.metas.toPartiQlMetaContainer().sourceLocation + val escapeLocationMeta = escapeExpr?.metas?.toPartiQlMetaContainer()?.sourceLocation + + // This is so that null short-circuits can be supported. + fun getRegexPattern(pattern: ExprValue, escape: ExprValue?): (() -> Pattern)? { + val patternArgs = listOfNotNull(pattern, escape) + when { + patternArgs.any { it.type.isUnknown } -> return null + patternArgs.any { !it.type.isText } -> return { + err( + "LIKE expression must be given non-null strings as input", + ErrorCode.EVALUATOR_LIKE_INVALID_INPUTS, + errorContextFrom(metas).also { + it[Property.LIKE_PATTERN] = pattern.ionValue.toString() + if (escape != null) it[Property.LIKE_ESCAPE] = escape.ionValue.toString() + }, + internal = false + ) + } + else -> { + val (patternString: String, escapeChar: Int?) = + checkPattern(pattern.ionValue, patternLocationMeta, escape?.ionValue, escapeLocationMeta) + val likeRegexPattern = when { + patternString.isEmpty() -> Pattern.compile("") + else -> parsePattern(patternString, escapeChar) + } + return { likeRegexPattern } + } + } + } + + fun matchRegexPattern(value: ExprValue, likePattern: (() -> Pattern)?): ExprValue { + return when { + likePattern == null || value.type.isUnknown -> valueFactory.nullValue + !value.type.isText -> err( + "LIKE expression must be given non-null strings as input", + ErrorCode.EVALUATOR_LIKE_INVALID_INPUTS, + errorContextFrom(metas).also { + it[Property.LIKE_VALUE] = value.ionValue.toString() + }, + internal = false + ) + else -> valueFactory.newBoolean(likePattern().matcher(value.stringValue()).matches()) + } + } + + val valueThunk = compileAstExpr(valueExpr) + + // If the pattern and escape expressions are literals then we can compile the pattern now and + // re-use it with every execution. Otherwise, we must re-compile the pattern every time. + return when { + patternExpr is PartiqlPhysical.Expr.Lit && (escapeExpr == null || escapeExpr is PartiqlPhysical.Expr.Lit) -> { + val patternParts = getRegexPattern( + valueFactory.newFromIonValue(patternExpr.value.toIonValue(valueFactory.ion)), + (escapeExpr as? PartiqlPhysical.Expr.Lit)?.value?.toIonValue(valueFactory.ion) + ?.let { valueFactory.newFromIonValue(it) } + ) + + // If valueExpr is also a literal then we can evaluate this at compile time and return a constant. + if (valueExpr is PartiqlPhysical.Expr.Lit) { + val resultValue = matchRegexPattern( + valueFactory.newFromIonValue(valueExpr.value.toIonValue(valueFactory.ion)), + patternParts + ) + return thunkFactory.thunkEnv(metas) { resultValue } + } else { + thunkFactory.thunkEnvOperands(metas, valueThunk) { _, value -> + matchRegexPattern(value, patternParts) + } + } + } + else -> { + val patternThunk = compileAstExpr(patternExpr) + when (escapeExpr) { + null -> { + // thunk that re-compiles the DFA every evaluation without a custom escape sequence + thunkFactory.thunkEnvOperands(metas, valueThunk, patternThunk) { _, value, pattern -> + val pps = getRegexPattern(pattern, null) + matchRegexPattern(value, pps) + } + } + else -> { + // thunk that re-compiles the pattern every evaluation but *with* a custom escape sequence + val escapeThunk = compileAstExpr(escapeExpr) + thunkFactory.thunkEnvOperands( + metas, + valueThunk, + patternThunk, + escapeThunk + ) { _, value, pattern, escape -> + val pps = getRegexPattern(pattern, escape) + matchRegexPattern(value, pps) + } + } + } + } + } + } + + /** + * Given the pattern and optional escape character in a `LIKE` predicate as [IonValue]s + * check their validity based on the SQL92 spec and return a triple that contains in order + * + * - the search pattern as a string + * - the escape character, possibly `null` + * - the length of the search pattern. The length of the search pattern is either + * - the length of the string representing the search pattern when no escape character is used + * - the length of the string representing the search pattern without counting uses of the escape character + * when an escape character is used + * + * A search pattern is valid when + * 1. pattern is not null + * 1. pattern contains characters where `_` means any 1 character and `%` means any string of length 0 or more + * 1. if the escape character is specified then pattern can be deterministically partitioned into character groups where + * 1. A length 1 character group consists of any character other than the ESCAPE character + * 1. A length 2 character group consists of the ESCAPE character followed by either `_` or `%` or the ESCAPE character itself + * + * @param pattern search pattern + * @param escape optional escape character provided in the `LIKE` predicate + * + * @return a triple that contains in order the search pattern as a [String], optionally the code point for the escape character if one was provided + * and the size of the search pattern excluding uses of the escape character + */ + private fun checkPattern( + pattern: IonValue, + patternLocationMeta: SourceLocationMeta?, + escape: IonValue?, + escapeLocationMeta: SourceLocationMeta? + ): Pair { + + val patternString = pattern.stringValue() + ?: err( + "Must provide a non-null value for PATTERN in a LIKE predicate: $pattern", + ErrorCode.EVALUATOR_LIKE_PATTERN_INVALID_ESCAPE_SEQUENCE, + errorContextFrom(patternLocationMeta), + internal = false + ) + + escape?.let { + val escapeCharString = checkEscapeChar(escape, escapeLocationMeta) + val escapeCharCodePoint = escapeCharString.codePointAt(0) // escape is a string of length 1 + val validEscapedChars = setOf('_'.toInt(), '%'.toInt(), escapeCharCodePoint) + val iter = patternString.codePointSequence().iterator() + + while (iter.hasNext()) { + val current = iter.next() + if (current == escapeCharCodePoint && (!iter.hasNext() || !validEscapedChars.contains(iter.next()))) { + err( + "Invalid escape sequence : $patternString", + ErrorCode.EVALUATOR_LIKE_PATTERN_INVALID_ESCAPE_SEQUENCE, + errorContextFrom(patternLocationMeta).apply { + set(Property.LIKE_PATTERN, patternString) + set(Property.LIKE_ESCAPE, escapeCharString) + }, + internal = false + ) + } + } + return Pair(patternString, escapeCharCodePoint) + } + return Pair(patternString, null) + } + + /** + * Given an [IonValue] to be used as the escape character in a `LIKE` predicate check that it is + * a valid character based on the SQL Spec. + * + * + * A value is a valid escape when + * 1. it is 1 character long, and, + * 1. Cannot be null (SQL92 spec marks this cases as *unknown*) + * + * @param escape value provided as an escape character for a `LIKE` predicate + * + * @return the escape character as a [String] or throws an exception when the input is invalid + */ + private fun checkEscapeChar(escape: IonValue, locationMeta: SourceLocationMeta?): String { + val escapeChar = escape.stringValue() ?: err( + "Must provide a value when using ESCAPE in a LIKE predicate: $escape", + ErrorCode.EVALUATOR_LIKE_PATTERN_INVALID_ESCAPE_SEQUENCE, + errorContextFrom(locationMeta), + internal = false + ) + when (escapeChar) { + "" -> { + err( + "Cannot use empty character as ESCAPE character in a LIKE predicate: $escape", + ErrorCode.EVALUATOR_LIKE_PATTERN_INVALID_ESCAPE_SEQUENCE, + errorContextFrom(locationMeta), + internal = false + ) + } + else -> { + if (escapeChar.trim().length != 1) { + err( + "Escape character must have size 1 : $escapeChar", + ErrorCode.EVALUATOR_LIKE_PATTERN_INVALID_ESCAPE_SEQUENCE, + errorContextFrom(locationMeta), + internal = false + ) + } + } + } + return escapeChar + } + + private fun compileExec(node: PartiqlPhysical.Statement.Exec): PhysicalPlanThunk { + val metas = node.metas + val procedureName = node.procedureName.text + val procedure = procedures[procedureName] ?: err( + "No such stored procedure: $procedureName", + ErrorCode.EVALUATOR_NO_SUCH_PROCEDURE, + errorContextFrom(metas).also { + it[Property.PROCEDURE_NAME] = procedureName + }, + internal = false + ) + + val args = node.args + // Check arity + if (args.size !in procedure.signature.arity) { + val errorContext = errorContextFrom(metas).also { + it[Property.EXPECTED_ARITY_MIN] = procedure.signature.arity.first + it[Property.EXPECTED_ARITY_MAX] = procedure.signature.arity.last + } + + val message = when { + procedure.signature.arity.first == 1 && procedure.signature.arity.last == 1 -> + "${procedure.signature.name} takes a single argument, received: ${args.size}" + procedure.signature.arity.first == procedure.signature.arity.last -> + "${procedure.signature.name} takes exactly ${procedure.signature.arity.first} arguments, received: ${args.size}" + else -> + "${procedure.signature.name} takes between ${procedure.signature.arity.first} and " + + "${procedure.signature.arity.last} arguments, received: ${args.size}" + } + + throw EvaluationException( + message, + ErrorCode.EVALUATOR_INCORRECT_NUMBER_OF_ARGUMENTS_TO_PROCEDURE_CALL, + errorContext, + internal = false + ) + } + + // Compile the procedure's arguments + val argThunks = compileAstExprs(args) + + return thunkFactory.thunkEnv(metas) { env -> + val procedureArgValues = argThunks.map { it(env) } + procedure.call(env.session, procedureArgValues) + } + } + + private fun compileDate(expr: PartiqlPhysical.Expr.Date, metas: MetaContainer): PhysicalPlanThunk = + thunkFactory.thunkEnv(metas) { + valueFactory.newDate( + expr.year.value.toInt(), + expr.month.value.toInt(), + expr.day.value.toInt() + ) + } + + private fun compileLitTime(expr: PartiqlPhysical.Expr.LitTime, metas: MetaContainer): PhysicalPlanThunk = + thunkFactory.thunkEnv(metas) { + // Add the default time zone if the type "TIME WITH TIME ZONE" does not have an explicitly specified time zone. + valueFactory.newTime( + Time.of( + expr.value.hour.value.toInt(), + expr.value.minute.value.toInt(), + expr.value.second.value.toInt(), + expr.value.nano.value.toInt(), + expr.value.precision.value.toInt(), + if (expr.value.withTimeZone.value && expr.value.tzMinutes == null) evaluatorOptions.defaultTimezoneOffset.totalMinutes else expr.value.tzMinutes?.value?.toInt() + ) + ) + } + + /** A special wrapper for `UNPIVOT` values as a BAG. */ + private class UnpivotedExprValue(private val values: Iterable) : BaseExprValue() { + override val type = ExprValueType.BAG + override fun iterator() = values.iterator() + + // XXX this value is only ever produced in a FROM iteration, thus none of these should ever be called + override val ionValue + get() = throw UnsupportedOperationException("Synthetic value cannot provide ion value") + } + + /** Unpivots a `struct`, and synthesizes a synthetic singleton `struct` for other [ExprValue]. */ + internal fun ExprValue.unpivot(): ExprValue = when { + // special case for our special UNPIVOT value to avoid double wrapping + this is UnpivotedExprValue -> this + // Wrap into a pseudo-BAG + type == ExprValueType.STRUCT || type == ExprValueType.MISSING -> UnpivotedExprValue(this) + // for non-struct, this wraps any value into a BAG with a synthetic name + else -> UnpivotedExprValue( + listOf( + this.namedValue(valueFactory.newString(syntheticColumnName(0))) + ) + ) + } + + private fun createStructExprValue(seq: Sequence, ordering: StructOrdering) = + valueFactory.newStruct( + when (evaluatorOptions.projectionIteration) { + ProjectionIterationBehavior.FILTER_MISSING -> seq.filter { it.type != ExprValueType.MISSING } + ProjectionIterationBehavior.UNFILTERED -> seq + }, + ordering + ) +} + +internal val MetaContainer.sourceLocationMeta get() = this[SourceLocationMeta.TAG] as? SourceLocationMeta + +internal fun StaticType.getTypes() = when (val flattened = this.flatten()) { + is AnyOfType -> flattened.types + else -> listOf(this) +} + +/** + * Represents an element in a select list that is to be projected into the final result. + * i.e. an expression, or a (project_all) node. + */ +private sealed class CompiledStructPart { + + /** + * Represents a single compiled expression to be projected into the final result. + * Given `SELECT a + b as value FROM foo`: + * - `name` is "value" + * - `thunk` is compiled expression, i.e. `a + b` + */ + class Field(val nameThunk: PhysicalPlanThunk, val valueThunk: PhysicalPlanThunk) : CompiledStructPart() + + /** + * Represents a wildcard ((path_project_all) node) expression to be projected into the final result. + * This covers two cases. For `SELECT foo.* FROM foo`, `exprThunks` contains a single compiled expression + * `foo`. + * + * For `SELECT * FROM foo, bar, bat`, `exprThunks` would contain a compiled expression for each of `foo`, `bar` and + * `bat`. + */ + class StructMerge(val thunks: List) : CompiledStructPart() +} diff --git a/lang/src/org/partiql/lang/eval/physical/RelationThunk.kt b/lang/src/org/partiql/lang/eval/physical/RelationThunk.kt new file mode 100644 index 0000000000..571e741a27 --- /dev/null +++ b/lang/src/org/partiql/lang/eval/physical/RelationThunk.kt @@ -0,0 +1,45 @@ +package org.partiql.lang.eval.physical + +import com.amazon.ionelement.api.MetaContainer +import org.partiql.lang.ast.SourceLocationMeta +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.eval.EvaluationException +import org.partiql.lang.eval.ThunkFactory +import org.partiql.lang.eval.errorContextFrom +import org.partiql.lang.eval.fillErrorContext +import org.partiql.lang.eval.relation.RelationIterator + +/** A thunk that returns a [RelationIterator], which is the result of evaluating a relational operator. */ +internal typealias RelationThunkEnv = (EvaluatorState) -> RelationIterator + +/** + * Invokes [t] with error handling like is supplied by [ThunkFactory]. + * + * This function is not currently in [ThunkFactory] to avoid complicating it further. If a need arises, it could be + * moved. + */ +internal inline fun relationThunk(metas: MetaContainer, crossinline t: RelationThunkEnv): RelationThunkEnv { + val sourceLocationMeta = metas[SourceLocationMeta.TAG] as? SourceLocationMeta + return { env: EvaluatorState -> + try { + t(env) + } catch (e: EvaluationException) { + // Only add source location data to the error context if it doesn't already exist + // in [errorContext]. + if (!e.errorContext.hasProperty(Property.LINE_NUMBER)) { + sourceLocationMeta?.let { fillErrorContext(e.errorContext, sourceLocationMeta) } + } + throw e + } catch (e: Exception) { + val message = e.message ?: "" + throw EvaluationException( + "Generic exception, $message", + errorCode = ErrorCode.EVALUATOR_GENERIC_EXCEPTION, + errorContext = errorContextFrom(sourceLocationMeta), + cause = e, + internal = true + ) + } + } +} diff --git a/lang/src/org/partiql/lang/eval/physical/UndefinedVariableUtil.kt b/lang/src/org/partiql/lang/eval/physical/UndefinedVariableUtil.kt new file mode 100644 index 0000000000..ee6bd16d04 --- /dev/null +++ b/lang/src/org/partiql/lang/eval/physical/UndefinedVariableUtil.kt @@ -0,0 +1,31 @@ +package org.partiql.lang.eval.physical + +import com.amazon.ionelement.api.MetaContainer +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.errors.UNBOUND_QUOTED_IDENTIFIER_HINT +import org.partiql.lang.eval.BindingCase +import org.partiql.lang.eval.BindingName +import org.partiql.lang.eval.EvaluationException +import org.partiql.lang.eval.errorContextFrom +import org.partiql.lang.util.propertyValueMapOf + +internal fun throwUndefinedVariableException( + bindingName: BindingName, + metas: MetaContainer? +): Nothing { + val (errorCode, hint) = when (bindingName.bindingCase) { + BindingCase.SENSITIVE -> + ErrorCode.EVALUATOR_QUOTED_BINDING_DOES_NOT_EXIST to " $UNBOUND_QUOTED_IDENTIFIER_HINT" + BindingCase.INSENSITIVE -> + ErrorCode.EVALUATOR_BINDING_DOES_NOT_EXIST to "" + } + throw EvaluationException( + message = "No such binding: ${bindingName.name}.$hint", + errorCode = errorCode, + errorContext = (metas?.let { errorContextFrom(metas) } ?: propertyValueMapOf()).also { + it[Property.BINDING_NAME] = bindingName.name + }, + internal = false + ) +} diff --git a/lang/src/org/partiql/lang/eval/relation/Relation.kt b/lang/src/org/partiql/lang/eval/relation/Relation.kt new file mode 100644 index 0000000000..9b8915861b --- /dev/null +++ b/lang/src/org/partiql/lang/eval/relation/Relation.kt @@ -0,0 +1,90 @@ +package org.partiql.lang.eval.relation + +import kotlin.coroutines.Continuation +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext +import kotlin.coroutines.createCoroutine +import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED +import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn +import kotlin.coroutines.resume + +/** + * Builds a [RelationIterator] that yields after every step in evaluating a relational operator. + * + * This is inspired heavily by Kotlin's [sequence] but for [RelationIterator] instead of [Sequence]. + */ +internal fun relation( + seqType: RelationType, + block: suspend RelationScope.() -> Unit +): RelationIterator { + val iterator = RelationBuilderIterator(seqType, block) + iterator.nextStep = block.createCoroutine(receiver = iterator, completion = iterator) + return iterator +} + +@DslMarker +@Target(AnnotationTarget.CLASS, AnnotationTarget.TYPE) +annotation class RelationDsl + +/** Defines functions within a block supplied to [relation]. */ +@RelationDsl +internal interface RelationScope { + /** Suspends the coroutine. Should be called after processing the current row of the relation. */ + suspend fun yield() + + /** Yields once for every row remaining in [relItr]. */ + suspend fun yieldAll(relItr: RelationIterator) +} + +private class RelationBuilderIterator( + override val relType: RelationType, + block: suspend RelationScope.() -> Unit +) : RelationScope, RelationIterator, Continuation { + var yielded = false + + var nextStep: Continuation? = block.createCoroutine(receiver = this, completion = this) + + override suspend fun yield() { + yielded = true + suspendCoroutineUninterceptedOrReturn { c -> + nextStep = c + COROUTINE_SUSPENDED + } + } + + override suspend fun yieldAll(relItr: RelationIterator) { + while (relItr.nextRow()) { + yield() + } + } + + override fun nextRow(): Boolean { + // if nextStep is null it means we've reached the end of the relation, but nextRow() was called again + // for some reason. This probably indicates a bug since we should not in general be attempting to + // read a `RelationIterator` after it has exhausted. + if (nextStep == null) { + error( + "Relation was previously exhausted. " + + "Please don't call nextRow() again after it returns false the first time." + ) + } + val step = nextStep!! + nextStep = null + step.resume(Unit) + + return if (yielded) { + yielded = false + true + } else { + false + } + } + + // Completion continuation implementation + override fun resumeWith(result: Result) { + result.getOrThrow() // just rethrow exception if it is there + } + + override val context: CoroutineContext + get() = EmptyCoroutineContext +} diff --git a/lang/src/org/partiql/lang/eval/relation/RelationIterator.kt b/lang/src/org/partiql/lang/eval/relation/RelationIterator.kt new file mode 100644 index 0000000000..d6807ef443 --- /dev/null +++ b/lang/src/org/partiql/lang/eval/relation/RelationIterator.kt @@ -0,0 +1,37 @@ +package org.partiql.lang.eval.relation + +import org.partiql.lang.domains.PartiqlPhysical.Expr.BindingsToValues +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.physical.EvaluatorState + +enum class RelationType { BAG, LIST } + +/** + * Represents an iterator that is returned by a relational operator during evaluation. + * + * This is a "faux" iterator in a sense, because it doesn't provide direct access to a current element. + * + * When initially created, the [RelationIterator] is positioned "before" the first element. [nextRow] should be called + * to advance the iterator to the first row. + * + * We do not use [Iterator] for this purpose because it is not a natural fit. There are two reasons: + * + * 1. [Iterator.next] returns the current element, but this isn't actually an iterator over a collection. Instead, + * execution of [nextRow] may have a side effect of populating value(s) in the current [EvaluatorState.registers] + * array. Bridge operators such as [BindingsToValues] are responsible for extracting current values from + * [EvaluatorState.registers] and converting them to the appropriate container [ExprValue]s. + * 2. [Iterator.hasNext] would require knowing if additional rows remain after the current row, but in a few cases + * including filters and joins that requires advancing through possibly all remaining rows to see if any remaining row + * matches the predicate. This is awkward to implement and would force eager evaluation of the [Iterator]. + */ +internal interface RelationIterator { + val relType: RelationType + + /** + * Advances the iterator to the next row. + * + * Returns true to indicate that the next row was found and that [EvaluatorState.registers] have been updated for + * the current row. False if there are no more rows. + */ + fun nextRow(): Boolean +} diff --git a/lang/src/org/partiql/lang/eval/visitors/PartiqlAstSanityValidator.kt b/lang/src/org/partiql/lang/eval/visitors/PartiqlAstSanityValidator.kt index c86dd7141a..39d4b89ceb 100644 --- a/lang/src/org/partiql/lang/eval/visitors/PartiqlAstSanityValidator.kt +++ b/lang/src/org/partiql/lang/eval/visitors/PartiqlAstSanityValidator.kt @@ -23,7 +23,6 @@ import org.partiql.lang.ast.IsCountStarMeta import org.partiql.lang.ast.passes.SemanticException import org.partiql.lang.domains.PartiqlAst import org.partiql.lang.domains.addSourceLocation -import org.partiql.lang.domains.errorContextFrom import org.partiql.lang.errors.ErrorCode import org.partiql.lang.errors.Property import org.partiql.lang.errors.PropertyValueMap @@ -31,6 +30,7 @@ import org.partiql.lang.eval.CompileOptions import org.partiql.lang.eval.EvaluationException import org.partiql.lang.eval.TypedOpBehavior import org.partiql.lang.eval.err +import org.partiql.lang.eval.errorContextFrom import org.partiql.pig.runtime.LongPrimitive /** diff --git a/lang/src/org/partiql/lang/eval/visitors/PartiqlPhysicalSanityValidator.kt b/lang/src/org/partiql/lang/eval/visitors/PartiqlPhysicalSanityValidator.kt new file mode 100644 index 0000000000..2fb63ad7b0 --- /dev/null +++ b/lang/src/org/partiql/lang/eval/visitors/PartiqlPhysicalSanityValidator.kt @@ -0,0 +1,121 @@ +package org.partiql.lang.eval.visitors + +import com.amazon.ionelement.api.IntElement +import com.amazon.ionelement.api.IntElementSize +import com.amazon.ionelement.api.MetaContainer +import com.amazon.ionelement.api.TextElement +import org.partiql.lang.ast.IsCountStarMeta +import org.partiql.lang.ast.passes.SemanticException +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.domains.addSourceLocation +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.errors.PropertyValueMap +import org.partiql.lang.eval.EvaluationException +import org.partiql.lang.eval.TypedOpBehavior +import org.partiql.lang.eval.err +import org.partiql.lang.eval.errorContextFrom +import org.partiql.lang.planner.EvaluatorOptions +import org.partiql.lang.util.propertyValueMapOf +import org.partiql.pig.runtime.LongPrimitive + +/** + * Provides rules for basic AST sanity checks that should be performed before any attempt at further physical + * plan processing. This is provided as a distinct [PartiqlPhysical.Visitor] so that the planner and evaluator may + * assume that the physical plan has passed the checks performed here. + * + * Any exception thrown by this class should always be considered an indication of a bug. + */ +class PartiqlPhysicalSanityValidator(private val evaluatorOptions: EvaluatorOptions) : PartiqlPhysical.Visitor() { + + /** + * Quick validation step to make sure the indexes of any variables make sense. + * It is unlikely that this check will ever fail, but if it does, it likely means there's a bug in + * [org.partiql.lang.planner.transforms.VariableIdAllocator] or that the plan was malformed by other means. + */ + override fun visitPlan(node: PartiqlPhysical.Plan) { + node.locals.forEachIndexed { idx, it -> + if (it.registerIndex.value != idx.toLong()) { + throw EvaluationException( + message = "Variable index must match ordinal position of variable", + errorCode = ErrorCode.INTERNAL_ERROR, + errorContext = propertyValueMapOf(), + internal = true + ) + } + } + super.visitPlan(node) + } + + override fun visitExprLit(node: PartiqlPhysical.Expr.Lit) { + val ionValue = node.value + val metas = node.metas + if (node.value is IntElement && ionValue.integerSize == IntElementSize.BIG_INTEGER) { + throw EvaluationException( + message = "Int overflow or underflow at compile time", + errorCode = ErrorCode.SEMANTIC_LITERAL_INT_OVERFLOW, + errorContext = errorContextFrom(metas), + internal = false + ) + } + } + + private fun validateDecimalOrNumericType(scale: LongPrimitive?, precision: LongPrimitive?, metas: MetaContainer) { + if (scale != null && precision != null && evaluatorOptions.typedOpBehavior == TypedOpBehavior.HONOR_PARAMETERS) { + if (scale.value !in 0..precision.value) { + err( + "Scale ${scale.value} should be between 0 and precision ${precision.value}", + errorCode = ErrorCode.SEMANTIC_INVALID_DECIMAL_ARGUMENTS, + errorContext = errorContextFrom(metas), + internal = false + ) + } + } + } + + override fun visitTypeDecimalType(node: PartiqlPhysical.Type.DecimalType) { + validateDecimalOrNumericType(node.scale, node.precision, node.metas) + } + + override fun visitTypeNumericType(node: PartiqlPhysical.Type.NumericType) { + validateDecimalOrNumericType(node.scale, node.precision, node.metas) + } + + override fun visitExprCallAgg(node: PartiqlPhysical.Expr.CallAgg) { + val setQuantifier = node.setq + val metas = node.metas + if (setQuantifier is PartiqlPhysical.SetQuantifier.Distinct && metas.containsKey(IsCountStarMeta.TAG)) { + err( + "COUNT(DISTINCT *) is not supported", + ErrorCode.EVALUATOR_COUNT_DISTINCT_STAR, + errorContextFrom(metas), + internal = false + ) + } + } + + override fun visitExprStruct(node: PartiqlPhysical.Expr.Struct) { + node.parts.forEach { part -> + when (part) { + is PartiqlPhysical.StructPart.StructField -> { + if (part.fieldName is PartiqlPhysical.Expr.Missing || + (part.fieldName is PartiqlPhysical.Expr.Lit && part.fieldName.value !is TextElement) + ) { + val type = when (part.fieldName) { + is PartiqlPhysical.Expr.Lit -> part.fieldName.value.type.toString() + else -> "MISSING" + } + throw SemanticException( + "Found struct part to be of type $type", + ErrorCode.SEMANTIC_NON_TEXT_STRUCT_FIELD_KEY, + PropertyValueMap().addSourceLocation(part.fieldName.metas).also { pvm -> + pvm[Property.ACTUAL_TYPE] = type + } + ) + } + } + is PartiqlPhysical.StructPart.StructFields -> { /* intentionally empty */ } + } + } + } +} diff --git a/lang/src/org/partiql/lang/planner/EvaluatorOptions.kt b/lang/src/org/partiql/lang/planner/EvaluatorOptions.kt new file mode 100644 index 0000000000..546ed1a835 --- /dev/null +++ b/lang/src/org/partiql/lang/planner/EvaluatorOptions.kt @@ -0,0 +1,85 @@ +package org.partiql.lang.planner + +import org.partiql.lang.eval.ProjectionIterationBehavior +import org.partiql.lang.eval.ThunkOptions +import org.partiql.lang.eval.TypedOpBehavior +import org.partiql.lang.eval.TypingMode +import java.time.ZoneOffset + +/* + +Differences between CompilerOptions and EvaluatorOptions: + +- There is no EvaluatorOptions equivalent for CompileOptions.visitorTransformMode since the planner always runs some basic + normalization and variable resolution passes *before* the customer can inject their own transforms. +- There is no EvaluatorOptions equivalent for CompileOptions.thunkReturnTypeAssertions since PlannerPipeline does not +support the static type inferencer (yet). +- EvaluatorOptions.allowUndefinedVariables is new. +- EvaluatorOptions has no equivalent for CompileOptions.undefinedVariableBehavior -- this was added for backward +compatibility on behalf of a customer we don't have anymore. Internal bug number is IONSQL-134. + */ + +/** + * Specifies options that effect the behavior of the PartiQL physical plan evaluator. + * + * @param defaultTimezoneOffset Default timezone offset to be used when TIME WITH TIME ZONE does not explicitly + * specify the time zone. Defaults to [ZoneOffset.UTC]. + */ +@Suppress("DataClassPrivateConstructor") +data class EvaluatorOptions private constructor ( + val projectionIteration: ProjectionIterationBehavior = ProjectionIterationBehavior.FILTER_MISSING, + val thunkOptions: ThunkOptions = ThunkOptions.standard(), + val typingMode: TypingMode = TypingMode.LEGACY, + val typedOpBehavior: TypedOpBehavior = TypedOpBehavior.LEGACY, + val defaultTimezoneOffset: ZoneOffset = ZoneOffset.UTC +) { + companion object { + + /** + * Creates a java style builder that will choose the default values for any unspecified options. + */ + @JvmStatic + fun builder() = Builder() + + /** + * Creates a java style builder that will clone the [EvaluatorOptions] passed to the constructor. + */ + @JvmStatic + fun builder(options: EvaluatorOptions) = Builder(options) + + /** + * Kotlin style builder that will choose the default values for any unspecified options. + */ + fun build(block: Builder.() -> Unit) = Builder().apply(block).build() + + /** + * Kotlin style builder that will clone the [EvaluatorOptions] passed to the constructor. + */ + fun build(options: EvaluatorOptions, block: Builder.() -> Unit) = Builder(options).apply(block).build() + + /** + * Creates a [EvaluatorOptions] instance with the standard values for use by the legacy AST compiler. + */ + @JvmStatic + fun standard() = Builder().build() + } + + /** + * Builds a [EvaluatorOptions] instance. + */ + class Builder(private var options: EvaluatorOptions = EvaluatorOptions()) { + + fun projectionIteration(value: ProjectionIterationBehavior) = set { copy(projectionIteration = value) } + fun typingMode(value: TypingMode) = set { copy(typingMode = value) } + fun typedOpBehavior(value: TypedOpBehavior) = set { copy(typedOpBehavior = value) } + fun thunkOptions(value: ThunkOptions) = set { copy(thunkOptions = value) } + fun defaultTimezoneOffset(value: ZoneOffset) = set { copy(defaultTimezoneOffset = value) } + + private inline fun set(block: EvaluatorOptions.() -> EvaluatorOptions): Builder { + options = block(options) + return this + } + + fun build() = options + } +} diff --git a/lang/src/org/partiql/lang/planner/MetadataResolver.kt b/lang/src/org/partiql/lang/planner/MetadataResolver.kt new file mode 100644 index 0000000000..af13cb00e2 --- /dev/null +++ b/lang/src/org/partiql/lang/planner/MetadataResolver.kt @@ -0,0 +1,65 @@ +package org.partiql.lang.planner + +import org.partiql.lang.eval.BindingCase +import org.partiql.lang.eval.BindingName + +/** Indicates the result of an attempt to resolve a global variable to its customer supplied unique identifier. */ +sealed class ResolutionResult { + /** + * A success case, indicates the [uniqueId] of the match to the [BindingName] in the global scope. + * Typically, this is defined by the storage layer. + * + * In the future, this will likely contain much more than just a unique id. It might include detailed schema + * information about global variables. + */ + data class GlobalVariable(val uniqueId: String) : ResolutionResult() + + /** + * A success case, indicates the [index] of the only possible match to the [BindingName] in a local lexical scope. + * This is `internal` because [index] is an implementation detail that shouldn't be accessible outside of this + * library. + */ + internal data class LocalVariable(val index: Int) : ResolutionResult() + + /** A failure case, indicates that resolution did not match any variable. */ + object Undefined : ResolutionResult() +} + +/** + * Supplies the query planner with metadata about the current database. Meant to be implemented by the application + * embedding PartiQL. + * + * Metadata is associated with global variables. Global variables can be tables or (less commonly) any other + * application specific global variable. + * + * In the future, new methods could be added which expose information about other types of database metadata such as + * available indexes and table statistics. + */ +interface MetadataResolver { + /** + * Implementations try to resolve a variable which is typically a database table to a schema + * using [bindingName]. [bindingName] includes both the name as specified by the query author and a [BindingCase] + * which indicates if query author included double quotes (") which mean the lookup should be case-sensitive. + * + * Implementations of this function must return: + * + * - [ResolutionResult.GlobalVariable] if [bindingName] matches a global variable (typically a database table). + * - [ResolutionResult.Undefined] if no identifier matches [bindingName]. + * + * When determining if a variable name matches a global variable, it is important to consider if the comparison + * should be case-sensitive or case-insensitive. @see [BindingName.bindingCase]. In the event that more than one + * variable matches a case-insensitive [BindingName], the implementation must still select one of them + * without providing an error. (This is consistent with Postres's behavior in this scenario.) + * + * Note that while [ResolutionResult.LocalVariable] exists, it is intentionally marked `internal` and cannot + * be used outside this project. + */ + fun resolveVariable(bindingName: BindingName): ResolutionResult +} + +private val EMPTY: MetadataResolver = object : MetadataResolver { + override fun resolveVariable(bindingName: BindingName): ResolutionResult = ResolutionResult.Undefined +} + +/** Convenience function for obtaining an instance of [MetadataResolver] with no defined variables. */ +fun emptyMetadataResolver(): MetadataResolver = EMPTY diff --git a/lang/src/org/partiql/lang/planner/PassResult.kt b/lang/src/org/partiql/lang/planner/PassResult.kt new file mode 100644 index 0000000000..19e2d4111d --- /dev/null +++ b/lang/src/org/partiql/lang/planner/PassResult.kt @@ -0,0 +1,16 @@ +package org.partiql.lang.planner +import org.partiql.lang.errors.Problem + +sealed class PassResult { + /** + * Indicates query planning was successful and includes a list of any warnings that were encountered along the way. + */ + data class Success(val result: TResult, val warnings: List) : PassResult() + + /** + * Indicates query planning was not successful and includes a list of errors and warnings that were encountered + * along the way. Encountering both errors and warnings, as well as multiple errors is possible since we are not + * required to stop when encountering the first error. + */ + data class Error(val errors: List) : PassResult() +} diff --git a/lang/src/org/partiql/lang/planner/PlannerPipeline.kt b/lang/src/org/partiql/lang/planner/PlannerPipeline.kt new file mode 100644 index 0000000000..ae3b1aaad8 --- /dev/null +++ b/lang/src/org/partiql/lang/planner/PlannerPipeline.kt @@ -0,0 +1,410 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.lang.planner + +import com.amazon.ion.IonSystem +import org.partiql.lang.SqlException +import org.partiql.lang.ast.SourceLocationMeta +import org.partiql.lang.domains.PartiqlAst +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.errors.Problem +import org.partiql.lang.errors.ProblemCollector +import org.partiql.lang.errors.Property +import org.partiql.lang.eval.ExprFunction +import org.partiql.lang.eval.ExprValueFactory +import org.partiql.lang.eval.Expression +import org.partiql.lang.eval.ThunkReturnTypeAssertions +import org.partiql.lang.eval.builtins.DynamicLookupExprFunction +import org.partiql.lang.eval.builtins.createBuiltinFunctions +import org.partiql.lang.eval.builtins.storedprocedure.StoredProcedure +import org.partiql.lang.eval.physical.PhysicalExprToThunkConverterImpl +import org.partiql.lang.planner.transforms.PlanningProblemDetails +import org.partiql.lang.planner.transforms.normalize +import org.partiql.lang.planner.transforms.toDefaultPhysicalPlan +import org.partiql.lang.planner.transforms.toLogicalPlan +import org.partiql.lang.planner.transforms.toResolvedPlan +import org.partiql.lang.syntax.Parser +import org.partiql.lang.syntax.SqlParser +import org.partiql.lang.syntax.SyntaxException +import org.partiql.lang.types.CustomType + +/** + * [PlannerPipeline] is the main interface for planning and compiling PartiQL queries into instances of [Expression] + * which can be executed. + * + * This class was originally derived from [org.partiql.lang.CompilerPipeline], which is the main compiler entry point + * for the legacy AST compiler. The main difference is that the logical and physical plans have taken the place of + * PartiQL's AST, and that after parsing several passes over the AST are performed: + * + * - It is transformed into a logical plan + * - Variables are resolved. + * - It is converted into a physical plan. + * + * In the future additional passes will exist to include optimizations like predicate & projection push-down, and + * will be extensible by customers who can include their own optimizations. + * + * Two basic scenarios for using this interface: + * + * 1. You want to plan and compile a query once, but don't care about re-using the plan across process instances. In + * this scenario, simply use the [planAndCompile] function to obtain an instance of [Expression] that can be + * invoked directly to evaluate a query. + * 2. You want to plan a query once and share a planned query across process boundaries. In this scenario, use + * [plan] to perform query planning and obtain an instance of [PartiqlPhysical.Statement] which can be serialized to + * Ion text or binary format. On the other side of the process boundary, use [compile] to turn the + * [PartiqlPhysical.Statement] query plan into an [Expression]. Compilation itself should be relatively fast. + * + * The provided builder companion creates an instance of [PlannerPipeline] that is NOT thread safe and should NOT be + * used to compile queries concurrently. If used in a multithreaded application, use one instance of [PlannerPipeline] + * per thread. + */ +interface PlannerPipeline { + val valueFactory: ExprValueFactory + + /** + * Plans a query only but does not compile it. + * + * - Parses the specified SQL string, producing an AST. + * - Converts the AST to a logical plan. + * - Resolves all global and local variables in the logical plan, assigning unique indexes to local variables + * and calling [MetadataResolver.resolveVariable] to obtain unique identifiers global values such as tables that + * are specific to the application embedding PartiQL, and optionally converts undefined variables to dynamic + * lookups. + * - Converts the logical plan to a physical plan with `(impl default)` operators. + * + * @param query The text of the SQL statement or expression to be planned. + * @return [PassResult.Success] containing an instance of [PartiqlPhysical.Statement] and any applicable warnings + * if planning was successful or [PassResult.Error] if not. + */ + fun plan(query: String): PassResult + + /** + * Compiles the previously planned [PartiqlPhysical.Statement] instance. + * + * @param physicalPlan The physical query plan. + * @return [PassResult.Success] containing an instance of [PartiqlPhysical.Statement] and any applicable warnings + * if compilation was successful or [PassResult.Error] if not. + */ + fun compile(physicalPlan: PartiqlPhysical.Plan): PassResult + + /** + * Plans and compiles a query. + * + * @param query The text of the SQL statement or expression to be planned and compiled. + * @return [PassResult.Success] containing an instance of [PartiqlPhysical.Statement] and any applicable warnings + * if compiling and planning was successful or [PassResult.Error] if not. + */ + fun planAndCompile(query: String): PassResult = + when (val planResult = plan(query)) { + is PassResult.Error -> PassResult.Error(planResult.errors) + is PassResult.Success -> { + when (val compileResult = compile(planResult.result)) { + is PassResult.Error -> compileResult + is PassResult.Success -> PassResult.Success( + compileResult.result, + // Need to include any warnings that may have been discovered during planning. + planResult.warnings + compileResult.warnings + ) + } + } + } + + @Suppress("DeprecatedCallableAddReplaceWith", "DEPRECATION") + companion object { + private const val WARNING = "WARNING: PlannerPipeline is EXPERIMENTAL and has incomplete language support! " + + "For production use, see org.partiql.lang.CompilerPipeline which is stable and supports all PartiQL " + + "features." + + /** Kotlin style builder for [PlannerPipeline]. If calling from Java instead use [builder]. */ + @Deprecated(WARNING) + fun build(ion: IonSystem, block: Builder.() -> Unit) = build(ExprValueFactory.standard(ion), block) + + /** Kotlin style builder for [PlannerPipeline]. If calling from Java instead use [builder]. */ + @Deprecated(WARNING) + fun build(valueFactory: ExprValueFactory, block: Builder.() -> Unit) = Builder(valueFactory).apply(block).build() + + /** Fluent style builder. If calling from Kotlin instead use the [build] method. */ + @JvmStatic + @Deprecated(WARNING) + fun builder(ion: IonSystem): Builder = builder(ExprValueFactory.standard(ion)) + + /** Fluent style builder. If calling from Kotlin instead use the [build] method. */ + @JvmStatic + @Deprecated(WARNING) + fun builder(valueFactory: ExprValueFactory): Builder = Builder(valueFactory) + + /** Returns an implementation of [PlannerPipeline] with all properties set to their defaults. */ + @JvmStatic + @Deprecated(WARNING) + fun standard(ion: IonSystem): PlannerPipeline = standard(ExprValueFactory.standard(ion)) + + /** Returns an implementation of [PlannerPipeline] with all properties set to their defaults. */ + @JvmStatic + @Deprecated(WARNING) + fun standard(valueFactory: ExprValueFactory): PlannerPipeline = builder(valueFactory).build() + } + + /** + * An implementation of the builder pattern for instances of [PlannerPipeline]. The created instance of + * [PlannerPipeline] is NOT thread safe and should NOT be used to compile queries concurrently. If used in a + * multithreaded application, use one instance of [PlannerPipeline] per thread. + */ + class Builder(val valueFactory: ExprValueFactory) { + private var parser: Parser? = null + private var evaluatorOptions: EvaluatorOptions? = null + private val customFunctions: MutableMap = HashMap() + private var customDataTypes: List = listOf() + private val customProcedures: MutableMap = HashMap() + private var metadataResolver: MetadataResolver = emptyMetadataResolver() + private var allowUndefinedVariables: Boolean = false + private var enableLegacyExceptionHandling: Boolean = false + + /** + * Specifies the [Parser] to be used to turn an PartiQL query into an instance of [PartiqlAst]. + * The default is [SqlParser]. + */ + fun sqlParser(p: Parser): Builder = this.apply { + parser = p + } + + /** + * Options affecting evaluation-time behavior. The default is [EvaluatorOptions.standard]. + */ + fun evaluatorOptions(options: EvaluatorOptions): Builder = this.apply { + evaluatorOptions = options + } + + /** + * A nested builder for compilation options. The default is [EvaluatorOptions.standard]. + * + * Avoid the use of this overload if calling from Java and instead use the overload accepting an instance + * of [EvaluatorOptions]. + * + * There is no need to call [Builder.build] when using this method. + */ + fun evaluatorOptions(block: EvaluatorOptions.Builder.() -> Unit): Builder = + evaluatorOptions(EvaluatorOptions.build(block)) + + /** + * Add a custom function which will be callable by the compiled queries. + * + * Functions added here will replace any built-in function with the same name. + * + * This function is marked as internal to prevent it from being used outside the tests in this + * project--it will be replaced during implementation of the open type system. + * https://github.com/partiql/partiql-lang-kotlin/milestone/4 + */ + internal fun addFunction(function: ExprFunction): Builder = this.apply { + customFunctions[function.signature.name] = function + } + + /** + * Add custom types to CAST/IS operators to. + * + * Built-in types will take precedence over custom types in case of a name collision. + * + * This function is marked as internal to prevent it from being used outside the tests in this + * project--it will be replaced during implementation of the open type system. + * https://github.com/partiql/partiql-lang-kotlin/milestone/4 + */ + internal fun customDataTypes(customTypes: List) = this.apply { + customDataTypes = customTypes + } + + /** + * Add a custom stored procedure which will be callable by the compiled queries. + * + * Stored procedures added here will replace any built-in procedure with the same name. + * This function is marked as internal to prevent it from being used outside the tests in this + * project--it will be replaced during implementation of the open type system. + * https://github.com/partiql/partiql-lang-kotlin/milestone/4 + */ + internal fun addProcedure(procedure: StoredProcedure): Builder = this.apply { + customProcedures[procedure.signature.name] = procedure + } + + /** + * Adds the [MetadataResolver] for global variables. + * + * [metadataResolver] is queried during query planning to fetch metadata information such as table schemas. + */ + fun metadataResolver(bindings: MetadataResolver): Builder = this.apply { + this.metadataResolver = bindings + } + + /** + * Sets a flag indicating if undefined variables are allowed. + * + * When allowed, undefined variables are rewritten to dynamic lookups. This is intended to provide a migration + * path for legacy PartiQL customers who depend on dynamic lookup of undefined variables to use the query + * planner & phys. algebra. New customers should not enable this. + */ + fun allowUndefinedVariables(allow: Boolean = true): Builder = this.apply { + this.allowUndefinedVariables = allow + } + + /** + * Prevents [SqlException] that occur during compilation from being converted into [Problem]s. + * + * This is for compatibility with the legacy unit test suite, which hasn't been updated to handle + * [Problem]s yet. + */ + internal fun enableLegacyExceptionHandling(): Builder = this.apply { + enableLegacyExceptionHandling = true + } + + /** Builds the actual implementation of [PlannerPipeline]. */ + fun build(): PlannerPipeline { + val compileOptionsToUse = evaluatorOptions ?: EvaluatorOptions.standard() + + when (compileOptionsToUse.thunkOptions.thunkReturnTypeAssertions) { + ThunkReturnTypeAssertions.DISABLED -> { /* take no action */ } + ThunkReturnTypeAssertions.ENABLED -> error( + "TODO: Support ThunkReturnTypeAssertions.ENABLED " + + "need a static type pass first)" + ) + } + + val builtinFunctions = createBuiltinFunctions(valueFactory) + DynamicLookupExprFunction() + val builtinFunctionsMap = builtinFunctions.associateBy { + it.signature.name + } + + // customFunctions must be on the right side of + here to ensure that they overwrite any + // built-in functions with the same name. + val allFunctionsMap = builtinFunctionsMap + customFunctions + return PlannerPipelineImpl( + valueFactory = valueFactory, + parser = parser ?: SqlParser(valueFactory.ion, this.customDataTypes), + evaluatorOptions = compileOptionsToUse, + functions = allFunctionsMap, + customDataTypes = customDataTypes, + procedures = customProcedures, + metadataResolver = metadataResolver, + allowUndefinedVariables = allowUndefinedVariables, + enableLegacyExceptionHandling = enableLegacyExceptionHandling + ) + } + } +} + +internal class PlannerPipelineImpl( + override val valueFactory: ExprValueFactory, + private val parser: Parser, + val evaluatorOptions: EvaluatorOptions, + val functions: Map, + val customDataTypes: List, + val procedures: Map, + val metadataResolver: MetadataResolver, + val allowUndefinedVariables: Boolean, + val enableLegacyExceptionHandling: Boolean +) : PlannerPipeline { + + init { + when (evaluatorOptions.thunkOptions.thunkReturnTypeAssertions) { + ThunkReturnTypeAssertions.DISABLED -> { + /** intentionally blank. */ + } + ThunkReturnTypeAssertions.ENABLED -> + // Need a type inferencer pass on resolved logical algebra to support this. + TODO("Support for EvaluatorOptions.thunkReturnTypeAsserts == ThunkReturnTypeAssertions.ENABLED") + } + } + + val customTypedOpParameters = customDataTypes.map { customType -> + (customType.aliases + customType.name).map { alias -> + Pair(alias.toLowerCase(), customType.typedOpParameter) + } + }.flatten().toMap() + + override fun plan(query: String): PassResult { + val ast = try { + parser.parseAstStatement(query) + } catch (ex: SyntaxException) { + val problem = Problem( + SourceLocationMeta( + ex.errorContext[Property.LINE_NUMBER]?.longValue() ?: -1, + ex.errorContext[Property.COLUMN_NUMBER]?.longValue() ?: -1 + ), + PlanningProblemDetails.ParseError(ex.generateMessageNoLocation()) + ) + return PassResult.Error(listOf(problem)) + } + // Now run the AST thru each pass until we arrive at the physical algebra. + + // Normalization--synthesizes any unspecified `AS` aliases, converts `SELECT *` to `SELECT f.*[, ...]` ... + val normalizedAst = ast.normalize() + + // ast -> logical plan + val logicalPlan = normalizedAst.toLogicalPlan() + + // logical plan -> resolved logical plan + val problemHandler = ProblemCollector() + val resolvedLogicalPlan = logicalPlan.toResolvedPlan(problemHandler, metadataResolver, allowUndefinedVariables) + // If there are unresolved variables after attempting to resolve variables, then we can't proceed. + if (problemHandler.hasErrors) { + return PassResult.Error(problemHandler.problems) + } + + // Possible future passes: + // - type checking and inferencing? + // - constant folding + // - common sub-expression removal + // - push down predicates & projections on top of their scan nodes. + // - customer supplied rewrites of resolved logical plan. + + // resolved logical plan -> physical plan. + // this will give all relational operators `(impl default)`. + val physicalPlan = resolvedLogicalPlan.toDefaultPhysicalPlan() + + // Future work: invoke passes to choose relational operator implementations other than `(impl default)`. + // Future work: fully push down predicates and projections into their physical read operators. + // Future work: customer supplied rewrites of phsyical plan + + // If we reach this far, we're successful. If there were any problems at all, they were just warnings. + return PassResult.Success(physicalPlan, problemHandler.problems) + } + + override fun compile(physicalPlan: PartiqlPhysical.Plan): PassResult { + val compiler = PhysicalExprToThunkConverterImpl( + valueFactory = valueFactory, + functions = functions, + customTypedOpParameters = customTypedOpParameters, + procedures = procedures, + evaluatorOptions = evaluatorOptions + ) + + val expression = when { + enableLegacyExceptionHandling -> compiler.compile(physicalPlan) + else -> { + // Legacy exception handling is disabled, convert any [SqlException] into a Problem and return + // PassResult.Error. + try { + compiler.compile(physicalPlan) + } catch (e: SqlException) { + val problem = Problem( + SourceLocationMeta( + e.errorContext[Property.LINE_NUMBER]?.longValue() ?: -1, + e.errorContext[Property.COLUMN_NUMBER]?.longValue() ?: -1 + ), + PlanningProblemDetails.CompileError(e.generateMessageNoLocation()) + ) + return PassResult.Error(listOf(problem)) + } + } + } + + return PassResult.Success(expression, listOf()) + } +} diff --git a/lang/src/org/partiql/lang/planner/transforms/AstNormalize.kt b/lang/src/org/partiql/lang/planner/transforms/AstNormalize.kt new file mode 100644 index 0000000000..9e7d362b34 --- /dev/null +++ b/lang/src/org/partiql/lang/planner/transforms/AstNormalize.kt @@ -0,0 +1,25 @@ +package org.partiql.lang.planner.transforms + +import org.partiql.lang.domains.PartiqlAst +import org.partiql.lang.eval.visitors.FromSourceAliasVisitorTransform +import org.partiql.lang.eval.visitors.PipelinedVisitorTransform +import org.partiql.lang.eval.visitors.SelectListItemAliasVisitorTransform +import org.partiql.lang.eval.visitors.SelectStarVisitorTransform + +/** + * Executes the [SelectListItemAliasVisitorTransform], [FromSourceAliasVisitorTransform] and + * [SelectStarVisitorTransform] passes on the receiver. + */ +fun PartiqlAst.Statement.normalize(): PartiqlAst.Statement { + // Since these passes all work on PartiqlAst, we can use a PipelinedVisitorTransform which executes each + // specified VisitorTransform in sequence. + val transforms = PipelinedVisitorTransform( + // Synthesizes unspecified `SELECT AS ...` aliases + SelectListItemAliasVisitorTransform(), + // Synthesizes unspecified `FROM AS ...` aliases + FromSourceAliasVisitorTransform(), + // Changes `SELECT * FROM a, b` to SELECT a.*, b.* FROM a, b` + SelectStarVisitorTransform() + ) + return transforms.transformStatement(this) +} diff --git a/lang/src/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt b/lang/src/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt new file mode 100644 index 0000000000..8b179008fe --- /dev/null +++ b/lang/src/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt @@ -0,0 +1,168 @@ +package org.partiql.lang.planner.transforms + +import org.partiql.lang.domains.PartiqlAst +import org.partiql.lang.domains.PartiqlAstToPartiqlLogicalVisitorTransform +import org.partiql.lang.domains.PartiqlLogical + +/** + * Transforms an instance of [PartiqlAst.Statement] to [PartiqlLogical.Statement]. This representation of the query + * expresses the intent of the query author in terms of PartiQL's relational algebra instead of it its AST. + * + * Performs no semantic checks. + * + * This conversion (and the logical algebra) are early in their lifecycle and so only a limited subset of SFW queries + * are transformable. See `AstToLogicalVisitorTransformTests` to see which queries are transformable. + */ +internal fun PartiqlAst.Statement.toLogicalPlan(): PartiqlLogical.Plan = + PartiqlLogical.build { + plan( + AstToLogicalVisitorTransform.transformStatement(this@toLogicalPlan), + version = PLAN_VERSION_NUMBER + ) + } + +private object AstToLogicalVisitorTransform : PartiqlAstToPartiqlLogicalVisitorTransform() { + + override fun transformExprSelect(node: PartiqlAst.Expr.Select): PartiqlLogical.Expr { + checkForUnsupportedSelectClauses(node) + + var algebra: PartiqlLogical.Bexpr = FromSourceToBexpr.convert(node.from) + + algebra = node.fromLet?.let { fromLet -> + PartiqlLogical.build { + let(algebra, fromLet.letBindings.map { transformLetBinding(it) }, node.fromLet.metas) + } + } ?: algebra + + algebra = node.where?.let { + PartiqlLogical.build { filter(transformExpr(it), algebra, it.metas) } + } ?: algebra + + algebra = node.offset?.let { + PartiqlLogical.build { offset(transformExpr(it), algebra, node.offset.metas) } + } ?: algebra + + algebra = node.limit?.let { + PartiqlLogical.build { limit(transformExpr(it), algebra, node.limit.metas) } + } ?: algebra + + return convertProjectionToBindingsToValues(node, algebra) + } + + private fun convertProjectionToBindingsToValues(node: PartiqlAst.Expr.Select, algebra: PartiqlLogical.Bexpr) = + PartiqlLogical.build { + bindingsToValues( + when (val project = node.project) { + is PartiqlAst.Projection.ProjectValue -> transformExpr(project.value) + is PartiqlAst.Projection.ProjectList -> { + struct( + List(project.projectItems.size) { idx -> + when (val projectItem = project.projectItems[idx]) { + is PartiqlAst.ProjectItem.ProjectExpr -> + structField( + lit( + projectItem.asAlias?.toIonElement() + ?: errAstNotNormalized("SELECT-list item alias not specified") + ), + transformExpr(projectItem.expr), + ) + is PartiqlAst.ProjectItem.ProjectAll -> { + structFields(transformExpr(projectItem.expr), projectItem.metas) + } + } + } + ) + } + is PartiqlAst.Projection.ProjectStar -> + // `SELECT * FROM bar AS b` is rewritten to `SELECT b.* FROM bar as b` by + // [SelectStarVisitorTransform]. Therefore, there is no need to support `SELECT *` here. + errAstNotNormalized("Expected SELECT * to be removed") + + is PartiqlAst.Projection.ProjectPivot -> TODO("PIVOT ...") + }, + algebra, + node.project.metas + ) + }.let { q -> + // in case of SELECT DISTINCT, wrap bindingsToValues in call to filter_distinct + when (node.setq) { + null, is PartiqlAst.SetQuantifier.All -> q + is PartiqlAst.SetQuantifier.Distinct -> PartiqlLogical.build { call("filter_distinct", q) } + } + } + + /** + * Throws [NotImplementedError] if any `SELECT` clauses were used that are not mappable to [PartiqlLogical]. + * + * This function is temporary and will be removed when all the clauses of the `SELECT` expression are mappable + * to [PartiqlLogical]. + */ + private fun checkForUnsupportedSelectClauses(node: PartiqlAst.Expr.Select) { + when { + node.group != null -> TODO("Support for GROUP BY") + node.order != null -> TODO("Support for ORDER BY") + node.having != null -> TODO("Support for HAVING") + } + } + + override fun transformLetBinding(node: PartiqlAst.LetBinding): PartiqlLogical.LetBinding = + PartiqlLogical.build { + letBinding( + transformExpr(node.expr), + varDecl_(node.name, node.name.metas), + node.metas + ) + } + + override fun transformStatementDml(node: PartiqlAst.Statement.Dml): PartiqlLogical.Statement { + TODO("Support for DML") + } + + override fun transformStatementDdl(node: PartiqlAst.Statement.Ddl): PartiqlLogical.Statement { + TODO("Support for DDL") + } + + override fun transformExprStruct(node: PartiqlAst.Expr.Struct): PartiqlLogical.Expr = + PartiqlLogical.build { + struct( + node.fields.map { + structField( + transformExpr(it.first), + transformExpr(it.second) + ) + }, + metas = node.metas + ) + } +} + +private object FromSourceToBexpr : PartiqlAst.FromSource.Converter { + + override fun convertScan(node: PartiqlAst.FromSource.Scan): PartiqlLogical.Bexpr { + val asAlias = node.asAlias ?: errAstNotNormalized("Expected as alias to be non-null") + return PartiqlLogical.build { + scan( + AstToLogicalVisitorTransform.transformExpr(node.expr), + varDecl_(asAlias, asAlias.metas), + node.atAlias?.let { varDecl_(it, it.metas) }, + node.byAlias?.let { varDecl_(it, it.metas) }, + node.metas + ) + } + } + + override fun convertUnpivot(node: PartiqlAst.FromSource.Unpivot): PartiqlLogical.Bexpr { + TODO("Support for UNPIVOT") + } + + override fun convertJoin(node: PartiqlAst.FromSource.Join): PartiqlLogical.Bexpr = + PartiqlLogical.build { + join( + joinType = AstToLogicalVisitorTransform.transformJoinType(node.type), + left = convert(node.left), + right = convert(node.right), + predicate = node.predicate?.let { AstToLogicalVisitorTransform.transformExpr(it) }, + node.metas + ) + } +} diff --git a/lang/src/org/partiql/lang/planner/transforms/LogicalResolvedToDefaultPhysicalVisitorTransform.kt b/lang/src/org/partiql/lang/planner/transforms/LogicalResolvedToDefaultPhysicalVisitorTransform.kt new file mode 100644 index 0000000000..6462757edf --- /dev/null +++ b/lang/src/org/partiql/lang/planner/transforms/LogicalResolvedToDefaultPhysicalVisitorTransform.kt @@ -0,0 +1,94 @@ +package org.partiql.lang.planner.transforms + +import org.partiql.lang.domains.PartiqlLogicalResolved +import org.partiql.lang.domains.PartiqlLogicalResolvedToPartiqlPhysicalVisitorTransform +import org.partiql.lang.domains.PartiqlPhysical + +/** + * Transforms an instance of [PartiqlLogicalResolved.Statement] to [PartiqlPhysical.Statement], + * specifying `(impl default)` for each relational operator. + */ +internal fun PartiqlLogicalResolved.Plan.toDefaultPhysicalPlan() = + LogicalResolvedToDefaultPhysicalVisitorTransform().transformPlan(this) + +internal val DEFAULT_IMPL = PartiqlPhysical.build { impl("default") } + +internal class LogicalResolvedToDefaultPhysicalVisitorTransform : PartiqlLogicalResolvedToPartiqlPhysicalVisitorTransform() { + + /** Copies [PartiqlLogicalResolved.Bexpr.Scan] to [PartiqlPhysical.Bexpr.Scan], adding the default impl. */ + override fun transformBexprScan(node: PartiqlLogicalResolved.Bexpr.Scan): PartiqlPhysical.Bexpr { + val thiz = this + return PartiqlPhysical.build { + scan( + i = DEFAULT_IMPL, + expr = thiz.transformExpr(node.expr), + asDecl = thiz.transformVarDecl(node.asDecl), + atDecl = node.atDecl?.let { thiz.transformVarDecl(it) }, + byDecl = node.byDecl?.let { thiz.transformVarDecl(it) }, + metas = node.metas + ) + } + } + + /** Copies [PartiqlLogicalResolved.Bexpr.Filter] to [PartiqlPhysical.Bexpr.Filter], adding the default impl. */ + override fun transformBexprFilter(node: PartiqlLogicalResolved.Bexpr.Filter): PartiqlPhysical.Bexpr { + val thiz = this + return PartiqlPhysical.build { + filter( + i = DEFAULT_IMPL, + predicate = thiz.transformExpr(node.predicate), + source = thiz.transformBexpr(node.source), + metas = node.metas + ) + } + } + + override fun transformBexprJoin(node: PartiqlLogicalResolved.Bexpr.Join): PartiqlPhysical.Bexpr { + val thiz = this + return PartiqlPhysical.build { + join( + i = DEFAULT_IMPL, + joinType = thiz.transformJoinType(node.joinType), + left = thiz.transformBexpr(node.left), + right = thiz.transformBexpr(node.right), + predicate = node.predicate?.let { thiz.transformExpr(it) }, + metas = node.metas + ) + } + } + + override fun transformBexprOffset(node: PartiqlLogicalResolved.Bexpr.Offset): PartiqlPhysical.Bexpr { + val thiz = this + return PartiqlPhysical.build { + offset( + i = DEFAULT_IMPL, + rowCount = thiz.transformExpr(node.rowCount), + source = thiz.transformBexpr(node.source), + metas = node.metas + ) + } + } + + override fun transformBexprLimit(node: PartiqlLogicalResolved.Bexpr.Limit): PartiqlPhysical.Bexpr { + val thiz = this + return PartiqlPhysical.build { + limit( + i = DEFAULT_IMPL, + rowCount = thiz.transformExpr(node.rowCount), + source = thiz.transformBexpr(node.source), + metas = node.metas + ) + } + } + + override fun transformBexprLet(node: PartiqlLogicalResolved.Bexpr.Let): PartiqlPhysical.Bexpr { + val thiz = this + return PartiqlPhysical.build { + let( + i = DEFAULT_IMPL, + source = thiz.transformBexpr(node.source), + bindings = node.bindings.map { transformLetBinding(it) } + ) + } + } +} diff --git a/lang/src/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransform.kt b/lang/src/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransform.kt new file mode 100644 index 0000000000..4c0f569e14 --- /dev/null +++ b/lang/src/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransform.kt @@ -0,0 +1,433 @@ +package org.partiql.lang.planner.transforms + +import com.amazon.ionelement.api.ionSymbol +import org.partiql.lang.ast.sourceLocation +import org.partiql.lang.domains.PartiqlLogical +import org.partiql.lang.domains.PartiqlLogicalResolved +import org.partiql.lang.domains.PartiqlLogicalToPartiqlLogicalResolvedVisitorTransform +import org.partiql.lang.domains.toBindingCase +import org.partiql.lang.errors.Problem +import org.partiql.lang.errors.ProblemHandler +import org.partiql.lang.eval.BindingName +import org.partiql.lang.eval.builtins.DYNAMIC_LOOKUP_FUNCTION_NAME +import org.partiql.lang.planner.MetadataResolver +import org.partiql.lang.planner.ResolutionResult +import org.partiql.pig.runtime.asPrimitive + +/** + * Resolves all variables by rewriting `(id )` to + * `(id )`) or `(global_id )`, or a `$__dynamic_lookup__` call site (if enabled). + * + * Local variables are resolved independently within this pass, but we rely on [resolver] to resolve global variables. + * + * There are actually two passes here: + * 1. All [PartiqlLogical.VarDecl] nodes are allocated unique indexes (which is stored in a meta). This pass is + * relatively simple. + * 2. Then, during the transform from the `partiql_logical` domain to the `partiql_logical_resolved` domain, we + * determine if the `id` node refers to a global variable, local variable or undefined variable. For global variables, + * the `id` node is replaced with `(global_id )`. For local variables, the original `id` node is + * replaced with a `(id )`), where `` is the index of the corresponding `var_decl`. + * + * When [allowUndefinedVariables] is `false`, the [problemHandler] is notified of any undefined variables. Resolution + * does not stop on the first undefined variable, rather we keep going to provide the end user any additional error + * messaging, unless [ProblemHandler.handleProblem] throws an exception when an error is logged. **If any undefined + * variables are detected, in order to allow traversal to continue, a fake index value (-1) is used in place of a real + * one and the resolved logical plan returned by this function is guaranteed to be invalid.** **Therefore, it is the + * responsibility of callers to check if any problems have been logged with + * [org.partiql.lang.errors.ProblemSeverity.ERROR] and to abort further query planning if so.** + * + * When [allowUndefinedVariables] is `true`, undefined variables are transformed into a dynamic lookup call site, which + * is semantically equivalent to the behavior of the AST evaluator in the same scenario. For example, `name` in the + * query below is undefined: + * + * ```sql + * SELECT name + * FROM foo AS f, bar AS b + * ``` + * Is effectively rewritten to: + * + * ```sql + * SELECT "$__dynamic_lookup__"('name', 'locals_then_globals', 'case_insensitive', f, b) + * FROM foo AS f, bar AS b + * ``` + * + * When `$__dynamic_lookup__` is invoked it will look for the value of `name` in the following locations: (All field + * / variable name comparisons are case-insensitive in this example, although we could have also specified + * `case_sensitive`.) + * + * - The fields of `f` if it is a struct. + * - The fields of `b` if it is a struct. + * - The global scope. + * + * The first value found is returned and the others are ignored. Local variables are searched first + * (`locals_then_globals`) because of the context of the undefined variable. (`name` is not within a `FROM` source.) + * However, to support SQL's FROM-clause semantics this pass specifies `globals_then_locals` when the variable is within + * a `FROM` source, which causes globals to be searched first. + * + * This behavior is backward compatible with the legacy AST evaluator. Furthermore, this rewrite allows us to avoid + * having to support this kind of dynamic lookup within the plan evaluator, thereby reducing its complexity. This + * rewrite can also be disabled entirely by setting [allowUndefinedVariables] to `false`, in which case undefined + * variables to result in a plan-time error instead. + */ +internal fun PartiqlLogical.Plan.toResolvedPlan( + problemHandler: ProblemHandler, + resolver: MetadataResolver, + allowUndefinedVariables: Boolean = false +): PartiqlLogicalResolved.Plan { + // Allocate a unique id for each `VarDecl` + val (planWithAllocatedVariables, allLocals) = this.allocateVariableIds() + + // Transform to `partiql_logical_resolved` while resolving variables. + val resolvedSt = LogicalToLogicalResolvedVisitorTransform(allowUndefinedVariables, problemHandler, resolver) + .transformPlan(planWithAllocatedVariables) + .copy(locals = allLocals) + + return resolvedSt +} + +/** + * A local scope is a list of variable declarations that are produced by a relational operator and an optional + * reference to a parent scope. This is handled separately from global variables. + * + * This is a [List] of [PartiqlLogical.VarDecl] and not a [Map] or some other more efficient data structure + * because most variable lookups are case-insensitive, which makes storing them in a [Map] and benefiting from it hard. + */ +private data class LocalScope(val varDecls: List) + +private data class LogicalToLogicalResolvedVisitorTransform( + /** If set to `true`, do not log errors about undefined variables. Rewrite such variables to a `dynamic_id` node. */ + val allowUndefinedVariables: Boolean, + /** Where to send error reports. */ + private val problemHandler: ProblemHandler, + /** If a variable is not found using [inputScope], we will attempt to locate the binding here instead. */ + private val globals: MetadataResolver, + +) : PartiqlLogicalToPartiqlLogicalResolvedVisitorTransform() { + /** The current [LocalScope]. */ + private var inputScope: LocalScope = LocalScope(emptyList()) + + private enum class VariableLookupStrategy { + LOCALS_THEN_GLOBALS, + GLOBALS_THEN_LOCALS + } + + /** + * This is set to [VariableLookupStrategy.GLOBALS_THEN_LOCALS] for the `` in `(scan ...)` nodes and + * [VariableLookupStrategy.LOCALS_THEN_GLOBALS] for everything else. This is we resolve globals first within + * a `FROM`. + */ + private var currentVariableLookupStrategy: VariableLookupStrategy = VariableLookupStrategy.LOCALS_THEN_GLOBALS + + private fun withVariableLookupStrategy(nextVariableLookupStrategy: VariableLookupStrategy, block: () -> T): T { + val lastVariableLookupStrategy = this.currentVariableLookupStrategy + this.currentVariableLookupStrategy = nextVariableLookupStrategy + return block().also { + this.currentVariableLookupStrategy = lastVariableLookupStrategy + } + } + + private fun withInputScope(nextScope: LocalScope, block: () -> T): T { + val lastScope = inputScope + inputScope = nextScope + return block().also { + inputScope = lastScope + } + } + + private fun PartiqlLogical.Expr.Id.asGlobalId(uniqueId: String): PartiqlLogicalResolved.Expr.GlobalId = + PartiqlLogicalResolved.build { + globalId_( + uniqueId = uniqueId.asPrimitive(), + case = this@LogicalToLogicalResolvedVisitorTransform.transformCaseSensitivity(this@asGlobalId.case), + metas = this@asGlobalId.metas + ) + } + + private fun PartiqlLogical.Expr.Id.asLocalId(index: Int): PartiqlLogicalResolved.Expr = + PartiqlLogicalResolved.build { + localId_(index.asPrimitive(), this@asLocalId.metas) + } + + private fun PartiqlLogical.Expr.Id.asErrorId(): PartiqlLogicalResolved.Expr = + PartiqlLogicalResolved.build { + localId_((-1).asPrimitive(), this@asErrorId.metas) + } + + override fun transformPlan(node: PartiqlLogical.Plan): PartiqlLogicalResolved.Plan = + PartiqlLogicalResolved.build { + plan_( + stmt = transformStatement(node.stmt), + version = node.version, + locals = emptyList(), // NOTE: locals will be populated by caller + metas = node.metas + ) + } + + override fun transformBexprScan_expr(node: PartiqlLogical.Bexpr.Scan): PartiqlLogicalResolved.Expr = + withVariableLookupStrategy(VariableLookupStrategy.GLOBALS_THEN_LOCALS) { + super.transformBexprScan_expr(node) + } + + override fun transformBexprJoin_right(node: PartiqlLogical.Bexpr.Join): PartiqlLogicalResolved.Bexpr { + // No need to change the current scope of the node.left. Node.right gets the current scope + + // the left output scope. + val leftOutputScope = getOutputScope(node.left) + val rightInputScope = inputScope.concatenate(leftOutputScope) + return withInputScope(rightInputScope) { + this.transformBexpr(node.right) + } + } + + override fun transformBexprLet(node: PartiqlLogical.Bexpr.Let): PartiqlLogicalResolved.Bexpr { + val thiz = this + return PartiqlLogicalResolved.build { + let( + source = transformBexpr(node.source), + bindings = withInputScope(getOutputScope(node.source)) { + // This "wonderful" (depending on your definition of the term) bit of code performs a fold + // combined with a map... The accumulator is a Pair, + // LocalScope>. + // accumulator.first: the current list of let bindings that have been transformed so far + // accumulator.second: an instance of LocalScope that includes all the variables defined up to + // this point, not including the current let binding. + val initial = emptyList() to thiz.inputScope + val (newBindings: List, _: LocalScope) = + node.bindings.fold(initial) { accumulator, current -> + // Each let binding's expression should be resolved within the scope of the *last* + // let binding (or the current scope if this is the first let binding). + val resolvedValueExpr = withInputScope(accumulator.second) { + thiz.transformExpr(current.value) + } + val nextScope = LocalScope(listOf(current.decl)).concatenate(accumulator.second) + val transformedLetBindings = accumulator.first + PartiqlLogicalResolved.build { + letBinding(resolvedValueExpr, transformVarDecl(current.decl)) + } + transformedLetBindings to nextScope + } + newBindings + } + ) + } + } + + // We are currently using bindings_to_values to denote a sub-query, which works for all the use cases we are + // presented with today, as every SELECT statement is replaced with `bindings_to_values at the top level. + override fun transformExprBindingsToValues(node: PartiqlLogical.Expr.BindingsToValues): PartiqlLogicalResolved.Expr = + // If we are in the expr of a scan node, we need to reset the lookup strategy + withVariableLookupStrategy(VariableLookupStrategy.LOCALS_THEN_GLOBALS) { + super.transformExprBindingsToValues(node) + } + + /** + * Grabs the index meta added by [VariableIdAllocator] and stores it as an element in + * [PartiqlLogicalResolved.VarDecl]. + */ + override fun transformVarDecl(node: PartiqlLogical.VarDecl): PartiqlLogicalResolved.VarDecl = + PartiqlLogicalResolved.build { + varDecl(node.indexMeta.toLong()) + } + + /** + * Returns [ResolutionResult.LocalVariable] if [bindingName] refers to a local variable. + * + * Otherwise, returns [ResolutionResult.Undefined]. (Elsewhere, [globals] will be checked next.) + */ + private fun lookupLocalVariable(bindingName: BindingName): ResolutionResult { + val found = this.inputScope.varDecls.firstOrNull { bindingName.isEquivalentTo(it.name.text) } + return if (found == null) { + ResolutionResult.Undefined + } else { + ResolutionResult.LocalVariable(found.indexMeta) + } + } + + /** + * Resolves the logical `(id ...)` node node to a `(local_id ...)`, `(global_id ...)`, or dynamic `(id...)` + * variable. + */ + override fun transformExprId(node: PartiqlLogical.Expr.Id): PartiqlLogicalResolved.Expr { + val bindingName = BindingName(node.name.text, node.case.toBindingCase()) + + val resolutionResult = if ( + this.currentVariableLookupStrategy == VariableLookupStrategy.GLOBALS_THEN_LOCALS && + node.qualifier is PartiqlLogical.ScopeQualifier.Unqualified + ) { + // look up variable in globals first, then locals + when (val globalResolutionResult = globals.resolveVariable(bindingName)) { + ResolutionResult.Undefined -> lookupLocalVariable(bindingName) + else -> globalResolutionResult + } + } else { + // look up variable in locals first, then globals. + when (val localResolutionResult = lookupLocalVariable(bindingName)) { + ResolutionResult.Undefined -> globals.resolveVariable(bindingName) + else -> localResolutionResult + } + } + return when (resolutionResult) { + is ResolutionResult.GlobalVariable -> { + node.asGlobalId(resolutionResult.uniqueId) + } + is ResolutionResult.LocalVariable -> { + node.asLocalId(resolutionResult.index) + } + ResolutionResult.Undefined -> { + if (this.allowUndefinedVariables) { + node.asDynamicLookupCallsite( + currentDynamicResolutionCandidates() + .map { + PartiqlLogicalResolved.build { + localId(it.indexMeta.toLong()) + } + } + ) + } else { + node.asErrorId().also { + problemHandler.handleProblem( + Problem( + node.metas.sourceLocation ?: error("MetaContainer is missing SourceLocationMeta"), + PlanningProblemDetails.UndefinedVariable( + node.name.text, + node.case is PartiqlLogical.CaseSensitivity.CaseSensitive + ) + ) + ) + } + } + } + } + } + + /** + * Returns a list of variables accessible from the current scope which contain variables that may contain + * an unqualified variable, in the order that they should be searched. + */ + fun currentDynamicResolutionCandidates(): List = + inputScope.varDecls.filter { it.includeInDynamicResolution } + + override fun transformExprBindingsToValues_exp(node: PartiqlLogical.Expr.BindingsToValues): PartiqlLogicalResolved.Expr { + val bindings = getOutputScope(node.query).concatenate(this.inputScope) + return withInputScope(bindings) { + this.transformExpr(node.exp) + } + } + + override fun transformBexprFilter_predicate(node: PartiqlLogical.Bexpr.Filter): PartiqlLogicalResolved.Expr { + val bindings = getOutputScope(node.source) + return withInputScope(bindings) { + this.transformExpr(node.predicate) + } + } + + override fun transformBexprJoin_predicate(node: PartiqlLogical.Bexpr.Join): PartiqlLogicalResolved.Expr? { + val bindings = getOutputScope(node) + return withInputScope(bindings) { + node.predicate?.let { this.transformExpr(it) } + } + } + + /** + * This should be called any time we create a [LocalScope] with more than one variable to prevent duplicate + * variable names. When checking for duplication, the letter case of the variable names is not considered. + * + * Example: + * + * ``` + * SELECT * FROM foo AS X AT x + * duplicate variable: ^ + * ``` + */ + private fun checkForDuplicateVariables(varDecls: List) { + val usedVariableNames = hashSetOf() + varDecls.forEach { varDecl -> + val loweredVariableName = varDecl.name.text.toLowerCase() + if (usedVariableNames.contains(loweredVariableName)) { + this.problemHandler.handleProblem( + Problem( + varDecl.metas.sourceLocation ?: error("VarDecl was missing source location meta"), + PlanningProblemDetails.VariablePreviouslyDefined(varDecl.name.text) + ) + ) + } + usedVariableNames.add(loweredVariableName) + } + } + + /** + * Computes a [LocalScope] for containing all of the variables that are output from [bexpr]. + */ + private fun getOutputScope(bexpr: PartiqlLogical.Bexpr): LocalScope = + when (bexpr) { + is PartiqlLogical.Bexpr.Filter -> getOutputScope(bexpr.source) + is PartiqlLogical.Bexpr.Limit -> getOutputScope(bexpr.source) + is PartiqlLogical.Bexpr.Offset -> getOutputScope(bexpr.source) + is PartiqlLogical.Bexpr.Scan -> { + LocalScope( + listOfNotNull(bexpr.asDecl.markForDynamicResolution(), bexpr.atDecl, bexpr.byDecl).also { + checkForDuplicateVariables(it) + } + ) + } + is PartiqlLogical.Bexpr.Join -> { + val (leftBexpr, rightBexpr) = when (bexpr.joinType) { + is PartiqlLogical.JoinType.Full, + is PartiqlLogical.JoinType.Inner, + is PartiqlLogical.JoinType.Left -> bexpr.left to bexpr.right + // right join is same as left join but right and left operands are swapped. + is PartiqlLogical.JoinType.Right -> bexpr.right to bexpr.left + } + val leftScope = getOutputScope(leftBexpr) + val rightScope = getOutputScope(rightBexpr) + // right scope is first to allow RHS variables to "shadow" LHS variables. + rightScope.concatenate(leftScope) + } + is PartiqlLogical.Bexpr.Let -> { + val sourceScope = getOutputScope(bexpr.source) + // Note that .reversed() is important here to ensure that variable shadowing works correctly. + val letVariables = bexpr.bindings.reversed().map { it.decl } + sourceScope.concatenate(letVariables) + } + } + + private fun LocalScope.concatenate(other: LocalScope): LocalScope = + this.concatenate(other.varDecls) + + private fun LocalScope.concatenate(other: List): LocalScope { + val concatenatedScopeVariables = this.varDecls + other + return LocalScope(concatenatedScopeVariables) + } + + private fun PartiqlLogical.Expr.Id.asDynamicLookupCallsite( + search: List + ): PartiqlLogicalResolved.Expr { + val caseSensitivityString = when (case) { + is PartiqlLogical.CaseSensitivity.CaseInsensitive -> "case_insensitive" + is PartiqlLogical.CaseSensitivity.CaseSensitive -> "case_sensitive" + } + val variableLookupStrategy = when (currentVariableLookupStrategy) { + // If we are not in a FROM source, ignore the scope qualifier + VariableLookupStrategy.LOCALS_THEN_GLOBALS -> VariableLookupStrategy.LOCALS_THEN_GLOBALS + // If we are in a FROM source, allow scope qualifier to override the current variable lookup strategy. + VariableLookupStrategy.GLOBALS_THEN_LOCALS -> when (this.qualifier) { + is PartiqlLogical.ScopeQualifier.LocalsFirst -> VariableLookupStrategy.LOCALS_THEN_GLOBALS + is PartiqlLogical.ScopeQualifier.Unqualified -> VariableLookupStrategy.GLOBALS_THEN_LOCALS + } + }.toString().toLowerCase() + return PartiqlLogicalResolved.build { + call( + funcName = DYNAMIC_LOOKUP_FUNCTION_NAME, + args = listOf( + lit(name.toIonElement()), + lit(ionSymbol(caseSensitivityString)), + lit(ionSymbol(variableLookupStrategy)), + ) + search, + metas = this@asDynamicLookupCallsite.metas + ) + } + } +} + +/** Marks a variable for dynamic resolution--i.e. if undefined, this vardecl will be included in any dynamic_id lookup. */ +fun PartiqlLogical.VarDecl.markForDynamicResolution() = this.withMeta("\$include_in_dynamic_resolution", Unit) +/** Returns true of the [VarDecl] has been marked to participate in unqualified field resolution */ +val PartiqlLogical.VarDecl.includeInDynamicResolution get() = this.metas.containsKey("\$include_in_dynamic_resolution") diff --git a/lang/src/org/partiql/lang/planner/transforms/PlanningProblemDetails.kt b/lang/src/org/partiql/lang/planner/transforms/PlanningProblemDetails.kt new file mode 100644 index 0000000000..5819479d2e --- /dev/null +++ b/lang/src/org/partiql/lang/planner/transforms/PlanningProblemDetails.kt @@ -0,0 +1,46 @@ +package org.partiql.lang.planner.transforms + +import org.partiql.lang.errors.ProblemDetails +import org.partiql.lang.errors.ProblemSeverity + +/** + * Contains detailed information about errors that may occur during query planning. + * + * This information can be used to generate end-user readable error messages and is also easy to assert + * equivalence in unit tests. + */ +sealed class PlanningProblemDetails( + override val severity: ProblemSeverity, + val messageFormatter: () -> String +) : ProblemDetails { + + override val message: String get() = messageFormatter() + + data class ParseError(val parseErrorMessage: String) : + PlanningProblemDetails(ProblemSeverity.ERROR, { parseErrorMessage }) + + data class CompileError(val errorMessage: String) : + PlanningProblemDetails(ProblemSeverity.ERROR, { errorMessage }) + + data class UndefinedVariable(val variableName: String, val caseSensitive: Boolean) : + PlanningProblemDetails( + ProblemSeverity.ERROR, + { + "Undefined variable '$variableName'." + + if (caseSensitive) { + // Individuals that are new to SQL often try to use double quotes for string literals. + // Let's help them out a bit. + " Hint: did you intend to use single-quotes (') here? Remember that double-quotes (\") denote " + + "quoted identifiers and single-quotes denote strings." + } else { + "" + } + } + ) + + data class VariablePreviouslyDefined(val variableName: String) : + PlanningProblemDetails( + ProblemSeverity.ERROR, + { "The variable '$variableName' was previously defined." } + ) +} diff --git a/lang/src/org/partiql/lang/planner/transforms/Util.kt b/lang/src/org/partiql/lang/planner/transforms/Util.kt new file mode 100644 index 0000000000..73f1fd6716 --- /dev/null +++ b/lang/src/org/partiql/lang/planner/transforms/Util.kt @@ -0,0 +1,30 @@ + +package org.partiql.lang.planner.transforms + +/** + * This is the semantic version number of the logical and physical plans supported by this version of PartiQL. This + * deals only with compatibility of trees that have been persisted as s-expressions with their PIG-generated + * classes. The format is: `.`. One or both of these will need to be changed when the following + * events happen: + * + * - Increment `` and set `` to `0` when a change to `partiql.ion` is introduced that will cause the + * persisted s-expressions to fail to load under the new version. Examples include: + * - Making an element non-nullable that was previously nullable. + * - Renaming any type or sum variant. + * - Removing a sum variant. + * - Adding or removing any element of any product type. + * - Changing the data type of any element. + * - Adding a required field to a record type. + * - Increment `` when a change to `partiql.ion` is introduced that will *not* cause the persisted s-expressions + * to fail to load into the PIG-generated classes. Examples include: + * - Adding a new, optional (nullable) field to a record type. + * - Adding a new sum variant. + * - Changing an element that was previously non-nullable nullable. + * + * It would be nice to embed semantic version in the PIG type universe somehow, but this isn't yet implemented, so we + * have to include it here for now. See: https://github.com/partiql/partiql-ir-generator/issues/121 + */ +const val PLAN_VERSION_NUMBER = "0.0" + +internal fun errAstNotNormalized(message: String): Nothing = + error("$message - have the basic visitor transforms been executed first?") diff --git a/lang/src/org/partiql/lang/planner/transforms/VariableIdAllocator.kt b/lang/src/org/partiql/lang/planner/transforms/VariableIdAllocator.kt new file mode 100644 index 0000000000..de3ffaeb68 --- /dev/null +++ b/lang/src/org/partiql/lang/planner/transforms/VariableIdAllocator.kt @@ -0,0 +1,44 @@ +package org.partiql.lang.planner.transforms + +import org.partiql.lang.domains.PartiqlLogical +import org.partiql.lang.domains.PartiqlLogicalResolved + +/** + * Allocates register indexes for all local variables in the plan. + * + * Returns pair containing a logical plan where all `var_decl`s have a [VARIABLE_ID_META_TAG] meta indicating the + * variable index (which can be utilized later when establishing variable scoping) and list of all local variables + * declared within the plan, which becomes the `locals` sub-node of the `plan` node. + */ +internal fun PartiqlLogical.Plan.allocateVariableIds(): Pair> { + val allLocals = mutableListOf() + val planWithAllocatedVariables = VariableIdAllocator(allLocals).transformPlan(this) + return planWithAllocatedVariables to allLocals.toList() +} + +private const val VARIABLE_ID_META_TAG = "\$variable_id" + +internal val PartiqlLogical.VarDecl.indexMeta + get() = this.metas[VARIABLE_ID_META_TAG] as? Int ?: error("Meta $VARIABLE_ID_META_TAG was not present") + +/** + * Allocates a unique index to every `var_decl` in the logical plan. We use metas for this step to avoid a having + * create another permuted domain. + */ +private class VariableIdAllocator( + val allLocals: MutableList +) : PartiqlLogical.VisitorTransform() { + private var nextVariableId = 0 + + override fun transformVarDecl(node: PartiqlLogical.VarDecl): PartiqlLogical.VarDecl = + node.withMeta(VARIABLE_ID_META_TAG, nextVariableId).also { + + allLocals.add( + PartiqlLogicalResolved.build { + localVariable(node.name.text, nextVariableId.toLong()) + } + ) + + nextVariableId++ + } +} diff --git a/lang/src/org/partiql/lang/types/PartiqlAstTypeExtensions.kt b/lang/src/org/partiql/lang/types/PartiqlAstTypeExtensions.kt index d9ef4f9096..d726074cae 100644 --- a/lang/src/org/partiql/lang/types/PartiqlAstTypeExtensions.kt +++ b/lang/src/org/partiql/lang/types/PartiqlAstTypeExtensions.kt @@ -1,85 +1,20 @@ package org.partiql.lang.types import org.partiql.lang.domains.PartiqlAst +import org.partiql.lang.domains.PartiqlPhysical -/** - * Helper to convert [PartiqlAst.Type] in AST to a [TypedOpParameter]. - */ -fun PartiqlAst.Type.toTypedOpParameter(customTypedOpParameters: Map): TypedOpParameter = when (this) { - is PartiqlAst.Type.MissingType -> TypedOpParameter(StaticType.MISSING) - is PartiqlAst.Type.NullType -> TypedOpParameter(StaticType.NULL) - is PartiqlAst.Type.BooleanType -> TypedOpParameter(StaticType.BOOL) - is PartiqlAst.Type.SmallintType -> TypedOpParameter(IntType(IntType.IntRangeConstraint.SHORT)) - is PartiqlAst.Type.Integer4Type -> TypedOpParameter(IntType(IntType.IntRangeConstraint.INT4)) - is PartiqlAst.Type.Integer8Type -> TypedOpParameter(IntType(IntType.IntRangeConstraint.LONG)) - is PartiqlAst.Type.IntegerType -> TypedOpParameter(IntType(IntType.IntRangeConstraint.LONG)) - is PartiqlAst.Type.FloatType, is PartiqlAst.Type.RealType, is PartiqlAst.Type.DoublePrecisionType -> TypedOpParameter(StaticType.FLOAT) - is PartiqlAst.Type.DecimalType -> when { - this.precision == null && this.scale == null -> TypedOpParameter(StaticType.DECIMAL) - this.precision != null && this.scale == null -> TypedOpParameter(DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(this.precision.value.toInt()))) - else -> TypedOpParameter( - DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(this.precision!!.value.toInt(), this.scale!!.value.toInt())) - ) - } - is PartiqlAst.Type.NumericType -> when { - this.precision == null && this.scale == null -> TypedOpParameter(StaticType.DECIMAL) - this.precision != null && this.scale == null -> TypedOpParameter(DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(this.precision.value.toInt()))) - else -> TypedOpParameter( - DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(this.precision!!.value.toInt(), this.scale!!.value.toInt())) - ) - } - is PartiqlAst.Type.TimestampType -> TypedOpParameter(StaticType.TIMESTAMP) - is PartiqlAst.Type.CharacterType -> when { - this.length == null -> TypedOpParameter(StringType(StringType.StringLengthConstraint.Constrained(NumberConstraint.Equals(1)))) - else -> TypedOpParameter( - StringType( - StringType.StringLengthConstraint.Constrained( - NumberConstraint.Equals(this.length.value.toInt()) - ) - ) - ) - } - is PartiqlAst.Type.CharacterVaryingType -> when (this.length) { - null -> TypedOpParameter(StringType(StringType.StringLengthConstraint.Unconstrained)) - else -> TypedOpParameter(StringType(StringType.StringLengthConstraint.Constrained(NumberConstraint.UpTo(this.length.value.toInt())))) - } - is PartiqlAst.Type.StringType -> TypedOpParameter(StaticType.STRING) - is PartiqlAst.Type.SymbolType -> TypedOpParameter(StaticType.SYMBOL) - is PartiqlAst.Type.ClobType -> TypedOpParameter(StaticType.CLOB) - is PartiqlAst.Type.BlobType -> TypedOpParameter(StaticType.BLOB) - is PartiqlAst.Type.StructType -> TypedOpParameter(StaticType.STRUCT) - is PartiqlAst.Type.TupleType -> TypedOpParameter(StaticType.STRUCT) - is PartiqlAst.Type.ListType -> TypedOpParameter(StaticType.LIST) - is PartiqlAst.Type.SexpType -> TypedOpParameter(StaticType.SEXP) - is PartiqlAst.Type.BagType -> TypedOpParameter(StaticType.BAG) - is PartiqlAst.Type.AnyType -> TypedOpParameter(StaticType.ANY) - is PartiqlAst.Type.CustomType -> - customTypedOpParameters.mapKeys { - (k, _) -> - k.toLowerCase() - }[this.name.text.toLowerCase()] ?: error("Could not find parameter for $this") - is PartiqlAst.Type.DateType -> TypedOpParameter(StaticType.DATE) - is PartiqlAst.Type.TimeType -> TypedOpParameter( - TimeType(this.precision?.value?.toInt(), withTimeZone = false) - ) - is PartiqlAst.Type.TimeWithTimeZoneType -> TypedOpParameter( - TimeType(this.precision?.value?.toInt(), withTimeZone = true) - ) - is PartiqlAst.Type.EsAny, - is PartiqlAst.Type.EsBoolean, - is PartiqlAst.Type.EsFloat, - is PartiqlAst.Type.EsInteger, - is PartiqlAst.Type.EsText, - is PartiqlAst.Type.RsBigint, - is PartiqlAst.Type.RsBoolean, - is PartiqlAst.Type.RsDoublePrecision, - is PartiqlAst.Type.RsInteger, - is PartiqlAst.Type.RsReal, - is PartiqlAst.Type.RsVarcharMax, - is PartiqlAst.Type.SparkBoolean, - is PartiqlAst.Type.SparkDouble, - is PartiqlAst.Type.SparkFloat, - is PartiqlAst.Type.SparkInteger, - is PartiqlAst.Type.SparkLong, - is PartiqlAst.Type.SparkShort -> error("$this node should not be present in PartiQLAST. Consider transforming the AST using CustomTypeVisitorTransform.") +/** Helper to convert [PartiqlAst.Type] in AST to a [TypedOpParameter]. */ +fun PartiqlAst.Type.toTypedOpParameter(customTypedOpParameters: Map): TypedOpParameter { + // hack: to avoid duplicating the function `PartiqlAst.Type.toTypedOpParameter`, we have to convert this + // PartiqlAst.Type to PartiqlPhysical.Type. The easiest way to do that without using a visitor transform + // (which is overkill and comes with some downsides for something this simple), is to transform to and from + // s-expressions again. This will work without difficulty as long as PartiqlAst.Type remains unchanged in all + // permuted domains between PartiqlAst and PartiqlPhysical. + + // This is really just a temporary measure, however, which must exist for as long as the type inferencer works only + // on PartiqlAst. When it has been migrated to use PartiqlPhysical instead, there should no longer be a reason + // to keep this function around. + val sexp = this.toIonElement() + val physicalType = PartiqlPhysical.transform(sexp) as PartiqlPhysical.Type + return physicalType.toTypedOpParameter(customTypedOpParameters) } diff --git a/lang/src/org/partiql/lang/types/PartiqlPhysicalTypeExtensions.kt b/lang/src/org/partiql/lang/types/PartiqlPhysicalTypeExtensions.kt new file mode 100644 index 0000000000..71f20f06ce --- /dev/null +++ b/lang/src/org/partiql/lang/types/PartiqlPhysicalTypeExtensions.kt @@ -0,0 +1,85 @@ +package org.partiql.lang.types + +import org.partiql.lang.domains.PartiqlPhysical + +/** + * Helper to convert [PartiqlPhysical.Type] in AST to a [TypedOpParameter]. + */ +fun PartiqlPhysical.Type.toTypedOpParameter(customTypedOpParameters: Map): TypedOpParameter = when (this) { + is PartiqlPhysical.Type.MissingType -> TypedOpParameter(StaticType.MISSING) + is PartiqlPhysical.Type.NullType -> TypedOpParameter(StaticType.NULL) + is PartiqlPhysical.Type.BooleanType -> TypedOpParameter(StaticType.BOOL) + is PartiqlPhysical.Type.SmallintType -> TypedOpParameter(IntType(IntType.IntRangeConstraint.SHORT)) + is PartiqlPhysical.Type.Integer4Type -> TypedOpParameter(IntType(IntType.IntRangeConstraint.INT4)) + is PartiqlPhysical.Type.Integer8Type -> TypedOpParameter(IntType(IntType.IntRangeConstraint.LONG)) + is PartiqlPhysical.Type.IntegerType -> TypedOpParameter(IntType(IntType.IntRangeConstraint.LONG)) + is PartiqlPhysical.Type.FloatType, is PartiqlPhysical.Type.RealType, is PartiqlPhysical.Type.DoublePrecisionType -> TypedOpParameter(StaticType.FLOAT) + is PartiqlPhysical.Type.DecimalType -> when { + this.precision == null && this.scale == null -> TypedOpParameter(StaticType.DECIMAL) + this.precision != null && this.scale == null -> TypedOpParameter(DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(this.precision.value.toInt()))) + else -> TypedOpParameter( + DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(this.precision!!.value.toInt(), this.scale!!.value.toInt())) + ) + } + is PartiqlPhysical.Type.NumericType -> when { + this.precision == null && this.scale == null -> TypedOpParameter(StaticType.DECIMAL) + this.precision != null && this.scale == null -> TypedOpParameter(DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(this.precision.value.toInt()))) + else -> TypedOpParameter( + DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(this.precision!!.value.toInt(), this.scale!!.value.toInt())) + ) + } + is PartiqlPhysical.Type.TimestampType -> TypedOpParameter(StaticType.TIMESTAMP) + is PartiqlPhysical.Type.CharacterType -> when { + this.length == null -> TypedOpParameter(StringType(StringType.StringLengthConstraint.Constrained(NumberConstraint.Equals(1)))) + else -> TypedOpParameter( + StringType( + StringType.StringLengthConstraint.Constrained( + NumberConstraint.Equals(this.length.value.toInt()) + ) + ) + ) + } + is PartiqlPhysical.Type.CharacterVaryingType -> when (this.length) { + null -> TypedOpParameter(StringType(StringType.StringLengthConstraint.Unconstrained)) + else -> TypedOpParameter(StringType(StringType.StringLengthConstraint.Constrained(NumberConstraint.UpTo(this.length.value.toInt())))) + } + is PartiqlPhysical.Type.StringType -> TypedOpParameter(StaticType.STRING) + is PartiqlPhysical.Type.SymbolType -> TypedOpParameter(StaticType.SYMBOL) + is PartiqlPhysical.Type.ClobType -> TypedOpParameter(StaticType.CLOB) + is PartiqlPhysical.Type.BlobType -> TypedOpParameter(StaticType.BLOB) + is PartiqlPhysical.Type.StructType -> TypedOpParameter(StaticType.STRUCT) + is PartiqlPhysical.Type.TupleType -> TypedOpParameter(StaticType.STRUCT) + is PartiqlPhysical.Type.ListType -> TypedOpParameter(StaticType.LIST) + is PartiqlPhysical.Type.SexpType -> TypedOpParameter(StaticType.SEXP) + is PartiqlPhysical.Type.BagType -> TypedOpParameter(StaticType.BAG) + is PartiqlPhysical.Type.AnyType -> TypedOpParameter(StaticType.ANY) + is PartiqlPhysical.Type.CustomType -> + customTypedOpParameters.mapKeys { + (k, _) -> + k.toLowerCase() + }[this.name.text.toLowerCase()] ?: error("Could not find parameter for $this") + is PartiqlPhysical.Type.DateType -> TypedOpParameter(StaticType.DATE) + is PartiqlPhysical.Type.TimeType -> TypedOpParameter( + TimeType(this.precision?.value?.toInt(), withTimeZone = false) + ) + is PartiqlPhysical.Type.TimeWithTimeZoneType -> TypedOpParameter( + TimeType(this.precision?.value?.toInt(), withTimeZone = true) + ) + is PartiqlPhysical.Type.EsAny, + is PartiqlPhysical.Type.EsBoolean, + is PartiqlPhysical.Type.EsFloat, + is PartiqlPhysical.Type.EsInteger, + is PartiqlPhysical.Type.EsText, + is PartiqlPhysical.Type.RsBigint, + is PartiqlPhysical.Type.RsBoolean, + is PartiqlPhysical.Type.RsDoublePrecision, + is PartiqlPhysical.Type.RsInteger, + is PartiqlPhysical.Type.RsReal, + is PartiqlPhysical.Type.RsVarcharMax, + is PartiqlPhysical.Type.SparkBoolean, + is PartiqlPhysical.Type.SparkDouble, + is PartiqlPhysical.Type.SparkFloat, + is PartiqlPhysical.Type.SparkInteger, + is PartiqlPhysical.Type.SparkLong, + is PartiqlPhysical.Type.SparkShort -> error("$this node should not be present in PartiqlPhysical. Consider transforming the AST using CustomTypeVisitorTransform.") +} diff --git a/lang/test/org/partiql/lang/eval/ErrorSignalerTests.kt b/lang/test/org/partiql/lang/eval/ErrorSignalerTests.kt index f3de26e967..3c509ddb91 100644 --- a/lang/test/org/partiql/lang/eval/ErrorSignalerTests.kt +++ b/lang/test/org/partiql/lang/eval/ErrorSignalerTests.kt @@ -37,8 +37,8 @@ class ErrorSignalerTests { ex } assertEquals(ex.errorCode, ErrorCode.EVALUATOR_CAST_FAILED) - assertEquals(ex.errorContext!![Property.LINE_NUMBER]!!.longValue(), 4L) - assertEquals(ex.errorContext!![Property.COLUMN_NUMBER]!!.longValue(), 2L) + assertEquals(ex.errorContext[Property.LINE_NUMBER]!!.longValue(), 4L) + assertEquals(ex.errorContext[Property.COLUMN_NUMBER]!!.longValue(), 2L) } private fun runTest(ctx1: ErrorSignaler, value: Int): ExprValue = diff --git a/lang/test/org/partiql/lang/eval/EvaluatingCompilerExceptionsTest.kt b/lang/test/org/partiql/lang/eval/EvaluatingCompilerExceptionsTest.kt index f691b218f5..d3fd56a2f2 100644 --- a/lang/test/org/partiql/lang/eval/EvaluatingCompilerExceptionsTest.kt +++ b/lang/test/org/partiql/lang/eval/EvaluatingCompilerExceptionsTest.kt @@ -21,6 +21,7 @@ import org.junit.jupiter.params.provider.ArgumentsSource import org.partiql.lang.errors.ErrorCode import org.partiql.lang.errors.Property import org.partiql.lang.eval.evaluatortestframework.EvaluatorErrorTestCase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.util.ArgumentsProviderBase import org.partiql.lang.util.propertyValueMapOf import org.partiql.lang.util.rootCause @@ -33,7 +34,6 @@ class EvaluatingCompilerExceptionsTest : EvaluatorTestBase() { // to follow a pattern that we'd like to change anyway. // FIXME - these tests don't seem to work, and when enabled the options are set but the `FLOAT` type is missing // the parameter at the point we test it in the EvaluatingCompiler - // XXX - for some reason, @Ignore did not work on this parameterized test. @Disabled @ParameterizedTest @ArgumentsSource(ErrorTestCasesTestCases::class) @@ -99,28 +99,32 @@ class EvaluatingCompilerExceptionsTest : EvaluatorTestBase() { """SELECT VALUE a FROM `[{v:5}]` AS item, @item.v AS a, @item.v AS a""", ErrorCode.EVALUATOR_AMBIGUOUS_BINDING, expectedErrorContext = propertyValueMapOf(1, 14, Property.BINDING_NAME to "a", Property.BINDING_NAME_MATCHES to "a, a"), - expectedPermissiveModeResult = "<>" + expectedPermissiveModeResult = "<>", + target = EvaluatorTestTarget.COMPILER_PIPELINE ) @Test fun topLevelCountStar() = runEvaluatorErrorTestCase( """COUNT(*)""", ErrorCode.EVALUATOR_COUNT_START_NOT_ALLOWED, - expectedErrorContext = propertyValueMapOf(1, 1) + expectedErrorContext = propertyValueMapOf(1, 1), + target = EvaluatorTestTarget.COMPILER_PIPELINE ) @Test fun selectValueCountStar() = runEvaluatorErrorTestCase( """SELECT VALUE COUNT(*) FROM numbers""", ErrorCode.EVALUATOR_COUNT_START_NOT_ALLOWED, - expectedErrorContext = propertyValueMapOf(1, 14) + expectedErrorContext = propertyValueMapOf(1, 14), + target = EvaluatorTestTarget.COMPILER_PIPELINE ) @Test fun selectListNestedAggregateCall() = runEvaluatorErrorTestCase( """SELECT SUM(AVG(n)) FROM <> AS n""", ErrorCode.EVALUATOR_INVALID_ARGUMENTS_FOR_AGG_FUNCTION, - expectedErrorContext = propertyValueMapOf(1, 12) + expectedErrorContext = propertyValueMapOf(1, 12), + target = EvaluatorTestTarget.COMPILER_PIPELINE ) private val sqlWithUndefinedVariable = "SELECT VALUE y FROM << 'el1' >> AS x" @@ -140,7 +144,8 @@ class EvaluatingCompilerExceptionsTest : EvaluatorTestBase() { // Same query as previous test--but DO NOT throw exception this time because of UndefinedVariableBehavior.MISSING runEvaluatorTestCase( sqlWithUndefinedVariable, expectedResult = "[null]", - compileOptionsBuilderBlock = { undefinedVariable(UndefinedVariableBehavior.MISSING) } + compileOptionsBuilderBlock = { undefinedVariable(UndefinedVariableBehavior.MISSING) }, + target = EvaluatorTestTarget.COMPILER_PIPELINE ) private val sqlWithUndefinedQuotedVariable = "SELECT VALUE \"y\" FROM << 'el1' >> AS x" @@ -151,7 +156,7 @@ class EvaluatingCompilerExceptionsTest : EvaluatorTestBase() { sqlWithUndefinedQuotedVariable, ErrorCode.EVALUATOR_QUOTED_BINDING_DOES_NOT_EXIST, propertyValueMapOf(1, 14, Property.BINDING_NAME to "y"), - expectedPermissiveModeResult = "<>" + expectedPermissiveModeResult = "<>", ) } @@ -160,7 +165,8 @@ class EvaluatingCompilerExceptionsTest : EvaluatorTestBase() { // Same query as previous test--but DO NOT throw exception this time because of UndefinedVariableBehavior.MISSING runEvaluatorTestCase( sqlWithUndefinedQuotedVariable, expectedResult = "[null]", - compileOptionsBuilderBlock = { undefinedVariable(UndefinedVariableBehavior.MISSING) } + compileOptionsBuilderBlock = { undefinedVariable(UndefinedVariableBehavior.MISSING) }, + target = EvaluatorTestTarget.COMPILER_PIPELINE ) @Test @@ -187,14 +193,16 @@ class EvaluatingCompilerExceptionsTest : EvaluatorTestBase() { fun rightJoin() = runEvaluatorErrorTestCase( "SELECT * FROM animals AS a RIGHT CROSS JOIN animal_types AS a_type WHERE a.type = a_type.id", ErrorCode.EVALUATOR_FEATURE_NOT_SUPPORTED_YET, - expectedErrorContext = propertyValueMapOf(1, 28, Property.FEATURE_NAME to "RIGHT and FULL JOIN") + expectedErrorContext = propertyValueMapOf(1, 28, Property.FEATURE_NAME to "RIGHT and FULL JOIN"), + target = EvaluatorTestTarget.COMPILER_PIPELINE ) @Test fun outerJoin() = runEvaluatorErrorTestCase( "SELECT * FROM animals AS a OUTER CROSS JOIN animal_types AS a_type WHERE a.type = a_type.id", ErrorCode.EVALUATOR_FEATURE_NOT_SUPPORTED_YET, - expectedErrorContext = propertyValueMapOf(1, 28, Property.FEATURE_NAME to "RIGHT and FULL JOIN") + expectedErrorContext = propertyValueMapOf(1, 28, Property.FEATURE_NAME to "RIGHT and FULL JOIN"), + target = EvaluatorTestTarget.COMPILER_PIPELINE ) @Test diff --git a/lang/test/org/partiql/lang/eval/EvaluatingCompilerFromLetTests.kt b/lang/test/org/partiql/lang/eval/EvaluatingCompilerFromLetTests.kt index 6294e13edd..4efadf9484 100644 --- a/lang/test/org/partiql/lang/eval/EvaluatingCompilerFromLetTests.kt +++ b/lang/test/org/partiql/lang/eval/EvaluatingCompilerFromLetTests.kt @@ -6,6 +6,7 @@ import org.partiql.lang.errors.ErrorCode import org.partiql.lang.errors.Property import org.partiql.lang.eval.evaluatortestframework.EvaluatorErrorTestCase import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestCase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.util.ArgumentsProviderBase import org.partiql.lang.util.propertyValueMapOf import org.partiql.lang.util.to @@ -35,28 +36,44 @@ class EvaluatingCompilerFromLetTests : EvaluatorTestBase() { // LET used in GROUP BY EvaluatorTestCase( "SELECT * FROM C LET region AS X GROUP BY X", - """<< {'X': `EU`}, {'X': `NA`} >>""" + """<< {'X': `EU`}, {'X': `NA`} >>""", + target = EvaluatorTestTarget.COMPILER_PIPELINE // no support in physical plans yet for GROUP BY ), // LET used in projection after GROUP BY EvaluatorTestCase( "SELECT foo FROM B LET 100 AS foo GROUP BY B.id, foo", - """<< {'foo': 100}, {'foo': 100} >>""" + """<< {'foo': 100}, {'foo': 100} >>""", + target = EvaluatorTestTarget.COMPILER_PIPELINE // no support in physical plans yet for GROUP BY ), // LET used in HAVING after GROUP BY EvaluatorTestCase( "SELECT B.id FROM B LET 100 AS foo GROUP BY B.id, foo HAVING B.id > foo", - """<< {'id': 200} >>""" + """<< {'id': 200} >>""", + target = EvaluatorTestTarget.COMPILER_PIPELINE // no support in physical plans yet for HAVING ), // LET shadowed binding EvaluatorTestCase( "SELECT X FROM A LET 1 AS X, 2 AS X", """<< {'X': 2} >>""" ), + + // For the two tests immediately below--one tests the AST evaluator only and the other tests + // the phys. plan evaluator only. The query is the same but the expected result is different + // because the legacy AST evaluator has a bug not present in the physical plan evaluator: + // https://github.com/partiql/partiql-lang-kotlin/issues/549 + // LET shadowing FROM binding EvaluatorTestCase( "SELECT * FROM A LET 100 AS A", - """<< {'_1': 100} >>""" + """<< { '_1': 100 } >>""", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ), + EvaluatorTestCase( + "SELECT * FROM A LET 100 AS A", + """<< { 'id': 1 }>>""", + target = EvaluatorTestTarget.PLANNER_PIPELINE ), + // LET using other variables EvaluatorTestCase( "SELECT X, Y FROM A LET 1 AS X, X + 1 AS Y", @@ -80,7 +97,8 @@ class EvaluatingCompilerFromLetTests : EvaluatorTestBase() { // LET calling function with GROUP BY and aggregation EvaluatorTestCase( "SELECT C.region, MAX(nameLength) AS maxLen FROM C LET char_length(C.name) AS nameLength GROUP BY C.region", - """<< {'region': `EU`, 'maxLen': 6}, {'region': `NA`, 'maxLen': 9} >>""" + """<< {'region': `EU`, 'maxLen': 6}, {'region': `NA`, 'maxLen': 9} >>""", + target = EvaluatorTestTarget.COMPILER_PIPELINE // no support in physical plans yet for GROUP BY ), // LET outer query has correct value EvaluatorTestCase( @@ -152,7 +170,8 @@ class EvaluatingCompilerFromLetTests : EvaluatorTestBase() { Property.LINE_NUMBER to 1L, Property.COLUMN_NUMBER to 63L, Property.BINDING_NAME to "foo" - ) + ), + targetPipeline = EvaluatorTestTarget.COMPILER_PIPELINE // no support in physical plans yet for GROUP BY ), // LET binding referenced in projection not in GROUP BY EvaluatorErrorTestCase( @@ -162,7 +181,8 @@ class EvaluatingCompilerFromLetTests : EvaluatorTestBase() { Property.LINE_NUMBER to 1L, Property.COLUMN_NUMBER to 8L, Property.BINDING_NAME to "foo" - ) + ), + targetPipeline = EvaluatorTestTarget.COMPILER_PIPELINE // no support in physical plans yet for GROUP BY ) ) } diff --git a/lang/test/org/partiql/lang/eval/EvaluatingCompilerGroupByTest.kt b/lang/test/org/partiql/lang/eval/EvaluatingCompilerGroupByTest.kt index 45ffc74b22..d7012e938c 100644 --- a/lang/test/org/partiql/lang/eval/EvaluatingCompilerGroupByTest.kt +++ b/lang/test/org/partiql/lang/eval/EvaluatingCompilerGroupByTest.kt @@ -19,6 +19,7 @@ import org.junit.Test import org.partiql.lang.errors.ErrorCode import org.partiql.lang.errors.Property import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestCase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.util.propertyValueMapOf class EvaluatingCompilerGroupByTest : EvaluatorTestBase() { @@ -87,7 +88,13 @@ class EvaluatingCompilerGroupByTest : EvaluatorTestBase() { ).toSession() private fun runTest(tc: EvaluatorTestCase, session: EvaluationSession) = - super.runEvaluatorTestCase(tc.copy(implicitPermissiveModeTest = false), session) + super.runEvaluatorTestCase( + tc.copy( + implicitPermissiveModeTest = false, // we are manually setting typing mode + targetPipeline = EvaluatorTestTarget.COMPILER_PIPELINE // no support in physical plans yet for GROUP BY + ), + session + ) companion object { @@ -1142,7 +1149,8 @@ class EvaluatingCompilerGroupByTest : EvaluatorTestBase() { "SELECT foo AS someSelectListAlias FROM <<{ 'a': 1 }>> GROUP BY someSelectListAlias", ErrorCode.EVALUATOR_BINDING_DOES_NOT_EXIST, propertyValueMapOf(1, 64, Property.BINDING_NAME to "someSelectListAlias"), - expectedPermissiveModeResult = "<<{}>>" + expectedPermissiveModeResult = "<<{}>>", + target = EvaluatorTestTarget.COMPILER_PIPELINE ) } @@ -1151,7 +1159,8 @@ class EvaluatingCompilerGroupByTest : EvaluatorTestBase() { runEvaluatorErrorTestCase( "SELECT MAX(@v2), @v2 FROM `[1, 2.0, 3e0, 4, 5d0]` AS v2", ErrorCode.EVALUATOR_VARIABLE_NOT_INCLUDED_IN_GROUP_BY, - propertyValueMapOf(1, 19, Property.BINDING_NAME to "v2") + propertyValueMapOf(1, 19, Property.BINDING_NAME to "v2"), + target = EvaluatorTestTarget.COMPILER_PIPELINE ) } @@ -1160,7 +1169,8 @@ class EvaluatingCompilerGroupByTest : EvaluatorTestBase() { runEvaluatorErrorTestCase( "SELECT * FROM << {'a': 1 } >> AS f GROUP BY f.a HAVING f.id = 1", ErrorCode.EVALUATOR_VARIABLE_NOT_INCLUDED_IN_GROUP_BY, - propertyValueMapOf(1, 56, Property.BINDING_NAME to "f") + propertyValueMapOf(1, 56, Property.BINDING_NAME to "f"), + target = EvaluatorTestTarget.COMPILER_PIPELINE ) } @@ -1169,7 +1179,8 @@ class EvaluatingCompilerGroupByTest : EvaluatorTestBase() { runEvaluatorErrorTestCase( "SELECT VALUE f.id FROM << {'a': 'b' } >> AS f GROUP BY f.a", ErrorCode.EVALUATOR_VARIABLE_NOT_INCLUDED_IN_GROUP_BY, - propertyValueMapOf(1, 14, Property.BINDING_NAME to "f") + propertyValueMapOf(1, 14, Property.BINDING_NAME to "f"), + target = EvaluatorTestTarget.COMPILER_PIPELINE ) } @@ -1183,6 +1194,7 @@ class EvaluatingCompilerGroupByTest : EvaluatorTestBase() { ErrorCode.EVALUATOR_VARIABLE_NOT_INCLUDED_IN_GROUP_BY, propertyValueMapOf(2, 28, Property.BINDING_NAME to "O"), session = session, + target = EvaluatorTestTarget.COMPILER_PIPELINE ) } @@ -1196,7 +1208,8 @@ class EvaluatingCompilerGroupByTest : EvaluatorTestBase() { ErrorCode.EVALUATOR_QUOTED_BINDING_DOES_NOT_EXIST, propertyValueMapOf(2, 28, Property.BINDING_NAME to "O"), expectedPermissiveModeResult = "<<{'_2': 10}>>", - session = session + session = session, + target = EvaluatorTestTarget.COMPILER_PIPELINE ) } @@ -1210,7 +1223,8 @@ class EvaluatingCompilerGroupByTest : EvaluatorTestBase() { """, expectedErrorCode = ErrorCode.EVALUATOR_VARIABLE_NOT_INCLUDED_IN_GROUP_BY, expectedErrorContext = propertyValueMapOf(2, 41, Property.BINDING_NAME to "c"), - session = session + session = session, + target = EvaluatorTestTarget.COMPILER_PIPELINE ) } @@ -1225,7 +1239,8 @@ class EvaluatingCompilerGroupByTest : EvaluatorTestBase() { """, expectedErrorCode = ErrorCode.EVALUATOR_VARIABLE_NOT_INCLUDED_IN_GROUP_BY, expectedErrorContext = propertyValueMapOf(2, 41, Property.BINDING_NAME to "o"), - session = session + session = session, + target = EvaluatorTestTarget.COMPILER_PIPELINE ) } @@ -1244,7 +1259,8 @@ class EvaluatingCompilerGroupByTest : EvaluatorTestBase() { """, expectedErrorCode = ErrorCode.EVALUATOR_VARIABLE_NOT_INCLUDED_IN_GROUP_BY, expectedErrorContext = propertyValueMapOf(2, 37, Property.BINDING_NAME to "o"), - session = session + session = session, + target = EvaluatorTestTarget.COMPILER_PIPELINE ) } @@ -1261,7 +1277,8 @@ class EvaluatingCompilerGroupByTest : EvaluatorTestBase() { """, expectedErrorCode = ErrorCode.EVALUATOR_VARIABLE_NOT_INCLUDED_IN_GROUP_BY, expectedErrorContext = propertyValueMapOf(4, 28, Property.BINDING_NAME to "o"), - session = session + session = session, + target = EvaluatorTestTarget.COMPILER_PIPELINE ) } } diff --git a/lang/test/org/partiql/lang/eval/EvaluatingCompilerHavingTest.kt b/lang/test/org/partiql/lang/eval/EvaluatingCompilerHavingTest.kt index 9e39601ef2..8ea52232c7 100644 --- a/lang/test/org/partiql/lang/eval/EvaluatingCompilerHavingTest.kt +++ b/lang/test/org/partiql/lang/eval/EvaluatingCompilerHavingTest.kt @@ -18,6 +18,7 @@ import junitparams.Parameters import org.junit.Test import org.partiql.lang.errors.ErrorCode import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestCase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.util.propertyValueMapOf class EvaluatingCompilerHavingTest : EvaluatorTestBase() { @@ -54,7 +55,11 @@ class EvaluatingCompilerHavingTest : EvaluatorTestBase() { @Test @Parameters - fun groupByHavingTest(tc: EvaluatorTestCase) = runEvaluatorTestCase(tc, session) + fun groupByHavingTest(tc: EvaluatorTestCase) = + runEvaluatorTestCase( + tc.copy(targetPipeline = EvaluatorTestTarget.COMPILER_PIPELINE), // Phys. Algebra doesn't yet support HAVING + session + ) fun parametersForGroupByHavingTest() = listOf( @@ -179,7 +184,8 @@ class EvaluatingCompilerHavingTest : EvaluatorTestBase() { runEvaluatorErrorTestCase( query = "SELECT foo.bar FROM bat HAVING 1 = 1", expectedErrorCode = ErrorCode.SEMANTIC_HAVING_USED_WITHOUT_GROUP_BY, - expectedErrorContext = propertyValueMapOf(1, 1) + expectedErrorContext = propertyValueMapOf(1, 1), + target = EvaluatorTestTarget.COMPILER_PIPELINE ) } } diff --git a/lang/test/org/partiql/lang/eval/EvaluatingCompilerLimitTests.kt b/lang/test/org/partiql/lang/eval/EvaluatingCompilerLimitTests.kt index 090ad180ff..11501436d4 100644 --- a/lang/test/org/partiql/lang/eval/EvaluatingCompilerLimitTests.kt +++ b/lang/test/org/partiql/lang/eval/EvaluatingCompilerLimitTests.kt @@ -3,6 +3,7 @@ package org.partiql.lang.eval import org.junit.Test import org.partiql.lang.errors.ErrorCode import org.partiql.lang.errors.Property +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.util.propertyValueMapOf class EvaluatingCompilerLimitTests : EvaluatorTestBase() { @@ -54,7 +55,8 @@ class EvaluatingCompilerLimitTests : EvaluatorTestBase() { runEvaluatorErrorTestCase( """ select * from <<1>> limit -1 """, ErrorCode.EVALUATOR_NEGATIVE_LIMIT, - propertyValueMapOf(1, 29) + propertyValueMapOf(1, 29), + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & physical plan have no support LIMIT (yet) ) @Test @@ -69,6 +71,7 @@ class EvaluatingCompilerLimitTests : EvaluatorTestBase() { fun `LIMIT applied after GROUP BY`() = runEvaluatorTestCase( "SELECT g FROM `[{foo: 1, bar: 10}, {foo: 1, bar: 11}]` AS f GROUP BY f.foo GROUP AS g LIMIT 1", - expectedResult = """[ { 'g': [ { 'f': { 'foo': 1, 'bar': 10 } }, { 'f': { 'foo': 1, 'bar': 11 } } ] } ]""" + expectedResult = """[ { 'g': [ { 'f': { 'foo': 1, 'bar': 10 } }, { 'f': { 'foo': 1, 'bar': 11 } } ] } ]""", + target = EvaluatorTestTarget.COMPILER_PIPELINE // planner & physical plan have no support for GROUP BY (yet) ) } diff --git a/lang/test/org/partiql/lang/eval/EvaluatingCompilerNAryIntOverflowTests.kt b/lang/test/org/partiql/lang/eval/EvaluatingCompilerNAryIntOverflowTests.kt index f2dd9e95d5..ad6be42473 100644 --- a/lang/test/org/partiql/lang/eval/EvaluatingCompilerNAryIntOverflowTests.kt +++ b/lang/test/org/partiql/lang/eval/EvaluatingCompilerNAryIntOverflowTests.kt @@ -3,6 +3,7 @@ package org.partiql.lang.eval import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ArgumentsSource import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestCase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.eval.visitors.StaticTypeInferenceVisitorTransform import org.partiql.lang.types.IntType import org.partiql.lang.types.StaticType @@ -152,18 +153,19 @@ class EvaluatingCompilerNAryIntOverflowTests : EvaluatorTestBase() { globals(defaultEnv.valueBindings) } - // We use EvaluatorTestCase/runTestCase from EvaluatorTestBase here instead of assertEval - // because the expected values are expressed in PartiQL syntax, but with assertEval it's expressed in - // Ion syntax. val etc = EvaluatorTestCase( query = tc.sqlUnderTest, expectedResult = tc.expectedPermissiveModeResult, + implicitPermissiveModeTest = false, compilerPipelineBuilderBlock = { globalTypeBindings(defaultEnv.typeBindings) compileOptions { typingMode(TypingMode.PERMISSIVE) } - } + }, + // These tests requires support for globalTypeBindings and thus a static type inference pass + // which is not (yet) supported by `PlannerPipeline` + target = EvaluatorTestTarget.COMPILER_PIPELINE ) runEvaluatorTestCase(etc, session) diff --git a/lang/test/org/partiql/lang/eval/EvaluatingCompilerOffsetTests.kt b/lang/test/org/partiql/lang/eval/EvaluatingCompilerOffsetTests.kt index ede0f12e86..e1e20bd622 100644 --- a/lang/test/org/partiql/lang/eval/EvaluatingCompilerOffsetTests.kt +++ b/lang/test/org/partiql/lang/eval/EvaluatingCompilerOffsetTests.kt @@ -6,6 +6,7 @@ import org.partiql.lang.errors.ErrorCode import org.partiql.lang.errors.Property import org.partiql.lang.eval.evaluatortestframework.EvaluatorErrorTestCase import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestCase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.util.ArgumentsProviderBase import org.partiql.lang.util.propertyValueMapOf import org.partiql.lang.util.to @@ -58,12 +59,14 @@ class EvaluatingCompilerOffsetTests : EvaluatorTestBase() { // LIMIT 2 and OFFSET 2 should return third and fourth results EvaluatorTestCase( "SELECT * FROM foo GROUP BY a LIMIT 2 OFFSET 2", - "<<{'a': 3}, {'a': 4}>>" + "<<{'a': 3}, {'a': 4}>>", + target = EvaluatorTestTarget.COMPILER_PIPELINE // PlannerPipeline doesn't support GROUP BY yet ), // LIMIT and OFFSET applied after GROUP BY EvaluatorTestCase( "SELECT * FROM foo GROUP BY a LIMIT 1 OFFSET 1", - "<<{'a': 2}>>" + "<<{'a': 2}>>", + target = EvaluatorTestTarget.COMPILER_PIPELINE // PlannerPipeline doesn't support GROUP BY yet ), // OFFSET value can be subtraction of 2 numbers EvaluatorTestCase( @@ -88,7 +91,8 @@ class EvaluatingCompilerOffsetTests : EvaluatorTestBase() { // OFFSET with GROUP BY and HAVING EvaluatorTestCase( "SELECT * FROM foo GROUP BY a HAVING a > 2 LIMIT 1 OFFSET 1", - "<<{'a': 4}>>" + "<<{'a': 4}>>", + target = EvaluatorTestTarget.COMPILER_PIPELINE // PlannerPipeline doesn't support GROUP BY yet ), // OFFSET with PIVOT EvaluatorTestCase( @@ -97,7 +101,8 @@ class EvaluatingCompilerOffsetTests : EvaluatorTestBase() { FROM <<{'a': 1, 'b':'I'}, {'a': 2, 'b':'II'}, {'a': 3, 'b':'III'}>> AS foo LIMIT 1 OFFSET 1 """.trimIndent(), - "{'II': 2}" + "{'II': 2}", + target = EvaluatorTestTarget.COMPILER_PIPELINE // PlannerPipeline doesn't support PIVOT yet ) ) } diff --git a/lang/test/org/partiql/lang/eval/EvaluatingCompilerOrderByTests.kt b/lang/test/org/partiql/lang/eval/EvaluatingCompilerOrderByTests.kt index 6406ec2a44..4946669e5b 100644 --- a/lang/test/org/partiql/lang/eval/EvaluatingCompilerOrderByTests.kt +++ b/lang/test/org/partiql/lang/eval/EvaluatingCompilerOrderByTests.kt @@ -3,6 +3,7 @@ package org.partiql.lang.eval import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ArgumentsSource import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestCase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.util.ArgumentsProviderBase class EvaluatingCompilerOrderByTests : EvaluatorTestBase() { @@ -276,7 +277,10 @@ class EvaluatingCompilerOrderByTests : EvaluatorTestBase() { @ParameterizedTest @ArgumentsSource(ArgsProviderValid::class) fun validTests(tc: EvaluatorTestCase) = runEvaluatorTestCase( - tc = tc.copy(excludeLegacySerializerAssertions = true), + tc = tc.copy( + excludeLegacySerializerAssertions = true, + targetPipeline = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for ORDER BY (yet) + ), session = session ) } diff --git a/lang/test/org/partiql/lang/eval/EvaluatingCompilerUnknownValuesTest.kt b/lang/test/org/partiql/lang/eval/EvaluatingCompilerUnknownValuesTest.kt index 28dac32ef6..624444c090 100644 --- a/lang/test/org/partiql/lang/eval/EvaluatingCompilerUnknownValuesTest.kt +++ b/lang/test/org/partiql/lang/eval/EvaluatingCompilerUnknownValuesTest.kt @@ -20,6 +20,7 @@ import org.junit.jupiter.params.provider.ArgumentsSource import org.partiql.lang.errors.ErrorCode import org.partiql.lang.errors.Property import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestCase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.eval.evaluatortestframework.ExpectedResultFormat import org.partiql.lang.types.FunctionSignature import org.partiql.lang.types.StaticType @@ -622,166 +623,243 @@ class EvaluatingCompilerUnknownValuesTest : EvaluatorTestBase() { // //////////////////////////////////////////////// @Test - fun aggregateSumWithNull() = runEvaluatorTestCase("SELECT sum(x.n) from nullSample as x", nullSample, "[{_1: 4}]") + fun aggregateSumWithNull() = runEvaluatorTestCase( + "SELECT sum(x.n) from nullSample as x", + nullSample, + "[{_1: 4}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) + ) @Test fun aggregateSumWithMissing() = runEvaluatorTestCase( "SELECT sum(x.n) from missingSample as x", missingSample, - "[{_1: 3}]" + "[{_1: 3}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test fun aggregateSumWithMissingAndNull() = runEvaluatorTestCase( "SELECT sum(x.n) from missingAndNullSample as x", missingAndNullSample, - "[{_1: 9}]" + "[{_1: 9}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test - fun aggregateMinWithNull() = runEvaluatorTestCase("SELECT min(x.n) from nullSample as x", nullSample, "[{_1: 1}]") + fun aggregateMinWithNull() = runEvaluatorTestCase( + "SELECT min(x.n) from nullSample as x", + nullSample, + "[{_1: 1}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) + ) @Test fun aggregateMinWithMissing() = runEvaluatorTestCase( "SELECT min(x.n) from missingSample as x", missingSample, - "[{_1: 1}]" + "[{_1: 1}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test fun aggregateMinWithMissingAndNull() = runEvaluatorTestCase( "SELECT min(x.n) from missingAndNullSample as x", missingAndNullSample, - "[{_1: 2}]" + "[{_1: 2}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test - fun aggregateAvgWithNull() = runEvaluatorTestCase("SELECT avg(x.n) from nullSample as x", nullSample, "[{_1: 2.}]") + fun aggregateAvgWithNull() = runEvaluatorTestCase( + "SELECT avg(x.n) from nullSample as x", + nullSample, + "[{_1: 2.}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) + ) @Test fun aggregateAvgWithMissing() = runEvaluatorTestCase( "SELECT avg(x.n) from missingSample as x", missingSample, - "[{_1: 1.5}]" + "[{_1: 1.5}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test fun aggregateAvgWithMissingAndNull() = runEvaluatorTestCase( "SELECT avg(x.n) from missingAndNullSample as x", missingAndNullSample, - "[{_1: 3.}]" + "[{_1: 3.}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test fun aggregateCountWithNull() = runEvaluatorTestCase( "SELECT count(x.n) from nullSample as x", nullSample, - "[{_1: 2}]" + "[{_1: 2}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test fun aggregateCountWithMissing() = runEvaluatorTestCase( "SELECT count(x.n) from missingSample as x", missingSample, - "[{_1: 2}]" + "[{_1: 2}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test fun aggregateCountWithMissingAndNull() = runEvaluatorTestCase( "SELECT count(x.n) from missingAndNullSample as x", missingAndNullSample, - "[{_1: 3}]" + "[{_1: 3}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test - fun countEmpty() = runEvaluatorTestCase("SELECT count(*) from `[]`", expectedResult = "[{_1: 0}]") + fun countEmpty() = runEvaluatorTestCase( + "SELECT count(*) from `[]`", + expectedResult = "[{_1: 0}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) + ) @Test fun countEmptyTuple() = - runEvaluatorTestCase("SELECT count(*) from `[{}]`", expectedResult = "[{_1: 1}]") + runEvaluatorTestCase( + "SELECT count(*) from `[{}]`", + expectedResult = "[{_1: 1}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) + ) @Test - fun sumEmpty() = runEvaluatorTestCase("SELECT sum(x.i) from `[]` as x", expectedResult = "[{_1: null}]") + fun sumEmpty() = runEvaluatorTestCase( + "SELECT sum(x.i) from `[]` as x", + expectedResult = "[{_1: null}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) + ) @Test fun sumEmptyTuple() = - runEvaluatorTestCase("SELECT sum(x.i) from `[{}]` as x", expectedResult = "[{_1: null}]") + runEvaluatorTestCase( + "SELECT sum(x.i) from `[{}]` as x", + expectedResult = "[{_1: null}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) + ) @Test - fun avgEmpty() = runEvaluatorTestCase("SELECT avg(x.i) from `[]` as x", expectedResult = "[{_1: null}]") + fun avgEmpty() = runEvaluatorTestCase( + "SELECT avg(x.i) from `[]` as x", + expectedResult = "[{_1: null}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) + ) @Test fun avgEmptyTuple() = - runEvaluatorTestCase("SELECT avg(x.i) from `[{}]` as x", expectedResult = "[{_1: null}]") + runEvaluatorTestCase( + "SELECT avg(x.i) from `[{}]` as x", + expectedResult = "[{_1: null}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) + ) @Test fun avgSomeEmptyTuples() = runEvaluatorTestCase( "SELECT avg(x.i) from `[{i: 1}, {}, {i:3}]` as x", - expectedResult = "[{_1: 2.}]" + expectedResult = "[{_1: 2.}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test fun avgSomeEmptyAndNullTuples() = runEvaluatorTestCase( "SELECT avg(x.i) from `[{i: 1}, {}, {i:null}, {i:3}]` as x", - expectedResult = "[{_1: 2.}]" + expectedResult = "[{_1: 2.}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test fun minSomeEmptyTuples() = runEvaluatorTestCase( "SELECT min(x.i) from `[{i: null}, {}, {i:3}]` as x", - expectedResult = "[{_1: 3}]" + expectedResult = "[{_1: 3}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test fun maxSomeEmptyTuples() = runEvaluatorTestCase( "SELECT max(x.i) from `[{i: null}, {}, {i:3}, {i:10}]` as x", - expectedResult = "[{_1: 10}]" + expectedResult = "[{_1: 10}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test - fun minEmpty() = runEvaluatorTestCase("SELECT min(x.i) from `[]` as x", expectedResult = "[{_1: null}]") + fun minEmpty() = runEvaluatorTestCase( + "SELECT min(x.i) from `[]` as x", + expectedResult = "[{_1: null}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) + ) @Test fun minEmptyTuple() = - runEvaluatorTestCase("SELECT min(x.i) from `[{}]` as x", expectedResult = "[{_1: null}]") + runEvaluatorTestCase( + "SELECT min(x.i) from `[{}]` as x", + expectedResult = "[{_1: null}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) + ) @Test - fun maxEmpty() = runEvaluatorTestCase("SELECT max(x.i) from `[]` as x", expectedResult = "[{_1: null}]") + fun maxEmpty() = runEvaluatorTestCase( + "SELECT max(x.i) from `[]` as x", + expectedResult = "[{_1: null}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) + ) @Test fun maxEmptyTuple() = - runEvaluatorTestCase("SELECT max(x.i) from `[{}]` as x", expectedResult = "[{_1: null}]") + runEvaluatorTestCase( + "SELECT max(x.i) from `[{}]` as x", + expectedResult = "[{_1: null}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) + ) @Test fun maxSomeEmptyTuple() = runEvaluatorTestCase( "SELECT max(x.i) from `[{}, {i:1}, {}, {i:2}]` as x", - expectedResult = "[{_1: 2}]" + expectedResult = "[{_1: 2}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test fun minSomeEmptyTuple() = runEvaluatorTestCase( "SELECT min(x.i) from `[{}, {i:1}, {}, {i:2}]` as x", - expectedResult = "[{_1: 1}]" + expectedResult = "[{_1: 1}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test fun sumSomeEmptyTuple() = runEvaluatorTestCase( "SELECT sum(x.i) from `[{}, {i:1}, {}, {i:2}]` as x", - expectedResult = "[{_1: 3}]" + expectedResult = "[{_1: 3}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test fun countSomeEmptyTuple() = runEvaluatorTestCase( "SELECT count(x.i) from `[{}, {i:1}, {}, {i:2}]` as x", - expectedResult = "[{_1: 2}]" + expectedResult = "[{_1: 2}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test fun countStar() = runEvaluatorTestCase( "SELECT count(*) from `[{}, {i:1}, {}, {i:2}]` as x", - expectedResult = "[{_1: 4}]" + expectedResult = "[{_1: 4}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) ) @Test fun countLiteral() = - runEvaluatorTestCase("SELECT count(1) from `[{}, {}, {}, {}]` as x", expectedResult = "[{_1: 4}]") + runEvaluatorTestCase( + "SELECT count(1) from `[{}, {}, {}, {}]` as x", + expectedResult = "[{_1: 4}]", + target = EvaluatorTestTarget.COMPILER_PIPELINE, // planner & phys. alg. have no support for aggregates (yet) + ) } diff --git a/lang/test/org/partiql/lang/eval/EvaluatorStaticTypeTests.kt b/lang/test/org/partiql/lang/eval/EvaluatorStaticTypeTests.kt index 61ded6d36d..d0cd758364 100644 --- a/lang/test/org/partiql/lang/eval/EvaluatorStaticTypeTests.kt +++ b/lang/test/org/partiql/lang/eval/EvaluatorStaticTypeTests.kt @@ -3,6 +3,7 @@ package org.partiql.lang.eval import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.MethodSource import org.partiql.lang.ION +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.util.testdsl.IonResultTestCase import org.partiql.lang.util.testdsl.runTestCase @@ -173,13 +174,12 @@ class EvaluatorStaticTypeTests { "aggregateInSubqueryOfSelect", "aggregateInSubqueryOfSelectValue", "aggregateWithAliasingInSubqueryOfSelectValue" - ) @JvmStatic @Suppress("unused") fun evaluatorStaticTypeTests() = EVALUATOR_TEST_SUITE.getAllTests( - EvaluatorTests.SKIP_LIST.union(FAILING_TESTS) + EvaluatorTests.AST_EVALUATOR_SKIP_LIST.union(FAILING_TESTS) ).map { it.copy( compileOptionsBuilderBlock = { @@ -187,8 +187,10 @@ class EvaluatorStaticTypeTests { // set permissive mode typingMode(TypingMode.PERMISSIVE) - // enable evaluation time type checking - evaluationTimeTypeChecks(ThunkReturnTypeAssertions.ENABLED) + thunkOptions { + // enable evaluation time type checking + evaluationTimeTypeChecks(ThunkReturnTypeAssertions.ENABLED) + } } ) } @@ -200,7 +202,9 @@ class EvaluatorStaticTypeTests { tc.runTestCase( valueFactory = valueFactory, db = mockDb, + // the planner doesn't yet support type inferencing pass needed to make this work + EvaluatorTestTarget.COMPILER_PIPELINE, // Enable the static type inferencer for this - compilerPipelineBuilderBlock = { this.globalTypeBindings(mockDb.typeBindings) } + compilerPipelineBuilderBlock = { this.globalTypeBindings(mockDb.typeBindings) }, ) } diff --git a/lang/test/org/partiql/lang/eval/EvaluatorTestBase.kt b/lang/test/org/partiql/lang/eval/EvaluatorTestBase.kt index a3571c81a4..c18db638ec 100644 --- a/lang/test/org/partiql/lang/eval/EvaluatorTestBase.kt +++ b/lang/test/org/partiql/lang/eval/EvaluatorTestBase.kt @@ -26,15 +26,18 @@ import org.partiql.lang.SqlException import org.partiql.lang.TestBase import org.partiql.lang.errors.ErrorCode import org.partiql.lang.errors.PropertyValueMap -import org.partiql.lang.eval.evaluatortestframework.AstEvaluatorTestAdapter import org.partiql.lang.eval.evaluatortestframework.AstRewriterBaseTestAdapter +import org.partiql.lang.eval.evaluatortestframework.CompilerPipelineFactory import org.partiql.lang.eval.evaluatortestframework.EvaluatorErrorTestCase import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestAdapter import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestCase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.eval.evaluatortestframework.ExpectedResultFormat import org.partiql.lang.eval.evaluatortestframework.LegacySerializerTestAdapter import org.partiql.lang.eval.evaluatortestframework.MultipleTestAdapter import org.partiql.lang.eval.evaluatortestframework.PartiqlAstExprNodeRoundTripAdapter +import org.partiql.lang.eval.evaluatortestframework.PipelineEvaluatorTestAdapter +import org.partiql.lang.eval.evaluatortestframework.PlannerPipelineFactory import org.partiql.lang.util.asSequence import org.partiql.lang.util.newFromIonText @@ -44,7 +47,8 @@ import org.partiql.lang.util.newFromIonText abstract class EvaluatorTestBase : TestBase() { private val testHarness: EvaluatorTestAdapter = MultipleTestAdapter( listOf( - AstEvaluatorTestAdapter(), + PipelineEvaluatorTestAdapter(CompilerPipelineFactory()), + PipelineEvaluatorTestAdapter(PlannerPipelineFactory()), PartiqlAstExprNodeRoundTripAdapter(), LegacySerializerTestAdapter(), AstRewriterBaseTestAdapter() @@ -71,6 +75,7 @@ abstract class EvaluatorTestBase : TestBase() { excludeLegacySerializerAssertions: Boolean = false, expectedResultFormat: ExpectedResultFormat = ExpectedResultFormat.ION_WITHOUT_BAG_AND_MISSING_ANNOTATIONS, includePermissiveModeTest: Boolean = true, + target: EvaluatorTestTarget = EvaluatorTestTarget.ALL_PIPELINES, compileOptionsBuilderBlock: CompileOptions.Builder.() -> Unit = { }, compilerPipelineBuilderBlock: CompilerPipeline.Builder.() -> Unit = { }, extraResultAssertions: (ExprValue) -> Unit = { } @@ -82,6 +87,7 @@ abstract class EvaluatorTestBase : TestBase() { expectedResultFormat = expectedResultFormat, excludeLegacySerializerAssertions = excludeLegacySerializerAssertions, implicitPermissiveModeTest = includePermissiveModeTest, + target = target, compileOptionsBuilderBlock = compileOptionsBuilderBlock, compilerPipelineBuilderBlock = compilerPipelineBuilderBlock, extraResultAssertions = extraResultAssertions @@ -112,6 +118,7 @@ abstract class EvaluatorTestBase : TestBase() { compileOptionsBuilderBlock: CompileOptions.Builder.() -> Unit = { }, addtionalExceptionAssertBlock: (SqlException) -> Unit = { }, implicitPermissiveModeTest: Boolean = true, + target: EvaluatorTestTarget = EvaluatorTestTarget.ALL_PIPELINES, session: EvaluationSession = EvaluationSession.standard() ) { val tc = EvaluatorErrorTestCase( @@ -121,9 +128,10 @@ abstract class EvaluatorTestBase : TestBase() { expectedInternalFlag = expectedInternalFlag, expectedPermissiveModeResult = expectedPermissiveModeResult, excludeLegacySerializerAssertions = excludeLegacySerializerAssertions, + implicitPermissiveModeTest = implicitPermissiveModeTest, + targetPipeline = target, compileOptionsBuilderBlock = compileOptionsBuilderBlock, compilerPipelineBuilderBlock = compilerPipelineBuilderBlock, - implicitPermissiveModeTest = implicitPermissiveModeTest, additionalExceptionAssertBlock = addtionalExceptionAssertBlock, ) diff --git a/lang/test/org/partiql/lang/eval/EvaluatorTestSuite.kt b/lang/test/org/partiql/lang/eval/EvaluatorTestSuite.kt index 9e5135ca00..156cc081c8 100644 --- a/lang/test/org/partiql/lang/eval/EvaluatorTestSuite.kt +++ b/lang/test/org/partiql/lang/eval/EvaluatorTestSuite.kt @@ -368,7 +368,7 @@ internal val EVALUATOR_TEST_SUITE: IonResultTestSuite = defineTestSuite { """ ) } - group("select-where") { + group("select_where") { test( "selectWhereStringEqualsSameCase", """SELECT * FROM animals as a WHERE a.name = 'Kumo' """, @@ -387,7 +387,7 @@ internal val EVALUATOR_TEST_SUITE: IonResultTestSuite = defineTestSuite { """ ) } - group("select-join") { + group("select_join") { test( "selectJoin", """SELECT * FROM animals AS a, animal_types AS t WHERE a.type = t.id""", diff --git a/lang/test/org/partiql/lang/eval/EvaluatorTests.kt b/lang/test/org/partiql/lang/eval/EvaluatorTests.kt index 6d92bf6738..4707b5a9c0 100644 --- a/lang/test/org/partiql/lang/eval/EvaluatorTests.kt +++ b/lang/test/org/partiql/lang/eval/EvaluatorTests.kt @@ -17,6 +17,7 @@ package org.partiql.lang.eval import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.MethodSource import org.partiql.lang.ION +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.util.testdsl.IonResultTestCase import org.partiql.lang.util.testdsl.runTestCase @@ -25,7 +26,7 @@ class EvaluatorTests { private val mockDb = EVALUATOR_TEST_SUITE.mockDb(valueFactory) companion object { - val SKIP_LIST = hashSetOf( + val AST_EVALUATOR_SKIP_LIST = hashSetOf( // https://github.com/partiql/partiql-lang-kotlin/issues/169 "selectDistinctStarLists", "selectDistinctStarBags", "selectDistinctStarMixed", @@ -36,8 +37,8 @@ class EvaluatorTests { @JvmStatic @Suppress("UNUSED") - fun evaluatorTests(): List { - val unskippedTests = EVALUATOR_TEST_SUITE.getAllTests(SKIP_LIST) + fun astEvaluatorTests(): List { + val unskippedTests = EVALUATOR_TEST_SUITE.getAllTests(AST_EVALUATOR_SKIP_LIST) return unskippedTests.map { it.copy( @@ -58,9 +59,71 @@ class EvaluatorTests { ) } } + + private val PLAN_EVALUATOR_SKIP_LIST = hashSetOf( + // below this line use features not supported by the current physical algebra compiler. + // most fail due to not supporting foundational nodes like id, global_id and scan yet. + // PartiQL's test cases are not all that cleanly separated. + "selectCorrelatedUnpivot", // TODO: Support UNPIVOT in physical plans + "nestedSelectJoinWithUnpivot", // TODO: Support UNPIVOT in physical plans + "nestedSelectJoinLimit", // TODO: Support UNPIVOT in physical plans + "pivotFrom", // TODO: Support PIVOT in physical plans + "pivotLiteralFieldNameFrom", // TODO: Support PIVOT in physical plans + "pivotBadFieldType", // TODO: Support PIVOT in physical plans + "pivotUnpivotWithWhereLimit", // TODO: Support PIVOT in physical plans + "topLevelCountDistinct", // TODO: Support aggregates in physical plans + "topLevelCount", // TODO: Support aggregates in physical plans + "topLevelAllCount", // TODO: Support aggregates in physical plans + "topLevelSum", // TODO: Support aggregates in physical plans + "topLevelAllSum", // TODO: Support aggregates in physical plans + "topLevelDistinctSum", // TODO: Support aggregates in physical plans + "topLevelMin", // TODO: Support aggregates in physical plans + "topLevelDistinctMin", // TODO: Support aggregates in physical plans + "topLevelAllMin", // TODO: Support aggregates in physical plans + "topLevelMax", // TODO: Support aggregates in physical plans + "topLevelDistinctMax", // TODO: Support aggregates in physical plans + "topLevelAllMax", // TODO: Support aggregates in physical plans + "topLevelAvg", // TODO: Support aggregates in physical plans + "topLevelDistinctAvg", // TODO: Support aggregates in physical plans + "topLevelAvgOnlyInt", // TODO: Support aggregates in physical plans + "selectValueAggregate", // TODO: Support aggregates in physical plans + "selectListCountStar", // TODO: Support aggregates in physical plans + "selectListCountVariable", // TODO: Support aggregates in physical plans + "selectListMultipleAggregates", // TODO: Support aggregates in physical plans + "selectListMultipleAggregatesNestedQuery", // TODO: Support aggregates in physical plans + "aggregateInSubqueryOfSelect", // TODO: Support aggregates in physical plans + "aggregateInSubqueryOfSelectValue", // TODO: Support aggregates in physical plans + "aggregateWithAliasingInSubqueryOfSelectValue", // TODO: Support aggregates in physical plans + "selectDistinctWithAggregate", // TODO: Support aggregates in physical plans + "selectDistinctAggregationWithGroupBy", // TODO: Support GROUP BY in physical plans + "selectDistinctWithGroupBy", // TODO: Support GROUP BY in physical plans + "unpivotStructWithMissingField", // TODO: Support UNPIVOT in physical plans + "unpivotMissing", // TODO: Support UNPIVOT in physical plans + "unpivotEmptyStruct", // TODO: Support UNPIVOT in physical plans + "unpivotMissingWithAsAndAt", // TODO: Support UNPIVOT in physical plans + "unpivotMissingCrossJoinWithAsAndAt", // TODO: Support UNPIVOT in physical plans + + // UndefinedVariableBehavior.MISSING not supported by plan evaluator + "undefinedUnqualifiedVariableWithUndefinedVariableBehaviorMissing", + "undefinedUnqualifiedVariableIsNullExprWithUndefinedVariableBehaviorMissing", + "undefinedUnqualifiedVariableIsMissingExprWithUndefinedVariableBehaviorMissing", + "undefinedUnqualifiedVariableInSelectWithUndefinedVariableBehaviorMissing", + ) + + @JvmStatic + @Suppress("UNUSED") + fun planEvaluatorTests(): List = + // Since the physical plan evaluator is a modified copy of the AST evaluator, it inherits the + // AST evaluator's current skip list. The physical plan evaluator also doesn't yet implement + // everything that the AST evaluator does, so has a separate skip list. + astEvaluatorTests().filter { it.name !in PLAN_EVALUATOR_SKIP_LIST } } @ParameterizedTest - @MethodSource("evaluatorTests") - fun allTests(tc: IonResultTestCase) = tc.runTestCase(valueFactory, mockDb) + @MethodSource("astEvaluatorTests") + fun astEvaluatorTests(tc: IonResultTestCase) = tc.runTestCase(valueFactory, mockDb, EvaluatorTestTarget.COMPILER_PIPELINE) + + @ParameterizedTest + @MethodSource("planEvaluatorTests") + fun planEvaluatorTests(tc: IonResultTestCase) = tc.runTestCase(valueFactory, mockDb, EvaluatorTestTarget.PLANNER_PIPELINE) } diff --git a/lang/test/org/partiql/lang/eval/QuotedIdentifierTests.kt b/lang/test/org/partiql/lang/eval/QuotedIdentifierTests.kt index 3f7500a487..3eb3a9c641 100644 --- a/lang/test/org/partiql/lang/eval/QuotedIdentifierTests.kt +++ b/lang/test/org/partiql/lang/eval/QuotedIdentifierTests.kt @@ -17,6 +17,7 @@ package org.partiql.lang.eval import org.junit.Test import org.partiql.lang.errors.ErrorCode import org.partiql.lang.errors.Property +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.eval.evaluatortestframework.ExpectedResultFormat import org.partiql.lang.util.propertyValueMapOf @@ -60,6 +61,8 @@ class QuotedIdentifierTests : EvaluatorTestBase() { session = simpleSession, expectedResult = "MISSING", expectedResultFormat = ExpectedResultFormat.PARTIQL, + // planner & physical plan have no support for UndefinedVariableBehavior.MISSING (and may never) + target = EvaluatorTestTarget.COMPILER_PIPELINE, compileOptionsBuilderBlock = { undefinedVariableMissingCompileOptionBlock() }, ) runEvaluatorTestCase( @@ -67,6 +70,8 @@ class QuotedIdentifierTests : EvaluatorTestBase() { session = simpleSession, expectedResult = "MISSING", expectedResultFormat = ExpectedResultFormat.PARTIQL, + // planner & physical plan have no support for UndefinedVariableBehavior.MISSING (and may never) + target = EvaluatorTestTarget.COMPILER_PIPELINE, compileOptionsBuilderBlock = { undefinedVariableMissingCompileOptionBlock() }, ) @@ -75,6 +80,8 @@ class QuotedIdentifierTests : EvaluatorTestBase() { query = "\"Abc\"", session = simpleSession, expectedResult = "1", + // planner & physical plan have no support for UndefinedVariableBehavior.MISSING (and may never) + target = EvaluatorTestTarget.COMPILER_PIPELINE, compileOptionsBuilderBlock = undefinedVariableMissingCompileOptionBlock ) } @@ -109,7 +116,8 @@ class QuotedIdentifierTests : EvaluatorTestBase() { Property.BINDING_NAME_MATCHES to "Abc, aBc, abC" ), expectedPermissiveModeResult = "MISSING", - session = simpleSession + session = simpleSession, + target = EvaluatorTestTarget.COMPILER_PIPELINE // Planner will never throw ambiguous binding error ) } diff --git a/lang/test/org/partiql/lang/eval/SimpleEvaluatingCompilerTests.kt b/lang/test/org/partiql/lang/eval/SimpleEvaluatingCompilerTests.kt index 59f5b88342..31bf1bab1c 100644 --- a/lang/test/org/partiql/lang/eval/SimpleEvaluatingCompilerTests.kt +++ b/lang/test/org/partiql/lang/eval/SimpleEvaluatingCompilerTests.kt @@ -17,6 +17,7 @@ package org.partiql.lang.eval import org.junit.Test import org.partiql.lang.errors.ErrorCode import org.partiql.lang.errors.Property +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.util.propertyValueMapOf class SimpleEvaluatingCompilerTests : EvaluatorTestBase() { @@ -73,7 +74,9 @@ class SimpleEvaluatingCompilerTests : EvaluatorTestBase() { {name:"d",val:4}, {name:"e",val:5}, {name:"f",val:6} - ]""" + ]""", + // planner & physical plan have no support for UNPIVOT (yet) + target = EvaluatorTestTarget.COMPILER_PIPELINE, ) @Test @@ -87,6 +90,27 @@ class SimpleEvaluatingCompilerTests : EvaluatorTestBase() { ]""" ) + private val sessionWithG = mapOf( + "table_1" to "[{a:[{b: 1}, {b:2}]}]", + "g" to "{a: \"from global variable g\"}" + ).toSession() + + /** Demonstrates that without the scope qualifier ('@'), the `g` in `g.a' refers to global `g`. */ + @Test + fun joinWithoutScopeQualifier() = runEvaluatorTestCase( + """SELECT g2 FROM table_1 AS g, g.a AS g2""", + expectedResult = "[{g2:\"from global variable g\"}]", + session = sessionWithG + ) + + /** Demonstrates that with the scope qualifier ('@'), the `g` in `@g.a' refers to local `g`. */ + @Test + fun joinWithScopeQualifier() = runEvaluatorTestCase( + """SELECT g2 FROM table_1 AS g, @g.a AS g2""", + expectedResult = "[{g2:{b:1}},{g2:{b:2}}]", + session = sessionWithG + ) + @Test fun simpleJoinWithCondition() = runEvaluatorTestCase( """ @@ -113,4 +137,25 @@ class SimpleEvaluatingCompilerTests : EvaluatorTestBase() { propertyValueMapOf(1, 1, Property.CAST_FROM to "SYMBOL", Property.CAST_TO to "INT"), expectedPermissiveModeResult = "MISSING" ) + + @Test + fun sum() { + // Note: planner & phys. alg. have no support for aggregates (yet) + runEvaluatorTestCase("SUM(`[1, 2, 3]`)", expectedResult = "6", target = EvaluatorTestTarget.COMPILER_PIPELINE) + runEvaluatorTestCase("SUM(`[1, 2e0, 3e0]`)", expectedResult = "6e0", target = EvaluatorTestTarget.COMPILER_PIPELINE) + runEvaluatorTestCase("SUM(`[1, 2d0, 3d0]`)", expectedResult = "6d0", target = EvaluatorTestTarget.COMPILER_PIPELINE) + runEvaluatorTestCase("SUM(`[1, 2e0, 3d0]`)", expectedResult = "6d0", target = EvaluatorTestTarget.COMPILER_PIPELINE) + runEvaluatorTestCase("SUM(`[1, 2d0, 3e0]`)", expectedResult = "6d0", target = EvaluatorTestTarget.COMPILER_PIPELINE) + } + + @Test + fun max() { + // Note: planner & phys. alg. have no support for aggregates (yet) + runEvaluatorTestCase("max(`[1, 2, 3]`)", expectedResult = "3", target = EvaluatorTestTarget.COMPILER_PIPELINE) + runEvaluatorTestCase("max(`[1, 2.0, 3]`)", expectedResult = "3", target = EvaluatorTestTarget.COMPILER_PIPELINE) + runEvaluatorTestCase("max(`[1, 2e0, 3e0]`)", expectedResult = "3e0", target = EvaluatorTestTarget.COMPILER_PIPELINE) + runEvaluatorTestCase("max(`[1, 2d0, 3d0]`)", expectedResult = "3d0", target = EvaluatorTestTarget.COMPILER_PIPELINE) + runEvaluatorTestCase("max(`[1, 2e0, 3d0]`)", expectedResult = "3d0", target = EvaluatorTestTarget.COMPILER_PIPELINE) + runEvaluatorTestCase("max(`[1, 2d0, 3e0]`)", expectedResult = "3e0", target = EvaluatorTestTarget.COMPILER_PIPELINE) + } } diff --git a/lang/test/org/partiql/lang/eval/ThunkFactoryTests.kt b/lang/test/org/partiql/lang/eval/ThunkFactoryTests.kt index c288dda71b..3d12ea3b6a 100644 --- a/lang/test/org/partiql/lang/eval/ThunkFactoryTests.kt +++ b/lang/test/org/partiql/lang/eval/ThunkFactoryTests.kt @@ -20,8 +20,8 @@ import kotlin.test.assertEquals class ThunkFactoryTests { companion object { - private val compileOptions = CompileOptions.build { - this.evaluationTimeTypeChecks(ThunkReturnTypeAssertions.ENABLED) + private val compileOptions = ThunkOptions.build { + evaluationTimeTypeChecks(ThunkReturnTypeAssertions.ENABLED) } private val valueFactory = ExprValueFactory.standard(ION) @@ -38,7 +38,7 @@ class ThunkFactoryTests { val expectedType: StaticType, val thunkReturnValue: ExprValue, val expectError: Boolean, - internal val thunkFactory: ThunkFactory + internal val thunkFactory: ThunkFactory ) { val metas = metaContainerOf(StaticTypeMeta(expectedType)) val fakeThunk = thunkFactory.thunkEnv(IRRELEVANT_METAS) { IRRELEVANT } diff --git a/lang/test/org/partiql/lang/eval/TypingModeTests.kt b/lang/test/org/partiql/lang/eval/TypingModeTests.kt index 28adb263d7..fc4022e5ce 100644 --- a/lang/test/org/partiql/lang/eval/TypingModeTests.kt +++ b/lang/test/org/partiql/lang/eval/TypingModeTests.kt @@ -58,18 +58,18 @@ class TypingModeTests : EvaluatorTestBase() { expectedErrorCode = tc.expectedLegacyError.errorCode, expectedPermissiveModeResult = tc.expectedPermissiveModeResult, addtionalExceptionAssertBlock = { ex: SqlException -> - // Have to use the addtionalExceptionAssertBlock instead of error context for this + // Have to use the additionalExceptionAssertBlock instead of error context for this // because there are a few cases with error context values other than line & column that we don't // account for in [TestCase]. assertEquals( "line number", tc.expectedLegacyError.lineNum.toLong(), - ex.errorContext?.get(Property.LINE_NUMBER)?.longValue() + ex.errorContext[Property.LINE_NUMBER]?.longValue() ) assertEquals( "column number", tc.expectedLegacyError.charOffset.toLong(), - ex.errorContext?.get(Property.COLUMN_NUMBER)?.longValue() + ex.errorContext[Property.COLUMN_NUMBER]?.longValue() ) } ) diff --git a/lang/test/org/partiql/lang/eval/builtins/BuiltInFunctionTestExtensions.kt b/lang/test/org/partiql/lang/eval/builtins/BuiltInFunctionTestExtensions.kt index 8adc72f79b..952b58c9e0 100644 --- a/lang/test/org/partiql/lang/eval/builtins/BuiltInFunctionTestExtensions.kt +++ b/lang/test/org/partiql/lang/eval/builtins/BuiltInFunctionTestExtensions.kt @@ -6,6 +6,7 @@ import org.partiql.lang.eval.Bindings import org.partiql.lang.eval.EvaluationSession import org.partiql.lang.eval.ExprValue import org.partiql.lang.eval.ExprValueFactory +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.util.newFromIonText /** @@ -19,8 +20,13 @@ internal fun checkInvalidArgType(funcName: String, syntaxSuffix: String = "(", a * Internal function used by ExprFunctionTest to test invalid arity. */ internal val invalidArityChecker = InvalidArityChecker() -internal fun checkInvalidArity(funcName: String, minArity: Int, maxArity: Int) = - invalidArityChecker.checkInvalidArity(funcName, minArity, maxArity) +internal fun checkInvalidArity( + funcName: String, + minArity: Int, + maxArity: Int, + targetPipeline: EvaluatorTestTarget = EvaluatorTestTarget.ALL_PIPELINES +) = + invalidArityChecker.checkInvalidArity(funcName, minArity, maxArity, targetPipeline) private val valueFactory = ExprValueFactory.standard(ION) diff --git a/lang/test/org/partiql/lang/eval/builtins/InvalidArityChecker.kt b/lang/test/org/partiql/lang/eval/builtins/InvalidArityChecker.kt index 9fe3b14bf9..848f884f5a 100644 --- a/lang/test/org/partiql/lang/eval/builtins/InvalidArityChecker.kt +++ b/lang/test/org/partiql/lang/eval/builtins/InvalidArityChecker.kt @@ -3,6 +3,7 @@ package org.partiql.lang.eval.builtins import org.partiql.lang.errors.ErrorCode import org.partiql.lang.errors.Property import org.partiql.lang.eval.EvaluatorTestBase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.util.propertyValueMapOf /** @@ -31,7 +32,7 @@ class InvalidArityChecker : EvaluatorTestBase() { * @param maxArity is the maximum arity of an ExprFunction. * @param minArity is the minimum arity of an ExprFunction. */ - fun checkInvalidArity(funcName: String, minArity: Int, maxArity: Int) { + fun checkInvalidArity(funcName: String, minArity: Int, maxArity: Int, targetPipeline: EvaluatorTestTarget) { if (minArity < 0) throw IllegalStateException("Minimum arity has to be larger than 0.") if (maxArity < minArity) throw IllegalStateException("Maximum arity has to be larger than or equal to minimum arity.") @@ -44,7 +45,7 @@ class InvalidArityChecker : EvaluatorTestBase() { else -> sb.append(",null") } if (curArity < minArity || curArity > maxArity) { // If less or more argument provided, we catch invalid arity error - assertThrowsInvalidArity("$sb)", funcName, curArity, minArity, maxArity) + assertThrowsInvalidArity("$sb)", funcName, curArity, minArity, maxArity, targetPipeline) } } } @@ -54,7 +55,8 @@ class InvalidArityChecker : EvaluatorTestBase() { funcName: String, actualArity: Int, minArity: Int, - maxArity: Int + maxArity: Int, + targetPipeline: EvaluatorTestTarget ) = runEvaluatorErrorTestCase( query = query, expectedErrorCode = ErrorCode.EVALUATOR_INCORRECT_NUMBER_OF_ARGUMENTS_TO_FUNC_CALL, @@ -65,5 +67,6 @@ class InvalidArityChecker : EvaluatorTestBase() { Property.EXPECTED_ARITY_MAX to maxArity, Property.ACTUAL_ARITY to actualArity ), + target = targetPipeline ) } diff --git a/lang/test/org/partiql/lang/eval/builtins/aggfunctions/AvgTests.kt b/lang/test/org/partiql/lang/eval/builtins/aggfunctions/AvgTests.kt index 9042bba673..b2e83bda2d 100644 --- a/lang/test/org/partiql/lang/eval/builtins/aggfunctions/AvgTests.kt +++ b/lang/test/org/partiql/lang/eval/builtins/aggfunctions/AvgTests.kt @@ -3,40 +3,71 @@ package org.partiql.lang.eval.builtins.aggfunctions import org.junit.Test import org.partiql.lang.errors.ErrorCode import org.partiql.lang.eval.EvaluatorTestBase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget class AvgTests : EvaluatorTestBase() { @Test - fun avgNull() = runEvaluatorTestCase("AVG([null, null])", expectedResult = "null") + fun avgNull() = runEvaluatorTestCase( + query = "AVG([null, null])", + expectedResult = "null", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun avgMissing() = runEvaluatorTestCase("AVG([missing, missing])", expectedResult = "null") + fun avgMissing() = runEvaluatorTestCase( + query = "AVG([missing, missing])", + expectedResult = "null", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun avgInt() = runEvaluatorTestCase("AVG(`[1, 2, 3]`)", expectedResult = "2.") + fun avgInt() = runEvaluatorTestCase( + query = "AVG(`[1, 2, 3]`)", + expectedResult = "2.", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun avgMixed0() = runEvaluatorTestCase("AVG(`[1, 2e0, 3e0]`)", expectedResult = "2.") + fun avgMixed0() = runEvaluatorTestCase( + query = "AVG(`[1, 2e0, 3e0]`)", + expectedResult = "2.", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun avgMixed1() = runEvaluatorTestCase("AVG(`[1, 2d0, 3d0]`)", expectedResult = "2.") + fun avgMixed1() = runEvaluatorTestCase( + query = "AVG(`[1, 2d0, 3d0]`)", + expectedResult = "2.", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun avgMixed2() = runEvaluatorTestCase("AVG(`[1, 2e0, 3d0]`)", expectedResult = "2.") + fun avgMixed2() = runEvaluatorTestCase( + query = "AVG(`[1, 2e0, 3d0]`)", + expectedResult = "2.", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun avgMixed3() = runEvaluatorTestCase("AVG(`[1, 2d0, 3e0]`)", expectedResult = "2.") + fun avgMixed3() = runEvaluatorTestCase( + query = "AVG(`[1, 2d0, 3e0]`)", + expectedResult = "2.", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test fun avgOverflow() = runEvaluatorErrorTestCase( - "AVG([1, 9223372036854775807])", + query = "AVG([1, 9223372036854775807])", ErrorCode.EVALUATOR_INTEGER_OVERFLOW, - expectedPermissiveModeResult = "MISSING" + expectedPermissiveModeResult = "MISSING", + target = EvaluatorTestTarget.COMPILER_PIPELINE ) @Test fun avgUnderflow() = runEvaluatorErrorTestCase( - "AVG([-1, -9223372036854775808])", + query = "AVG([-1, -9223372036854775808])", ErrorCode.EVALUATOR_INTEGER_OVERFLOW, - expectedPermissiveModeResult = "MISSING" + expectedPermissiveModeResult = "MISSING", + target = EvaluatorTestTarget.COMPILER_PIPELINE ) } diff --git a/lang/test/org/partiql/lang/eval/builtins/aggfunctions/CountTests.kt b/lang/test/org/partiql/lang/eval/builtins/aggfunctions/CountTests.kt index ae0821d45e..b363471d8a 100644 --- a/lang/test/org/partiql/lang/eval/builtins/aggfunctions/CountTests.kt +++ b/lang/test/org/partiql/lang/eval/builtins/aggfunctions/CountTests.kt @@ -2,56 +2,116 @@ package org.partiql.lang.eval.builtins.aggfunctions import org.junit.Test import org.partiql.lang.eval.EvaluatorTestBase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget class CountTests : EvaluatorTestBase() { @Test - fun countEmpty() = runEvaluatorTestCase("COUNT(`[]`)", expectedResult = "0") + fun countEmpty() = + runEvaluatorTestCase(query = "COUNT(`[]`)", expectedResult = "0", target = EvaluatorTestTarget.COMPILER_PIPELINE) @Test - fun countNull() = runEvaluatorTestCase("COUNT([null, null])", expectedResult = "0") + fun countNull() = runEvaluatorTestCase( + query = "COUNT([null, null])", + expectedResult = "0", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun countMissing() = runEvaluatorTestCase("COUNT([missing])", expectedResult = "0") + fun countMissing() = + runEvaluatorTestCase("COUNT([missing])", expectedResult = "0", target = EvaluatorTestTarget.COMPILER_PIPELINE) @Test - fun countBoolean() = runEvaluatorTestCase("COUNT(`[true, false]`)", expectedResult = "2") + fun countBoolean() = runEvaluatorTestCase( + query = "COUNT(`[true, false]`)", + expectedResult = "2", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun countInt() = runEvaluatorTestCase("COUNT(`[1, 2, 3]`)", expectedResult = "3") + fun countInt() = + runEvaluatorTestCase("COUNT(`[1, 2, 3]`)", expectedResult = "3", target = EvaluatorTestTarget.COMPILER_PIPELINE) @Test - fun countDecimal() = runEvaluatorTestCase("COUNT(`[1e0, 2e0, 3e0]`)", expectedResult = "3") + fun countDecimal() = runEvaluatorTestCase( + query = "COUNT(`[1e0, 2e0, 3e0]`)", + expectedResult = "3", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun countFloat() = runEvaluatorTestCase("COUNT(`[1e0, 2e0, 3e0]`)", expectedResult = "3") + fun countFloat() = runEvaluatorTestCase( + query = "COUNT(`[1e0, 2e0, 3e0]`)", + expectedResult = "3", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun countString() = runEvaluatorTestCase("COUNT(`[\"1\", \"2\", \"3\"]`)", expectedResult = "3") + fun countString() = runEvaluatorTestCase( + query = "COUNT(`[\"1\", \"2\", \"3\"]`)", + expectedResult = "3", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun countTimestamp() = runEvaluatorTestCase("COUNT(`[2020-01-01T00:00:00Z, 2020-01-01T00:00:01Z]`)", expectedResult = "2") + fun countTimestamp() = runEvaluatorTestCase( + query = "COUNT(`[2020-01-01T00:00:00Z, 2020-01-01T00:00:01Z]`)", + expectedResult = "2", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun countBlob() = runEvaluatorTestCase("COUNT(`[{{ aaaa }}, {{ aaab }}]`)", expectedResult = "2") + fun countBlob() = runEvaluatorTestCase( + query = "COUNT(`[{{ aaaa }}, {{ aaab }}]`)", + expectedResult = "2", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun countClob() = runEvaluatorTestCase("COUNT(`[{{ \"aaaa\" }}, {{ \"aaab\" }}]`)", expectedResult = "2") + fun countClob() = runEvaluatorTestCase( + query = "COUNT(`[{{ \"aaaa\" }}, {{ \"aaab\" }}]`)", + expectedResult = "2", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun countSexp() = runEvaluatorTestCase("COUNT(`[(1), (2)]`)", expectedResult = "2") + fun countSexp() = runEvaluatorTestCase( + query = "COUNT(`[(1), (2)]`)", + expectedResult = "2", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun countList() = runEvaluatorTestCase("COUNT(`[[1], [2]]`)", expectedResult = "2") + fun countList() = runEvaluatorTestCase( + query = "COUNT(`[[1], [2]]`)", + expectedResult = "2", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun countBag() = runEvaluatorTestCase("COUNT([<<1>>, <<2>>])", expectedResult = "2") + fun countBag() = runEvaluatorTestCase( + query = "COUNT([<<1>>, <<2>>])", + expectedResult = "2", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun countStruct() = runEvaluatorTestCase("COUNT(`[{'a':1}, {'a':2}]`)", expectedResult = "2") + fun countStruct() = runEvaluatorTestCase( + query = "COUNT(`[{'a':1}, {'a':2}]`)", + expectedResult = "2", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun countMixed0() = runEvaluatorTestCase("COUNT([null, missing, 1, 2])", expectedResult = "2") + fun countMixed0() = runEvaluatorTestCase( + query = "COUNT([null, missing, 1, 2])", + expectedResult = "2", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun countMixed1() = runEvaluatorTestCase("COUNT([1, '2', true, `2020-01-01T00:00:00Z`, `{{ aaaa }}`])", expectedResult = "5") + fun countMixed1() = runEvaluatorTestCase( + query = "COUNT([1, '2', true, `2020-01-01T00:00:00Z`, `{{ aaaa }}`])", + expectedResult = "5", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) } diff --git a/lang/test/org/partiql/lang/eval/builtins/aggfunctions/MaxTests.kt b/lang/test/org/partiql/lang/eval/builtins/aggfunctions/MaxTests.kt index 2d833293f2..a27caf8607 100644 --- a/lang/test/org/partiql/lang/eval/builtins/aggfunctions/MaxTests.kt +++ b/lang/test/org/partiql/lang/eval/builtins/aggfunctions/MaxTests.kt @@ -2,77 +2,174 @@ package org.partiql.lang.eval.builtins.aggfunctions import org.junit.Test import org.partiql.lang.eval.EvaluatorTestBase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget class MaxTests : EvaluatorTestBase() { @Test - fun maxNull() = runEvaluatorTestCase("max([null, null])", expectedResult = "null") + fun maxNull() = runEvaluatorTestCase( + query = "max([null, null])", + expectedResult = "null", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxMissing() = runEvaluatorTestCase("max([missing, missing])", expectedResult = "null") + fun maxMissing() = runEvaluatorTestCase( + query = "max([missing, missing])", + expectedResult = "null", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxNumber0() = runEvaluatorTestCase("max(`[1, 2, 3]`)", expectedResult = "3") + fun maxNumber0() = runEvaluatorTestCase( + query = "max(`[1, 2, 3]`)", + expectedResult = "3", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxNumber1() = runEvaluatorTestCase("max(`[1, 2.0, 3]`)", expectedResult = "3") + fun maxNumber1() = runEvaluatorTestCase( + query = "max(`[1, 2.0, 3]`)", + expectedResult = "3", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxNumber2() = runEvaluatorTestCase("max(`[1, 2e0, 3e0]`)", expectedResult = "3e0") + fun maxNumber2() = runEvaluatorTestCase( + query = "max(`[1, 2e0, 3e0]`)", + expectedResult = "3e0", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxNumber3() = runEvaluatorTestCase("max(`[1, 2d0, 3d0]`)", expectedResult = "3d0") + fun maxNumber3() = runEvaluatorTestCase( + query = "max(`[1, 2d0, 3d0]`)", + expectedResult = "3d0", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxNumber4() = runEvaluatorTestCase("max(`[1, 2e0, 3d0]`)", expectedResult = "3d0") + fun maxNumber4() = runEvaluatorTestCase( + query = "max(`[1, 2e0, 3d0]`)", + expectedResult = "3d0", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxNumber5() = runEvaluatorTestCase("max(`[1, 2d0, 3e0]`)", expectedResult = "3e0") + fun maxNumber5() = runEvaluatorTestCase( + query = "max(`[1, 2d0, 3e0]`)", + expectedResult = "3e0", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxString0() = runEvaluatorTestCase("max(['a', 'abc', '3'])", expectedResult = "\"abc\"") + fun maxString0() = runEvaluatorTestCase( + query = "max(['a', 'abc', '3'])", + expectedResult = "\"abc\"", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxString1() = runEvaluatorTestCase("max(['1', '2', '3', null])", expectedResult = "\"3\"") + fun maxString1() = runEvaluatorTestCase( + query = "max(['1', '2', '3', null])", + expectedResult = "\"3\"", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxTimestamp0() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, `2020-01-01T00:00:01Z`, `2020-01-01T00:00:02Z`])", expectedResult = "2020-01-01T00:00:02Z") + fun maxTimestamp0() = runEvaluatorTestCase( + query = "max([`2020-01-01T00:00:00Z`, `2020-01-01T00:00:01Z`, `2020-01-01T00:00:02Z`])", + expectedResult = "2020-01-01T00:00:02Z", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxTimestamp1() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, `2020-01-01T00:01:00Z`, `2020-01-01T00:02:00Z`])", expectedResult = "2020-01-01T00:02:00Z") + fun maxTimestamp1() = runEvaluatorTestCase( + query = "max([`2020-01-01T00:00:00Z`, `2020-01-01T00:01:00Z`, `2020-01-01T00:02:00Z`])", + expectedResult = "2020-01-01T00:02:00Z", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxTimestamp2() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, `2020-01-01T01:00:00Z`, `2020-01-01T02:00:00Z`])", expectedResult = "2020-01-01T02:00:00Z") + fun maxTimestamp2() = runEvaluatorTestCase( + query = "max([`2020-01-01T00:00:00Z`, `2020-01-01T01:00:00Z`, `2020-01-01T02:00:00Z`])", + expectedResult = "2020-01-01T02:00:00Z", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxTimestamp3() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, `2020-01-02T00:00:00Z`, `2020-01-03T00:00:00Z`])", expectedResult = "2020-01-03T00:00:00Z") + fun maxTimestamp3() = runEvaluatorTestCase( + query = "max([`2020-01-01T00:00:00Z`, `2020-01-02T00:00:00Z`, `2020-01-03T00:00:00Z`])", + expectedResult = "2020-01-03T00:00:00Z", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxTimestamp4() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, `2020-02-01T00:00:00Z`, `2020-03-01T00:00:00Z`])", expectedResult = "2020-03-01T00:00:00Z") + fun maxTimestamp4() = runEvaluatorTestCase( + query = "max([`2020-01-01T00:00:00Z`, `2020-02-01T00:00:00Z`, `2020-03-01T00:00:00Z`])", + expectedResult = "2020-03-01T00:00:00Z", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxTimestamp5() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, `2021-01-01T00:00:00Z`, `2022-01-01T00:00:00Z`])", expectedResult = "2022-01-01T00:00:00Z") + fun maxTimestamp5() = runEvaluatorTestCase( + query = "max([`2020-01-01T00:00:00Z`, `2021-01-01T00:00:00Z`, `2022-01-01T00:00:00Z`])", + expectedResult = "2022-01-01T00:00:00Z", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxTimestamp6() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, `2020-01-01T00:00:01Z`, `2020-01-01T00:00:02Z`, null])", expectedResult = "2020-01-01T00:00:02Z") + fun maxTimestamp6() = runEvaluatorTestCase( + query = "max([`2020-01-01T00:00:00Z`, `2020-01-01T00:00:01Z`, `2020-01-01T00:00:02Z`, null])", + expectedResult = "2020-01-01T00:00:02Z", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxBoolean() = runEvaluatorTestCase("max([false, true])", expectedResult = "true") + fun maxBoolean() = runEvaluatorTestCase( + query = "max([false, true])", + expectedResult = "true", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxBlob() = runEvaluatorTestCase("max([`{{ aaaa }}`, `{{ aaab }}`])", expectedResult = "{{aaab}}") + fun maxBlob() = runEvaluatorTestCase( + query = "max([`{{ aaaa }}`, `{{ aaab }}`])", + expectedResult = "{{aaab}}", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxClob() = runEvaluatorTestCase("max([`{{\"a\"}}`, `{{\"b\"}}`])", expectedResult = "{{\"b\"}}") + fun maxClob() = runEvaluatorTestCase( + query = "max([`{{\"a\"}}`, `{{\"b\"}}`])", + expectedResult = "{{\"b\"}}", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxMixed0() = runEvaluatorTestCase("max([false, 1])", expectedResult = "1") + fun maxMixed0() = runEvaluatorTestCase( + query = "max([false, 1])", + expectedResult = "1", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxMixed1() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, 1])", expectedResult = "2020-01-01T00:00:00Z") + fun maxMixed1() = runEvaluatorTestCase( + query = "max([`2020-01-01T00:00:00Z`, 1])", + expectedResult = "2020-01-01T00:00:00Z", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxMixed2() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, '1'])", expectedResult = "\"1\"") + fun maxMixed2() = runEvaluatorTestCase( + query = "max([`2020-01-01T00:00:00Z`, '1'])", + expectedResult = "\"1\"", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun maxMixed3() = runEvaluatorTestCase("max([`{{\"abcd\"}}`, '1'])", expectedResult = "{{\"abcd\"}}") + fun maxMixed3() = runEvaluatorTestCase( + query = "max([`{{\"abcd\"}}`, '1'])", + expectedResult = "{{\"abcd\"}}", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) } diff --git a/lang/test/org/partiql/lang/eval/builtins/aggfunctions/MinTests.kt b/lang/test/org/partiql/lang/eval/builtins/aggfunctions/MinTests.kt index 682f8b8b3a..e4d69457df 100644 --- a/lang/test/org/partiql/lang/eval/builtins/aggfunctions/MinTests.kt +++ b/lang/test/org/partiql/lang/eval/builtins/aggfunctions/MinTests.kt @@ -2,77 +2,174 @@ package org.partiql.lang.eval.builtins.aggfunctions import org.junit.Test import org.partiql.lang.eval.EvaluatorTestBase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget class MinTests : EvaluatorTestBase() { @Test - fun minNull() = runEvaluatorTestCase("min([null, null])", expectedResult = "null") + fun minNull() = runEvaluatorTestCase( + query = "min([null, null])", + expectedResult = "null", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minMissing() = runEvaluatorTestCase("min([missing, missing])", expectedResult = "null") + fun minMissing() = runEvaluatorTestCase( + query = "min([missing, missing])", + expectedResult = "null", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minNumber0() = runEvaluatorTestCase("min(`[1, 2, 3]`)", expectedResult = "1") + fun minNumber0() = runEvaluatorTestCase( + query = "min(`[1, 2, 3]`)", + expectedResult = "1", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minNumber1() = runEvaluatorTestCase("min(`[1, 2.0, 3]`)", expectedResult = "1") + fun minNumber1() = runEvaluatorTestCase( + query = "min(`[1, 2.0, 3]`)", + expectedResult = "1", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minNumber2() = runEvaluatorTestCase("min(`[1, 2e0, 3e0]`)", expectedResult = "1") + fun minNumber2() = runEvaluatorTestCase( + query = "min(`[1, 2e0, 3e0]`)", + expectedResult = "1", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minNumber3() = runEvaluatorTestCase("min(`[1, 2d0, 3d0]`)", expectedResult = "1") + fun minNumber3() = runEvaluatorTestCase( + query = "min(`[1, 2d0, 3d0]`)", + expectedResult = "1", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minNumber4() = runEvaluatorTestCase("min(`[1, 2e0, 3d0]`)", expectedResult = "1") + fun minNumber4() = runEvaluatorTestCase( + query = "min(`[1, 2e0, 3d0]`)", + expectedResult = "1", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minNumber5() = runEvaluatorTestCase("min(`[1, 2d0, 3e0]`)", expectedResult = "1") + fun minNumber5() = runEvaluatorTestCase( + query = "min(`[1, 2d0, 3e0]`)", + expectedResult = "1", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minString0() = runEvaluatorTestCase("min(['a', 'abc', '3'])", expectedResult = "\"3\"") + fun minString0() = runEvaluatorTestCase( + query = "min(['a', 'abc', '3'])", + expectedResult = "\"3\"", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minString1() = runEvaluatorTestCase("min(['1', '2', '3', null])", expectedResult = "\"1\"") + fun minString1() = runEvaluatorTestCase( + query = "min(['1', '2', '3', null])", + expectedResult = "\"1\"", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minTimestamp0() = runEvaluatorTestCase("min([`2020-01-01T00:00:00Z`, `2020-01-01T00:00:01Z`, `2020-01-01T00:00:02Z`])", expectedResult = "2020-01-01T00:00:00Z") + fun minTimestamp0() = runEvaluatorTestCase( + query = "min([`2020-01-01T00:00:00Z`, `2020-01-01T00:00:01Z`, `2020-01-01T00:00:02Z`])", + expectedResult = "2020-01-01T00:00:00Z", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minTimestamp1() = runEvaluatorTestCase("min([`2020-01-01T00:00:00Z`, `2020-01-01T00:01:00Z`, `2020-01-01T00:02:00Z`])", expectedResult = "2020-01-01T00:00:00Z") + fun minTimestamp1() = runEvaluatorTestCase( + query = "min([`2020-01-01T00:00:00Z`, `2020-01-01T00:01:00Z`, `2020-01-01T00:02:00Z`])", + expectedResult = "2020-01-01T00:00:00Z", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minTimestamp2() = runEvaluatorTestCase("min([`2020-01-01T00:00:00Z`, `2020-01-01T01:00:00Z`, `2020-01-01T02:00:00Z`])", expectedResult = "2020-01-01T00:00:00Z") + fun minTimestamp2() = runEvaluatorTestCase( + query = "min([`2020-01-01T00:00:00Z`, `2020-01-01T01:00:00Z`, `2020-01-01T02:00:00Z`])", + expectedResult = "2020-01-01T00:00:00Z", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minTimestamp3() = runEvaluatorTestCase("min([`2020-01-01T00:00:00Z`, `2020-01-02T00:00:00Z`, `2020-01-03T00:00:00Z`])", expectedResult = "2020-01-01T00:00:00Z") + fun minTimestamp3() = runEvaluatorTestCase( + query = "min([`2020-01-01T00:00:00Z`, `2020-01-02T00:00:00Z`, `2020-01-03T00:00:00Z`])", + expectedResult = "2020-01-01T00:00:00Z", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minTimestamp4() = runEvaluatorTestCase("min([`2020-01-01T00:00:00Z`, `2020-02-01T00:00:00Z`, `2020-03-01T00:00:00Z`])", expectedResult = "2020-01-01T00:00:00Z") + fun minTimestamp4() = runEvaluatorTestCase( + query = "min([`2020-01-01T00:00:00Z`, `2020-02-01T00:00:00Z`, `2020-03-01T00:00:00Z`])", + expectedResult = "2020-01-01T00:00:00Z", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minTimestamp5() = runEvaluatorTestCase("min([`2020-01-01T00:00:00Z`, `2021-01-01T00:00:00Z`, `2022-01-01T00:00:00Z`])", expectedResult = "2020-01-01T00:00:00Z") + fun minTimestamp5() = runEvaluatorTestCase( + query = "min([`2020-01-01T00:00:00Z`, `2021-01-01T00:00:00Z`, `2022-01-01T00:00:00Z`])", + expectedResult = "2020-01-01T00:00:00Z", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minTimestamp6() = runEvaluatorTestCase("min([`2020-01-01T00:00:00Z`, `2020-01-01T00:00:01Z`, `2020-01-01T00:00:02Z`, null])", expectedResult = "2020-01-01T00:00:00Z") + fun minTimestamp6() = runEvaluatorTestCase( + query = "min([`2020-01-01T00:00:00Z`, `2020-01-01T00:00:01Z`, `2020-01-01T00:00:02Z`, null])", + expectedResult = "2020-01-01T00:00:00Z", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minBoolean() = runEvaluatorTestCase("min([false, true])", expectedResult = "false") + fun minBoolean() = runEvaluatorTestCase( + query = "min([false, true])", + expectedResult = "false", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minBlob() = runEvaluatorTestCase("min([`{{ aaaa }}`, `{{ aaab }}`])", expectedResult = "{{aaaa}}") + fun minBlob() = runEvaluatorTestCase( + query = "min([`{{ aaaa }}`, `{{ aaab }}`])", + expectedResult = "{{aaaa}}", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minClob() = runEvaluatorTestCase("min([`{{\"a\"}}`, `{{\"b\"}}`])", expectedResult = "{{\"a\"}}") + fun minClob() = runEvaluatorTestCase( + query = "min([`{{\"a\"}}`, `{{\"b\"}}`])", + expectedResult = "{{\"a\"}}", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minMixed0() = runEvaluatorTestCase("min([false, 1])", expectedResult = "false") + fun minMixed0() = runEvaluatorTestCase( + query = "min([false, 1])", + expectedResult = "false", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minMixed1() = runEvaluatorTestCase("min([`2020-01-01T00:00:00Z`, 1])", expectedResult = "1") + fun minMixed1() = runEvaluatorTestCase( + query = "min([`2020-01-01T00:00:00Z`, 1])", + expectedResult = "1", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minMixed2() = runEvaluatorTestCase("min([`2020-01-01T00:00:00Z`, '1'])", expectedResult = "2020-01-01T00:00:00Z") + fun minMixed2() = runEvaluatorTestCase( + query = "min([`2020-01-01T00:00:00Z`, '1'])", + expectedResult = "2020-01-01T00:00:00Z", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun minMixed3() = runEvaluatorTestCase("min([`{{\"abcd\"}}`, '1'])", expectedResult = "\"1\"") + fun minMixed3() = runEvaluatorTestCase( + query = "min([`{{\"abcd\"}}`, '1'])", + expectedResult = "\"1\"", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) } diff --git a/lang/test/org/partiql/lang/eval/builtins/aggfunctions/SumTests.kt b/lang/test/org/partiql/lang/eval/builtins/aggfunctions/SumTests.kt index e069626be7..15c6f6db2c 100644 --- a/lang/test/org/partiql/lang/eval/builtins/aggfunctions/SumTests.kt +++ b/lang/test/org/partiql/lang/eval/builtins/aggfunctions/SumTests.kt @@ -3,40 +3,71 @@ package org.partiql.lang.eval.builtins.aggfunctions import org.junit.Test import org.partiql.lang.errors.ErrorCode import org.partiql.lang.eval.EvaluatorTestBase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget class SumTests : EvaluatorTestBase() { @Test - fun sumNull() = runEvaluatorTestCase("SUM([null, null])", expectedResult = "null") + fun sumNull() = runEvaluatorTestCase( + query = "SUM([null, null])", + expectedResult = "null", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun sumMissing() = runEvaluatorTestCase("SUM([missing, missing])", expectedResult = "null") + fun sumMissing() = runEvaluatorTestCase( + query = "SUM([missing, missing])", + expectedResult = "null", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun sum0() = runEvaluatorTestCase("SUM(`[1, 2, 3]`)", expectedResult = "6") + fun sum0() = runEvaluatorTestCase( + query = "SUM(`[1, 2, 3]`)", + expectedResult = "6", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun sum1() = runEvaluatorTestCase("SUM(`[1, 2e0, 3e0]`)", expectedResult = "6e0") + fun sum1() = runEvaluatorTestCase( + query = "SUM(`[1, 2e0, 3e0]`)", + expectedResult = "6e0", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun sum2() = runEvaluatorTestCase("SUM(`[1, 2d0, 3d0]`)", expectedResult = "6d0") + fun sum2() = runEvaluatorTestCase( + query = "SUM(`[1, 2d0, 3d0]`)", + expectedResult = "6d0", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun sum3() = runEvaluatorTestCase("SUM(`[1, 2e0, 3d0]`)", expectedResult = "6d0") + fun sum3() = runEvaluatorTestCase( + query = "SUM(`[1, 2e0, 3d0]`)", + expectedResult = "6d0", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test - fun sum4() = runEvaluatorTestCase("SUM(`[1, 2d0, 3e0]`)", expectedResult = "6d0") + fun sum4() = runEvaluatorTestCase( + query = "SUM(`[1, 2d0, 3e0]`)", + expectedResult = "6d0", + target = EvaluatorTestTarget.COMPILER_PIPELINE + ) @Test fun sumOverflow() = runEvaluatorErrorTestCase( - "SUM([1, 9223372036854775807])", - ErrorCode.EVALUATOR_INTEGER_OVERFLOW, - expectedPermissiveModeResult = "MISSING" + query = "SUM([1, 9223372036854775807])", + expectedErrorCode = ErrorCode.EVALUATOR_INTEGER_OVERFLOW, + expectedPermissiveModeResult = "MISSING", + target = EvaluatorTestTarget.COMPILER_PIPELINE ) @Test fun sumUnderflow() = runEvaluatorErrorTestCase( - "SUM([-1, -9223372036854775808])", - ErrorCode.EVALUATOR_INTEGER_OVERFLOW, - expectedPermissiveModeResult = "MISSING" + query = "SUM([-1, -9223372036854775808])", + expectedErrorCode = ErrorCode.EVALUATOR_INTEGER_OVERFLOW, + expectedPermissiveModeResult = "MISSING", + target = EvaluatorTestTarget.COMPILER_PIPELINE ) } diff --git a/lang/test/org/partiql/lang/eval/builtins/functions/DynamicLookupExprFunctionTest.kt b/lang/test/org/partiql/lang/eval/builtins/functions/DynamicLookupExprFunctionTest.kt new file mode 100644 index 0000000000..f55ba07792 --- /dev/null +++ b/lang/test/org/partiql/lang/eval/builtins/functions/DynamicLookupExprFunctionTest.kt @@ -0,0 +1,153 @@ +package org.partiql.lang.eval.builtins.functions + +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.eval.EvaluatorTestBase +import org.partiql.lang.eval.builtins.DYNAMIC_LOOKUP_FUNCTION_NAME +import org.partiql.lang.eval.builtins.ExprFunctionTestCase +import org.partiql.lang.eval.builtins.checkInvalidArity +import org.partiql.lang.eval.evaluatortestframework.EvaluatorErrorTestCase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget +import org.partiql.lang.eval.evaluatortestframework.ExpectedResultFormat +import org.partiql.lang.util.ArgumentsProviderBase +import org.partiql.lang.util.propertyValueMapOf +import org.partiql.lang.util.to + +class DynamicLookupExprFunctionTest : EvaluatorTestBase() { + val session = mapOf( + "f" to "{ foo: 42 }", + "b" to "{ bar: 43 }", + "foo" to "44", + ).toSession() + + // Pass test cases + @ParameterizedTest + @ArgumentsSource(ToStringPassCases::class) + fun runPassTests(testCase: ExprFunctionTestCase) = + runEvaluatorTestCase( + query = testCase.source, + expectedResult = testCase.expectedLegacyModeResult, + target = EvaluatorTestTarget.PLANNER_PIPELINE, + expectedResultFormat = ExpectedResultFormat.ION, + session = session + ) + + // We rely on the built-in [DEFAULT_COMPARATOR] for the actual definition of equality, which is not being tested + // here. + class ToStringPassCases : ArgumentsProviderBase() { + override fun getParameters(): List = listOf( + // function signature: $__dynamic_lookup__(, , , *) + // arg #1: the name of the field or variable to locate. + // arg #2: case-insensitive or sensitive + // arg #3: look in globals first or locals first. + // arg #4 and later (variadic): any remaining arguments are the variables to search within, which in general + // are structs. note that in general, these will be local variables, however we don't use local variables + // here to simplify these test cases. + + // locals_then_globals + + // `foo` should be found in the variable f, which is a struct + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`foo`, `case_insensitive`, `locals_then_globals`, f, b)", "42"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`fOo`, `case_insensitive`, `locals_then_globals`, f, b)", "42"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`FoO`, `case_insensitive`, `locals_then_globals`, f, b)", "42"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`foo`, `case_sensitive`, `locals_then_globals`, f, b)", "42"), + // `bar` should be found in the variable b, which is also a struct + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`bar`, `case_insensitive`, `locals_then_globals`, f, b)", "43"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`BaR`, `case_insensitive`, `locals_then_globals`, f, b)", "43"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`bAr`, `case_insensitive`, `locals_then_globals`, f, b)", "43"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`bar`, `case_sensitive`, `locals_then_globals`, f, b)", "43"), + + // globals_then_locals + + // The global variable `foo` should be found first, ignoring the `f.foo`, unlike the similar cases above` + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`foo`, `case_insensitive`, `globals_then_locals`, f, b)", "44"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`fOo`, `case_insensitive`, `globals_then_locals`, f, b)", "44"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`FoO`, `case_insensitive`, `globals_then_locals`, f, b)", "44"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`foo`, `case_sensitive`, `globals_then_locals`, f, b)", "44"), + // `bar` should still be found in the variable b, which is also a struct, since there is no global named `bar`. + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`bar`, `case_insensitive`, `globals_then_locals`, f, b)", "43"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`BaR`, `case_insensitive`, `globals_then_locals`, f, b)", "43"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`bAr`, `case_insensitive`, `globals_then_locals`, f, b)", "43"), + ExprFunctionTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`bar`, `case_sensitive`, `globals_then_locals`, f, b)", "43") + ) + } + + @ParameterizedTest + @ArgumentsSource(MismatchCaseSensitiveCases::class) + fun mismatchedCaseSensitiveTests(testCase: EvaluatorErrorTestCase) = + runEvaluatorErrorTestCase( + testCase.copy( + expectedPermissiveModeResult = "MISSING", + targetPipeline = EvaluatorTestTarget.PLANNER_PIPELINE + ), + session = session + ) + + class MismatchCaseSensitiveCases : ArgumentsProviderBase() { + override fun getParameters(): List = listOf( + // Can't find these variables due to case mismatch when perform case sensitive lookup + EvaluatorErrorTestCase( + query = "\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`fOo`, `case_sensitive`, `locals_then_globals`, f, b)", + expectedErrorCode = ErrorCode.EVALUATOR_QUOTED_BINDING_DOES_NOT_EXIST, + expectedErrorContext = propertyValueMapOf(1, 1, Property.BINDING_NAME to "fOo") + ), + EvaluatorErrorTestCase( + query = "\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`FoO`, `case_sensitive`, `locals_then_globals`, f, b)", + expectedErrorCode = ErrorCode.EVALUATOR_QUOTED_BINDING_DOES_NOT_EXIST, + expectedErrorContext = propertyValueMapOf(1, 1, Property.BINDING_NAME to "FoO") + ), + EvaluatorErrorTestCase( + query = "\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`BaR`, `case_sensitive`, `locals_then_globals`, f, b)", + expectedErrorCode = ErrorCode.EVALUATOR_QUOTED_BINDING_DOES_NOT_EXIST, + expectedErrorContext = propertyValueMapOf(1, 1, Property.BINDING_NAME to "BaR") + ), + EvaluatorErrorTestCase( + query = "\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`bAr`, `case_sensitive`, `locals_then_globals`, f, b)", + expectedErrorCode = ErrorCode.EVALUATOR_QUOTED_BINDING_DOES_NOT_EXIST, + expectedErrorContext = propertyValueMapOf(1, 1, Property.BINDING_NAME to "bAr") + ) + ) + } + + data class InvalidArgTestCase( + val source: String, + val argumentPosition: Int, + val actualArgumentType: String, + ) + + @ParameterizedTest + @ArgumentsSource(InvalidArgCases::class) + fun invalidArgTypeTestCases(testCase: InvalidArgTestCase) = + runEvaluatorErrorTestCase( + query = testCase.source, + expectedErrorCode = ErrorCode.EVALUATOR_INCORRECT_TYPE_OF_ARGUMENTS_TO_FUNC_CALL, + expectedErrorContext = propertyValueMapOf( + 1, 1, + Property.FUNCTION_NAME to DYNAMIC_LOOKUP_FUNCTION_NAME, + Property.EXPECTED_ARGUMENT_TYPES to "SYMBOL", + Property.ACTUAL_ARGUMENT_TYPES to testCase.actualArgumentType, + Property.ARGUMENT_POSITION to testCase.argumentPosition + ), + expectedPermissiveModeResult = "MISSING", + target = EvaluatorTestTarget.PLANNER_PIPELINE + ) + + class InvalidArgCases : ArgumentsProviderBase() { + override fun getParameters(): List = listOf( + InvalidArgTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(1, `case_insensitive`, `locals_then_globals`)", 1, "INT"), + InvalidArgTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`foo`, 1, `locals_then_globals`)", 2, "INT"), + InvalidArgTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`foo`, `case_insensitive`, 1)", 3, "INT") + ) + } + + @Test + fun invalidArityTest() = checkInvalidArity( + funcName = "\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"", + maxArity = Int.MAX_VALUE, + minArity = 3, + targetPipeline = EvaluatorTestTarget.PLANNER_PIPELINE + ) +} diff --git a/lang/test/org/partiql/lang/eval/builtins/functions/FilterDistinctEvaluationTest.kt b/lang/test/org/partiql/lang/eval/builtins/functions/FilterDistinctEvaluationTest.kt new file mode 100644 index 0000000000..df7689abde --- /dev/null +++ b/lang/test/org/partiql/lang/eval/builtins/functions/FilterDistinctEvaluationTest.kt @@ -0,0 +1,73 @@ +package org.partiql.lang.eval.builtins.functions + +import org.junit.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.eval.EvaluatorTestBase +import org.partiql.lang.eval.builtins.ExprFunctionTestCase +import org.partiql.lang.eval.builtins.checkInvalidArity +import org.partiql.lang.util.ArgumentsProviderBase +import org.partiql.lang.util.propertyValueMapOf +import org.partiql.lang.util.to + +class FilterDistinctEvaluationTest : EvaluatorTestBase() { + // Pass test cases + @ParameterizedTest + @ArgumentsSource(ToStringPassCases::class) + fun runPassTests(testCase: ExprFunctionTestCase) = + runEvaluatorTestCase(query = testCase.source, expectedResult = testCase.expectedLegacyModeResult) + + // We rely on the built-in [DEFAULT_COMPARATOR] for the actual definition of equality, which is not being tested + // here. + class ToStringPassCases : ArgumentsProviderBase() { + override fun getParameters(): List = listOf( + + // These three tests ensure we can accept lists, bags, s-expressions and structs + ExprFunctionTestCase("filter_distinct([0, 0, 1])", "[0, 1]"), // list + ExprFunctionTestCase("filter_distinct(<<0, 0, 1>>)", "[0, 1]"), // bag + ExprFunctionTestCase("filter_distinct(SEXP(0, 0, 1))", "[0, 1]"), // s-exp + ExprFunctionTestCase("filter_distinct({'a': 0, 'b': 0, 'c': 1})", "[0, 1]"), // struct + + // Some "smoke tests" to ensure the basic plumbing is working right. + ExprFunctionTestCase("filter_distinct(['foo', 'foo', 1, 1, `symbol`, `symbol`])", "[\"foo\", 1, symbol]"), + ExprFunctionTestCase("filter_distinct([{ 'a': 1 }, { 'a': 1 }, { 'a': 1 }])", "[{ 'a': 1 }]"), + ExprFunctionTestCase("filter_distinct([[1, 1], [1, 1], [2, 2]])", "[[1,1], [2, 2]]"), + ) + } + + // Error test cases: Invalid arguments + data class InvalidArgTestCase( + val source: String, + val actualArgumentType: String + ) + + @ParameterizedTest + @ArgumentsSource(InvalidArgCases::class) + fun toStringInvalidArgumentTests(testCase: InvalidArgTestCase) = runEvaluatorErrorTestCase( + query = testCase.source, + expectedErrorCode = ErrorCode.EVALUATOR_INCORRECT_TYPE_OF_ARGUMENTS_TO_FUNC_CALL, + expectedErrorContext = propertyValueMapOf( + 1, 1, + Property.FUNCTION_NAME to "filter_distinct", + Property.EXPECTED_ARGUMENT_TYPES to "BAG, LIST, SEXP, or STRUCT", + Property.ACTUAL_ARGUMENT_TYPES to testCase.actualArgumentType, + Property.ARGUMENT_POSITION to 1 + ), + expectedPermissiveModeResult = "MISSING", + ) + + class InvalidArgCases : ArgumentsProviderBase() { + override fun getParameters(): List = listOf( + InvalidArgTestCase("filter_distinct(1)", "INT"), + InvalidArgTestCase("filter_distinct(1.0)", "DECIMAL"), + InvalidArgTestCase("filter_distinct('foo')", "STRING"), + InvalidArgTestCase("filter_distinct(`some_symbol`)", "SYMBOL"), + InvalidArgTestCase("filter_distinct(`{{ '''a clob''' }}`)", "CLOB"), + ) + } + + @Test + fun invalidArityTest() = checkInvalidArity(funcName = "filter_distinct", maxArity = 1, minArity = 1) +} diff --git a/lang/test/org/partiql/lang/eval/evaluatortestframework/AbstractPipeline.kt b/lang/test/org/partiql/lang/eval/evaluatortestframework/AbstractPipeline.kt new file mode 100644 index 0000000000..65f97b83d4 --- /dev/null +++ b/lang/test/org/partiql/lang/eval/evaluatortestframework/AbstractPipeline.kt @@ -0,0 +1,15 @@ +package org.partiql.lang.eval.evaluatortestframework + +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.TypingMode + +/** + * Represents an abstract pipeline (either [org.partiql.lang.CompilerPipeline] or + * [org.partiql.lang.planner.PlannerPipeline]) so that [PipelineEvaluatorTestAdapter] can work with either. + * + * Includes only those properties and methods that are required for testing purposes. + */ +interface AbstractPipeline { + val typingMode: TypingMode + fun evaluate(query: String): ExprValue +} diff --git a/lang/test/org/partiql/lang/eval/evaluatortestframework/Assertions.kt b/lang/test/org/partiql/lang/eval/evaluatortestframework/Assertions.kt new file mode 100644 index 0000000000..39808d2933 --- /dev/null +++ b/lang/test/org/partiql/lang/eval/evaluatortestframework/Assertions.kt @@ -0,0 +1,40 @@ +package org.partiql.lang.eval.evaluatortestframework + +import org.partiql.lang.SqlException + +internal fun assertEquals( + expected: Any?, + actual: Any?, + reason: EvaluatorTestFailureReason, + detailsBlock: () -> String +) { + if (expected != actual) { + throw EvaluatorAssertionFailedError(reason, detailsBlock()) + } +} + +internal fun assertDoesNotThrow( + reason: EvaluatorTestFailureReason, + detailsBlock: () -> String, + block: () -> T +): T { + try { + return block() + } catch (ex: Throwable) { + throw EvaluatorAssertionFailedError(reason, detailsBlock(), ex) + } +} + +internal inline fun assertThrowsSqlException( + reason: EvaluatorTestFailureReason, + detailsBlock: () -> String, + block: () -> Unit +): SqlException { + try { + block() + // if we made it here, the test failed. + throw EvaluatorAssertionFailedError(reason, detailsBlock()) + } catch (ex: SqlException) { + return ex + } +} diff --git a/lang/test/org/partiql/lang/eval/evaluatortestframework/CompilerPipelineFactory.kt b/lang/test/org/partiql/lang/eval/evaluatortestframework/CompilerPipelineFactory.kt new file mode 100644 index 0000000000..baa82c7548 --- /dev/null +++ b/lang/test/org/partiql/lang/eval/evaluatortestframework/CompilerPipelineFactory.kt @@ -0,0 +1,51 @@ +package org.partiql.lang.eval.evaluatortestframework + +import org.partiql.lang.CompilerPipeline +import org.partiql.lang.ION +import org.partiql.lang.eval.CompileOptions +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.TypingMode + +internal class CompilerPipelineFactory : PipelineFactory { + override val pipelineName: String + get() = "CompilerPipeline (AST Evaluator)" + + override val target: EvaluatorTestTarget + get() = EvaluatorTestTarget.COMPILER_PIPELINE + + override fun createPipeline( + evaluatorTestDefinition: EvaluatorTestDefinition, + session: EvaluationSession, + forcePermissiveMode: Boolean + ): AbstractPipeline { + val concretePipeline = evaluatorTestDefinition.createCompilerPipeline(forcePermissiveMode) + + return object : AbstractPipeline { + override val typingMode: TypingMode + get() = concretePipeline.compileOptions.typingMode + + override fun evaluate(query: String): ExprValue = + concretePipeline.compile(query).eval(session) + } + } +} + +internal fun EvaluatorTestDefinition.createCompilerPipeline(forcePermissiveMode: Boolean): CompilerPipeline { + + val compileOptions = CompileOptions.build(compileOptionsBuilderBlock).let { co -> + if (forcePermissiveMode) { + CompileOptions.build(co) { + typingMode(TypingMode.PERMISSIVE) + } + } else { + co + } + } + + val concretePipeline = CompilerPipeline.build(ION) { + compileOptions(compileOptions) + this@createCompilerPipeline.compilerPipelineBuilderBlock(this) + } + return concretePipeline +} diff --git a/lang/test/org/partiql/lang/eval/evaluatortestframework/EvaluatorErrorTestCase.kt b/lang/test/org/partiql/lang/eval/evaluatortestframework/EvaluatorErrorTestCase.kt index 3495758280..f1571555a4 100644 --- a/lang/test/org/partiql/lang/eval/evaluatortestframework/EvaluatorErrorTestCase.kt +++ b/lang/test/org/partiql/lang/eval/evaluatortestframework/EvaluatorErrorTestCase.kt @@ -55,6 +55,17 @@ data class EvaluatorErrorTestCase( */ override val implicitPermissiveModeTest: Boolean = true, + /** + * This will be executed to perform additional exceptions on the resulting exception. + */ + val additionalExceptionAssertBlock: (SqlException) -> Unit = { }, + + /** + * Determines which pipeline this test should run against; the [CompilerPipeline], + * [org.partiql.lang.planner.PlannerPipeline] or both. + */ + override val targetPipeline: EvaluatorTestTarget = EvaluatorTestTarget.ALL_PIPELINES, + /** * Builder block for building [CompileOptions]. */ @@ -65,10 +76,6 @@ data class EvaluatorErrorTestCase( */ override val compilerPipelineBuilderBlock: CompilerPipeline.Builder.() -> Unit = { }, - /** - * This will be executed to perform additional exceptions on the resulting exception. - */ - val additionalExceptionAssertBlock: (SqlException) -> Unit = { } ) : EvaluatorTestDefinition { /** This will show up in the IDE's test runner. */ @@ -76,4 +83,36 @@ data class EvaluatorErrorTestCase( val groupNameString = if (groupName == null) "" else "$groupName" return "$groupNameString $query : $expectedErrorCode : $expectedErrorContext" } + + /** A generated and human-readable description of this test case for display in assertion failure messages. */ + fun testDetails( + note: String, + actualErrorCode: ErrorCode? = null, + actualErrorContext: PropertyValueMap? = null, + actualPermissiveModeResult: String? = null, + actualInternalFlag: Boolean? = null, + ): String { + val b = StringBuilder() + b.appendLine("Note : $note") + b.appendLine("Group name : $groupName") + b.appendLine("Query : $query") + b.appendLine("Target pipeline : $targetPipeline") + b.appendLine("Expected error code : $expectedErrorCode") + if (actualErrorCode != null) { + b.appendLine("Actual error code : $actualErrorCode") + } + b.appendLine("Expected error context : $expectedErrorContext") + if (actualErrorContext != null) { + b.appendLine("Actual error context : $actualErrorContext") + } + b.appendLine("Expected internal flag : $expectedInternalFlag") + if (actualErrorContext != null) { + b.appendLine("Actual internal flag : $actualInternalFlag") + } + b.appendLine("Expected permissive mode result: $expectedPermissiveModeResult") + if (actualPermissiveModeResult != null) { + b.appendLine("Actual permissive mode result : $actualPermissiveModeResult") + } + return b.toString() + } } diff --git a/lang/test/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestCase.kt b/lang/test/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestCase.kt index fecffd270b..35bd77a64f 100644 --- a/lang/test/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestCase.kt +++ b/lang/test/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestCase.kt @@ -59,6 +59,13 @@ data class EvaluatorTestCase( * `false`. Note that, when `false`, [expectedPermissiveModeResult] is ignored. */ override val implicitPermissiveModeTest: Boolean = true, + + /** + * Determines which pipeline this test should run against; the [CompilerPipeline], + * [org.partiql.lang.planner.PlannerPipeline] or both. + */ + override val targetPipeline: EvaluatorTestTarget = EvaluatorTestTarget.ALL_PIPELINES, + /** * Builder block for building [CompileOptions]. */ @@ -78,6 +85,7 @@ data class EvaluatorTestCase( expectedResultFormat: ExpectedResultFormat = ExpectedResultFormat.PARTIQL, excludeLegacySerializerAssertions: Boolean = false, implicitPermissiveModeTest: Boolean = true, + target: EvaluatorTestTarget = EvaluatorTestTarget.ALL_PIPELINES, compileOptionsBuilderBlock: CompileOptions.Builder.() -> Unit = { }, compilerPipelineBuilderBlock: CompilerPipeline.Builder.() -> Unit = { }, extraResultAssertions: (ExprValue) -> Unit = { }, @@ -89,6 +97,7 @@ data class EvaluatorTestCase( expectedResultFormat = expectedResultFormat, excludeLegacySerializerAssertions = excludeLegacySerializerAssertions, implicitPermissiveModeTest = implicitPermissiveModeTest, + targetPipeline = target, compileOptionsBuilderBlock = compileOptionsBuilderBlock, compilerPipelineBuilderBlock = compilerPipelineBuilderBlock, extraResultAssertions = extraResultAssertions @@ -97,6 +106,22 @@ data class EvaluatorTestCase( /** This will show up in the IDE's test runner. */ override fun toString() = when { groupName != null -> "$groupName : $query" - else -> "$query" + else -> query + } + + /** A generated and human-readable description of this test case for display in assertion failure messages. */ + fun testDetails(note: String, actualResult: String? = null): String { + val b = StringBuilder() + b.appendLine("Note : $note") + b.appendLine("Group name : $groupName") + b.appendLine("Query : $query") + b.appendLine("Target pipeline : $targetPipeline") + b.appendLine("Expected result : $expectedResult") + if (actualResult != null) { + b.appendLine("Actual result : $actualResult") + } + b.appendLine("Result format : $expectedResultFormat") + + return b.toString() } } diff --git a/lang/test/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestDefinition.kt b/lang/test/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestDefinition.kt index 02eafc0eb3..5b875ed2e8 100644 --- a/lang/test/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestDefinition.kt +++ b/lang/test/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestDefinition.kt @@ -36,6 +36,12 @@ interface EvaluatorTestDefinition { */ val implicitPermissiveModeTest: Boolean + /** + * Determines which pipeline this test should run against; the [CompilerPipeline], + * [org.partiql.lang.planner.PlannerPipeline] or both. + */ + val targetPipeline: EvaluatorTestTarget + /** * Builder block for building [CompileOptions]. */ diff --git a/lang/test/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestTarget.kt b/lang/test/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestTarget.kt new file mode 100644 index 0000000000..295208acbc --- /dev/null +++ b/lang/test/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestTarget.kt @@ -0,0 +1,26 @@ +package org.partiql.lang.eval.evaluatortestframework + +/** + * An indicator of which pipeline(s) each test case should be run against. Useful when one pipeline supports + * a feature that the other one doesn't. + */ +enum class EvaluatorTestTarget { + /** + * Run the test on all pipelines. + * + * Set this option when both pipelines support all features utilized in the test case. + * */ + ALL_PIPELINES, + + /** + * Run the test only on [org.partiql.lang.CompilerPipeline]. Set this when the test case covers features not yet + * supported by [org.partiql.lang.planner.PlannerPipeline] or when testing features unique to the former. + */ + COMPILER_PIPELINE, + + /** + * Run the test only on [org.partiql.lang.planner.PlannerPipeline]. Set this when the test case covers features not + * supported by [org.partiql.lang.CompilerPipeline], or when testing features unique to the former. + */ + PLANNER_PIPELINE +} diff --git a/lang/test/org/partiql/lang/eval/evaluatortestframework/AstEvaluatorTestAdapter.kt b/lang/test/org/partiql/lang/eval/evaluatortestframework/PipelineEvaluatorTestAdapter.kt similarity index 63% rename from lang/test/org/partiql/lang/eval/evaluatortestframework/AstEvaluatorTestAdapter.kt rename to lang/test/org/partiql/lang/eval/evaluatortestframework/PipelineEvaluatorTestAdapter.kt index b6ff78e496..02dea9b25e 100644 --- a/lang/test/org/partiql/lang/eval/evaluatortestframework/AstEvaluatorTestAdapter.kt +++ b/lang/test/org/partiql/lang/eval/evaluatortestframework/PipelineEvaluatorTestAdapter.kt @@ -1,136 +1,32 @@ package org.partiql.lang.eval.evaluatortestframework +import org.junit.jupiter.api.Assertions.assertNotEquals import org.junit.jupiter.api.Assertions.assertNotNull import org.junit.jupiter.api.Assertions.assertNull -import org.partiql.lang.CompilerPipeline import org.partiql.lang.ION -import org.partiql.lang.SqlException import org.partiql.lang.errors.ErrorBehaviorInPermissiveMode -import org.partiql.lang.errors.ErrorCode -import org.partiql.lang.errors.PropertyValueMap -import org.partiql.lang.eval.CompileOptions import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.ExprValue import org.partiql.lang.eval.TypingMode import org.partiql.lang.eval.cloneAndRemoveBagAndMissingAnnotations import org.partiql.lang.eval.exprEquals -import kotlin.test.assertNotEquals -private fun EvaluatorTestDefinition.createPipeline(forcePermissiveMode: Boolean = false): CompilerPipeline { - val compileOptions = CompileOptions.build(this@createPipeline.compileOptionsBuilderBlock).let { co -> - if (forcePermissiveMode) { - CompileOptions.build(co) { - typingMode(TypingMode.PERMISSIVE) - } - } else { - co - } - } - - return CompilerPipeline.build(ION) { - compileOptions(compileOptions) - this@createPipeline.compilerPipelineBuilderBlock(this) - } -} - -/** A generated and human readable description of this test case for display in assertion failure messages. */ -fun EvaluatorTestCase.testDetails(note: String, actualResult: String? = null): String { - val b = StringBuilder() - b.appendLine("Note : $note") - b.appendLine("Group name : $groupName") - b.appendLine("Query : $query") - b.appendLine("Expected result : $expectedResult") - if (actualResult != null) { - b.appendLine("Actual result : $actualResult") - } - b.appendLine("Result format : $expectedResultFormat") - - return b.toString() -} - -/** A generated and human readable description of this test case for display in assertion failure messages. */ -fun EvaluatorErrorTestCase.testDetails( - note: String, - actualErrorCode: ErrorCode? = null, - actualErrorContext: PropertyValueMap? = null, - actualPermissiveModeResult: String? = null, - actualInternalFlag: Boolean? = null, -): String { - val b = StringBuilder() - b.appendLine("Note : $note") - b.appendLine("Group name : $groupName") - b.appendLine("Query : $query") - b.appendLine("Expected error code : $expectedErrorCode") - if (actualErrorCode != null) { - b.appendLine("Actual error code : $actualErrorCode") - } - b.appendLine("Expected error context : $expectedErrorContext") - if (actualErrorContext != null) { - b.appendLine("Actual error context : $actualErrorContext") - } - b.appendLine("Expected internal flag : $expectedInternalFlag") - if (actualErrorContext != null) { - b.appendLine("Actual internal flag : $actualInternalFlag") - } - b.appendLine("Expected permissive mode result: $expectedPermissiveModeResult") - if (actualPermissiveModeResult != null) { - b.appendLine("Actual permissive mode result : $actualPermissiveModeResult") - } - return b.toString() -} - -private fun assertEquals( - expected: Any?, - actual: Any?, - reason: EvaluatorTestFailureReason, - detailsBlock: () -> String -) { - if (expected != actual) { - throw EvaluatorAssertionFailedError(reason, detailsBlock()) - } -} - -private fun assertDoesNotThrow( - reason: EvaluatorTestFailureReason, - detailsBlock: () -> String, - block: () -> T -): T { - try { - return block() - } catch (ex: Throwable) { - throw EvaluatorAssertionFailedError(reason, detailsBlock(), ex.cause) - } -} - -private inline fun assertThrowsSqlException( - reason: EvaluatorTestFailureReason, - detailsBlock: () -> String, - block: () -> Unit -): SqlException { - try { - block() - // if we made it here, the test failed. - throw EvaluatorAssertionFailedError(reason, detailsBlock()) - } catch (ex: SqlException) { - return ex - } -} - -class AstEvaluatorTestAdapter : EvaluatorTestAdapter { +internal class PipelineEvaluatorTestAdapter( + private val pipelineFactory: PipelineFactory +) : EvaluatorTestAdapter { override fun runEvaluatorTestCase(tc: EvaluatorTestCase, session: EvaluationSession) { - if (tc.implicitPermissiveModeTest) { - val testOpts = CompileOptions.build { tc.compileOptionsBuilderBlock(this) } - assertNotEquals( - TypingMode.PERMISSIVE, testOpts.typingMode, - "Setting TypingMode.PERMISSIVE when implicit permissive mode testing is enabled is redundant" - ) + // Skip execution of this test case if it does not apply to the pipeline supplied by pipelineFactory. + if (tc.targetPipeline != EvaluatorTestTarget.ALL_PIPELINES && pipelineFactory.target != tc.targetPipeline) { + return } + checkRedundantPermissiveMode(tc) // Compile options unmodified... This covers [TypingMode.LEGACY], unless the test explicitly // sets the typing mode. - privateRunEvaluatorTestCase(tc, session, "compile options unaltered") + privateRunEvaluatorTestCase(tc, session, "${pipelineFactory.pipelineName} (compile options unaltered)") - // Unless the tests disable it, run again in permissive mode. + // Unless the test disables it, run again in permissive mode. if (tc.implicitPermissiveModeTest) { privateRunEvaluatorTestCase( tc.copy( @@ -140,7 +36,7 @@ class AstEvaluatorTestAdapter : EvaluatorTestAdapter { } ), session, - "compile options forced to PERMISSIVE mode" + "${pipelineFactory.pipelineName} (compile options forced to PERMISSIVE mode)" ) } } @@ -153,17 +49,17 @@ class AstEvaluatorTestAdapter : EvaluatorTestAdapter { session: EvaluationSession, note: String, ) { - val pipeline = tc.createPipeline() + val pipeline = pipelineFactory.createPipeline(tc, session) - val actualExprValueResult = assertDoesNotThrow( + val actualExprValueResult: ExprValue = assertDoesNotThrow( EvaluatorTestFailureReason.FAILED_TO_EVALUATE_QUERY, { tc.testDetails(note = note) } ) { - pipeline.compile(tc.query).eval(session) + pipeline.evaluate(tc.query) } val (expectedResult, unexpectedResultErrorCode) = - when (pipeline.compileOptions.typingMode) { + when (pipeline.typingMode) { TypingMode.LEGACY -> tc.expectedResult to EvaluatorTestFailureReason.UNEXPECTED_QUERY_RESULT TypingMode.PERMISSIVE -> tc.expectedPermissiveModeResult to EvaluatorTestFailureReason.UNEXPECTED_PERMISSIVE_MODE_RESULT } @@ -194,7 +90,7 @@ class AstEvaluatorTestAdapter : EvaluatorTestAdapter { EvaluatorTestFailureReason.FAILED_TO_EVALUATE_PARTIQL_EXPECTED_RESULT, { tc.testDetails(note = note) } ) { - pipeline.compile(expectedResult).eval(session) + pipeline.evaluate(expectedResult) } if (!expectedExprValueResult.exprEquals(actualExprValueResult)) { @@ -223,7 +119,7 @@ class AstEvaluatorTestAdapter : EvaluatorTestAdapter { session: EvaluationSession, note: String ) { - val compilerPipeline = tc.createPipeline() + val pipeline = pipelineFactory.createPipeline(tc, session) val ex = assertThrowsSqlException( EvaluatorTestFailureReason.EXPECTED_SQL_EXCEPTION_BUT_THERE_WAS_NONE, @@ -234,10 +130,8 @@ class AstEvaluatorTestAdapter : EvaluatorTestAdapter { // .compile OR in .eval. We currently don't make a distinction, so tests cannot assert that certain // errors are compile-time and others are evaluation-time. We really aught to create a way for tests to // indicate when the exception should be thrown. This is undone. - val expression = compilerPipeline.compile(tc.query) - - // The call to .ionValue is important since query execution won't actually begin otherwise. - expression.eval(session).ionValue + // The call to .ionValue below is important since query execution won't actually begin otherwise. + pipeline.evaluate(tc.query).ionValue } assertEquals( @@ -266,15 +160,13 @@ class AstEvaluatorTestAdapter : EvaluatorTestAdapter { } override fun runEvaluatorErrorTestCase(tc: EvaluatorErrorTestCase, session: EvaluationSession) { - if (tc.implicitPermissiveModeTest) { - val testOpts = CompileOptions.build { tc.compileOptionsBuilderBlock(this) } - assertNotEquals( - TypingMode.PERMISSIVE, - testOpts.typingMode, - "Setting TypingMode.PERMISSIVE when implicit permissive mode testing is enabled is redundant" - ) + // Skip execution of this test case if it does not apply to the pipeline supplied by pipelineFactory. + if (tc.targetPipeline != EvaluatorTestTarget.ALL_PIPELINES && pipelineFactory.target != tc.targetPipeline) { + return } + checkRedundantPermissiveMode(tc) + // Run the query once with compile options unmodified. privateRunEvaluatorErrorTestCase( tc = tc.copy( @@ -284,7 +176,7 @@ class AstEvaluatorTestAdapter : EvaluatorTestAdapter { } ), session = session, - note = "Typing mode forced to LEGACY" + note = "${pipelineFactory.pipelineName} (Typing mode forced to LEGACY)" ) when (tc.expectedErrorCode.errorBehaviorInPermissiveMode) { @@ -305,7 +197,7 @@ class AstEvaluatorTestAdapter : EvaluatorTestAdapter { } ), session, - note = "Typing mode forced to PERMISSIVE" + note = "${pipelineFactory.pipelineName} (typing mode forced to PERMISSIVE)" ) } ErrorBehaviorInPermissiveMode.RETURN_MISSING -> { @@ -319,14 +211,14 @@ class AstEvaluatorTestAdapter : EvaluatorTestAdapter { ) // Compute the expected return value - val permissiveModePipeline = tc.createPipeline(forcePermissiveMode = true) + val permissiveModePipeline = pipelineFactory.createPipeline(evaluatorTestDefinition = tc, session, forcePermissiveMode = true) val expectedExprValueForPermissiveMode = assertDoesNotThrow( EvaluatorTestFailureReason.FAILED_TO_EVALUATE_PARTIQL_EXPECTED_RESULT, { tc.testDetails(note = "Evaluating expected permissive mode result") } ) { - permissiveModePipeline.compile(tc.expectedPermissiveModeResult!!).eval(session) + permissiveModePipeline.evaluate(tc.expectedPermissiveModeResult!!) } val actualReturnValueForPermissiveMode = @@ -338,7 +230,7 @@ class AstEvaluatorTestAdapter : EvaluatorTestAdapter { ) } ) { - permissiveModePipeline.compile(tc.query).eval(session) + permissiveModePipeline.evaluate(tc.query) } if (!expectedExprValueForPermissiveMode.exprEquals(actualReturnValueForPermissiveMode)) { @@ -353,4 +245,15 @@ class AstEvaluatorTestAdapter : EvaluatorTestAdapter { } } } + + private fun checkRedundantPermissiveMode(tc: EvaluatorTestDefinition) { + if (tc.implicitPermissiveModeTest) { + val pipeline = pipelineFactory.createPipeline(tc, EvaluationSession.standard()) + assertNotEquals( + TypingMode.PERMISSIVE, + pipeline.typingMode, + "Setting TypingMode.PERMISSIVE when implicit permissive mode testing is enabled is redundant" + ) + } + } } diff --git a/lang/test/org/partiql/lang/eval/evaluatortestframework/AstEvaluatorTestAdapterTests.kt b/lang/test/org/partiql/lang/eval/evaluatortestframework/PipelineEvaluatorTestAdapterTests.kt similarity index 87% rename from lang/test/org/partiql/lang/eval/evaluatortestframework/AstEvaluatorTestAdapterTests.kt rename to lang/test/org/partiql/lang/eval/evaluatortestframework/PipelineEvaluatorTestAdapterTests.kt index 6614505b58..225ec3ca41 100644 --- a/lang/test/org/partiql/lang/eval/evaluatortestframework/AstEvaluatorTestAdapterTests.kt +++ b/lang/test/org/partiql/lang/eval/evaluatortestframework/PipelineEvaluatorTestAdapterTests.kt @@ -8,27 +8,42 @@ import org.partiql.lang.errors.ErrorCode import org.partiql.lang.eval.EvaluationSession import org.partiql.lang.util.propertyValueMapOf +private fun assertTestFails( + testAdapter: PipelineEvaluatorTestAdapter, + expectedReason: EvaluatorTestFailureReason, + tc: EvaluatorTestCase +) { + val ex = assertThrows { + testAdapter.runEvaluatorTestCase(tc, EvaluationSession.standard()) + } + assertEquals(expectedReason, ex.reason) +} + +private fun assertErrorTestFails( + testAdapter: PipelineEvaluatorTestAdapter, + expectedReason: EvaluatorTestFailureReason, + tc: EvaluatorErrorTestCase +) { + val ex = assertThrows { + testAdapter.runEvaluatorErrorTestCase(tc, EvaluationSession.standard()) + } + + assertEquals(expectedReason, ex.reason) +} + /** - * These are just some "smoke tests" to ensure that the essential parts of [AstEvaluatorTestAdapterTests] are + * These are "smoke tests" to ensure that the essential parts of [PipelineEvaluatorTestAdapterTests] are * working correctly. */ -class AstEvaluatorTestAdapterTests { - private val testAdapter = AstEvaluatorTestAdapter() +class PipelineEvaluatorTestAdapterTests { + private val astPipelineTestAdapter = PipelineEvaluatorTestAdapter(CompilerPipelineFactory()) private fun assertTestFails(expectedReason: EvaluatorTestFailureReason, tc: EvaluatorTestCase) { - val ex = assertThrows { - testAdapter.runEvaluatorTestCase(tc, EvaluationSession.standard()) - } - - assertEquals(expectedReason, ex.reason) + assertTestFails(astPipelineTestAdapter, expectedReason, tc) } private fun assertErrorTestFails(expectedReason: EvaluatorTestFailureReason, tc: EvaluatorErrorTestCase) { - val ex = assertThrows { - testAdapter.runEvaluatorErrorTestCase(tc, EvaluationSession.standard()) - } - - assertEquals(expectedReason, ex.reason) + assertErrorTestFails(astPipelineTestAdapter, expectedReason, tc) } class FooException : Exception() @@ -40,7 +55,7 @@ class AstEvaluatorTestAdapterTests { @Test fun `runEvaluatorTestCase - expected result matches - ExpectedResultFormat-ION`() { assertDoesNotThrow("happy path - should not throw") { - testAdapter.runEvaluatorTestCase( + astPipelineTestAdapter.runEvaluatorTestCase( EvaluatorTestCase( query = "1", expectedResult = "1", @@ -54,7 +69,7 @@ class AstEvaluatorTestAdapterTests { @Test fun `runEvaluatorTestCase - different permissive mode result - ExpectedResultFormat-ION`() { assertDoesNotThrow("happy path - should not throw") { - testAdapter.runEvaluatorTestCase( + astPipelineTestAdapter.runEvaluatorTestCase( EvaluatorTestCase( query = "1 + MISSING", // Note:unknown propagation works differently in legacy vs permissive modes. expectedResult = "null", @@ -69,7 +84,7 @@ class AstEvaluatorTestAdapterTests { @Test fun `runEvaluatorTestCase - expected result matches - ExpectedResultFormat-ION (missing)`() { assertDoesNotThrow("happy path - should not throw") { - testAdapter.runEvaluatorTestCase( + astPipelineTestAdapter.runEvaluatorTestCase( EvaluatorTestCase( query = "MISSING", expectedResult = "\$partiql_missing::null", @@ -83,7 +98,7 @@ class AstEvaluatorTestAdapterTests { @Test fun `runEvaluatorTestCase - expected result matches - ExpectedResultFormat-ION (date)`() { assertDoesNotThrow("happy path - should not throw") { - testAdapter.runEvaluatorTestCase( + astPipelineTestAdapter.runEvaluatorTestCase( EvaluatorTestCase( query = "DATE '2001-01-01'", expectedResult = "\$partiql_date::2001-01-01", @@ -97,7 +112,7 @@ class AstEvaluatorTestAdapterTests { @Test fun `runEvaluatorTestCase - expected result matches - ExpectedResultFormat-ION (time)`() { assertDoesNotThrow("happy path - should not throw") { - testAdapter.runEvaluatorTestCase( + astPipelineTestAdapter.runEvaluatorTestCase( EvaluatorTestCase( query = "TIME '12:12:01'", expectedResult = "\$partiql_time::{hour:12,minute:12,second:1.,timezone_hour:null.int,timezone_minute:null.int}", @@ -111,7 +126,7 @@ class AstEvaluatorTestAdapterTests { @Test fun `runEvaluatorTestCase - expected result matches - ExpectedResultFormat-ION_WITHOUT_BAG_AND_MISSING_ANNOTATIONS mode (int)`() { assertDoesNotThrow("happy path - should not throw") { - testAdapter.runEvaluatorTestCase( + astPipelineTestAdapter.runEvaluatorTestCase( EvaluatorTestCase( query = "1", expectedResult = "1", @@ -125,7 +140,7 @@ class AstEvaluatorTestAdapterTests { @Test fun `runEvaluatorTestCase - different permissive mode result - ExpectedResultFormat-ION_WITHOUT_BAG_AND_MISSING_ANNOTATIONS`() { assertDoesNotThrow("happy path - should not throw") { - testAdapter.runEvaluatorTestCase( + astPipelineTestAdapter.runEvaluatorTestCase( EvaluatorTestCase( query = "1 + MISSING", expectedResult = "null", @@ -142,7 +157,7 @@ class AstEvaluatorTestAdapterTests { @Test fun `runEvaluatorTestCase - expected result matches - ExpectedResultFormat-ION_WITHOUT_BAG_AND_MISSING_ANNOTATIONS mode (bag)`() { assertDoesNotThrow("happy path - should not throw") { - testAdapter.runEvaluatorTestCase( + astPipelineTestAdapter.runEvaluatorTestCase( EvaluatorTestCase( query = "<<1>>", // note: In this ExpectedResultFormat we lose the fact that this a BAG and not an @@ -158,7 +173,7 @@ class AstEvaluatorTestAdapterTests { @Test fun `runEvaluatorTestCase - expected result matches - ExpectedResultFormat-ION_WITHOUT_BAG_AND_MISSING_ANNOTATIONS mode (missing)`() { assertDoesNotThrow("happy path - should not throw") { - testAdapter.runEvaluatorTestCase( + astPipelineTestAdapter.runEvaluatorTestCase( EvaluatorTestCase( query = "MISSING", // note: In this ExpectedResultFormat we lose the fact that this MISSING and not an @@ -174,7 +189,7 @@ class AstEvaluatorTestAdapterTests { @Test fun `runEvaluatorTestCase - expected result matches - ExpectedResultFormat-ION_WITHOUT_BAG_AND_MISSING_ANNOTATIONS mode (date)`() { assertDoesNotThrow("happy path - should not throw") { - testAdapter.runEvaluatorTestCase( + astPipelineTestAdapter.runEvaluatorTestCase( EvaluatorTestCase( query = "DATE '2001-01-01'", expectedResult = "\$partiql_date::2001-01-01", @@ -188,7 +203,7 @@ class AstEvaluatorTestAdapterTests { @Test fun `runEvaluatorTestCase - expected result matches - ExpectedResultFormat-ION_WITHOUT_BAG_AND_MISSING_ANNOTATIONS mode (time)`() { assertDoesNotThrow("happy path - should not throw") { - testAdapter.runEvaluatorTestCase( + astPipelineTestAdapter.runEvaluatorTestCase( EvaluatorTestCase( query = "TIME '12:12:01'", expectedResult = "\$partiql_time::{hour:12,minute:12,second:1.,timezone_hour:null.int,timezone_minute:null.int}", @@ -202,7 +217,7 @@ class AstEvaluatorTestAdapterTests { @Test fun `runEvaluatorTestCase - expected result matches - ExpectedResultFormat-ION_WITHOUT_BAG_AND_MISSING_ANNOTATIONS mode`() { assertDoesNotThrow("happy path - should not throw") { - testAdapter.runEvaluatorTestCase( + astPipelineTestAdapter.runEvaluatorTestCase( EvaluatorTestCase( query = "<<1>>", expectedResult = "[1]", @@ -216,7 +231,7 @@ class AstEvaluatorTestAdapterTests { @Test fun `runEvaluatorTestCase - expected result matches - ExpectedResultFormat-STRING mode`() { assertDoesNotThrow("happy path - should not throw") { - testAdapter.runEvaluatorTestCase( + astPipelineTestAdapter.runEvaluatorTestCase( EvaluatorTestCase( query = "SEXP(1, 2, 3)", expectedResult = "`(1 2 3)`", // <-- ExprValue.toString() produces this @@ -313,7 +328,7 @@ class AstEvaluatorTestAdapterTests { @Test fun `runEvaluatorTestCase - extraResultAssertions`() { assertThrows("extraResultAssertions should throw") { - testAdapter.runEvaluatorTestCase( + astPipelineTestAdapter.runEvaluatorTestCase( EvaluatorTestCase( query = "1", expectedResult = "1", @@ -403,8 +418,10 @@ class AstEvaluatorTestAdapterTests { @Test fun `runEvaluatorErrorTestCase - additionalExceptionAssertBlock`() { + // No need to test both test adapters here since additionalExceptionAssertBlock is invoked by + // PipelineEvaluatorTestAdapter. assertThrows("additionalExceptionAssertBlock should throw") { - testAdapter.runEvaluatorErrorTestCase( + astPipelineTestAdapter.runEvaluatorErrorTestCase( EvaluatorErrorTestCase( query = "undefined_function()", expectedErrorCode = ErrorCode.EVALUATOR_NO_SUCH_FUNCTION, diff --git a/lang/test/org/partiql/lang/eval/evaluatortestframework/PipelineFactory.kt b/lang/test/org/partiql/lang/eval/evaluatortestframework/PipelineFactory.kt new file mode 100644 index 0000000000..46e450fd23 --- /dev/null +++ b/lang/test/org/partiql/lang/eval/evaluatortestframework/PipelineFactory.kt @@ -0,0 +1,19 @@ +package org.partiql.lang.eval.evaluatortestframework + +import org.partiql.lang.eval.EvaluationSession + +/** + * The implementation of this interface is passed to the constructor of [PipelineEvaluatorTestAdapter]. Determines + * which pipeline (either [org.partiql.lang.CompilerPipeline] or [org.partiql.lang.planner.PlannerPipeline]) will be + * tested. + */ +internal interface PipelineFactory { + val pipelineName: String + val target: EvaluatorTestTarget + + fun createPipeline( + evaluatorTestDefinition: EvaluatorTestDefinition, + session: EvaluationSession, + forcePermissiveMode: Boolean = false, + ): AbstractPipeline +} diff --git a/lang/test/org/partiql/lang/eval/evaluatortestframework/PlannerPipelineFactory.kt b/lang/test/org/partiql/lang/eval/evaluatortestframework/PlannerPipelineFactory.kt new file mode 100644 index 0000000000..cdb6c59f38 --- /dev/null +++ b/lang/test/org/partiql/lang/eval/evaluatortestframework/PlannerPipelineFactory.kt @@ -0,0 +1,124 @@ +package org.partiql.lang.eval.evaluatortestframework + +import org.junit.jupiter.api.fail +import org.partiql.lang.ION +import org.partiql.lang.eval.BindingName +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.TypingMode +import org.partiql.lang.eval.UndefinedVariableBehavior +import org.partiql.lang.planner.EvaluatorOptions +import org.partiql.lang.planner.MetadataResolver +import org.partiql.lang.planner.PassResult +import org.partiql.lang.planner.PlannerPipeline +import org.partiql.lang.planner.ResolutionResult +import kotlin.test.assertNotEquals +import kotlin.test.assertNull + +/** + * Uses the test infrastructure (which is geared toward the legacy [org.partiql.lang.CompilerPipeline]) to create a + * standard [org.partiql.lang.CompilerPipeline], then creates an equivalent [PlannerPipeline] which is wrapped in + * an instance of [AbstractPipeline] and returned to the caller. + * + * Why? Because the entire test infrastructure (and the many thousands of tests) are heavily dependent on + * [org.partiql.lang.CompilerPipeline]. + * + * TODO: When that class is deprecated or removed we'll want to change the test infrastructure to depend on the + * [PlannerPipeline] instead. + */ +internal class PlannerPipelineFactory : PipelineFactory { + + override val pipelineName: String + get() = "PlannerPipeline (and Physical Plan Evaluator)" + + override val target: EvaluatorTestTarget + get() = EvaluatorTestTarget.PLANNER_PIPELINE + + override fun createPipeline( + evaluatorTestDefinition: EvaluatorTestDefinition, + session: EvaluationSession, + forcePermissiveMode: Boolean + ): AbstractPipeline { + + // Construct a legacy CompilerPipeline + val compilerPipeline = evaluatorTestDefinition.createCompilerPipeline(forcePermissiveMode) + + // Convert it to a PlannerPipeline (to avoid having to refactor many tests cases to use + // PlannerPipeline.Builder and EvaluatorOptions.Builder. + val co = compilerPipeline.compileOptions + + assertNotEquals( + co.undefinedVariable, UndefinedVariableBehavior.MISSING, + "The planner and physical plan evaluator do not support UndefinedVariableBehavior.MISSING. " + + "Please set target = EvaluatorTestTarget.COMPILER_PIPELINE for this test.\n" + + "Test groupName: ${evaluatorTestDefinition.groupName}" + ) + + assertNull( + compilerPipeline.globalTypeBindings, + "The planner and physical plan evaluator do not support globalTypeBindings (yet)" + + "Please set target = EvaluatorTestTarget.COMPILER_PIPELINE for this test." + ) + + val evaluatorOptions = EvaluatorOptions.build { + typingMode(co.typingMode) + thunkOptions(co.thunkOptions) + defaultTimezoneOffset(co.defaultTimezoneOffset) + typedOpBehavior(co.typedOpBehavior) + projectionIteration(co.projectionIteration) + } + + @Suppress("DEPRECATION") + val plannerPipeline = PlannerPipeline.build(ION) { + // this is for support of the existing test suite and may not be desirable for all future tests. + allowUndefinedVariables(true) + + customDataTypes(compilerPipeline.customDataTypes) + + compilerPipeline.functions.values.forEach { this.addFunction(it) } + compilerPipeline.procedures.values.forEach { this.addProcedure(it) } + + evaluatorOptions(evaluatorOptions) + + // For compatibility with the unit test suite, prevent the planner from catching SqlException during query + // compilation and converting them into Problems + enableLegacyExceptionHandling() + + // Create a fake MetadataResolver implementation which defines any global that is also defined in the + // session. + metadataResolver( + object : MetadataResolver { + override fun resolveVariable(bindingName: BindingName): ResolutionResult { + val boundValue = session.globals[bindingName] + return if (boundValue != null) { + // There is no way to tell the actual name of the global variable as it exists + // in session.globals (case may differ). For now we simply have to use binding.name + // as the uniqueId of the variable, however, this is not desirable in production + // scenarios. At minimum, the name of the variable in its original letter-case should be + // used. + ResolutionResult.GlobalVariable(bindingName.name) + } else { + ResolutionResult.Undefined + } + } + } + ) + } + + return object : AbstractPipeline { + override val typingMode: TypingMode + get() = evaluatorOptions.typingMode + + override fun evaluate(query: String): ExprValue { + when (val planningResult = plannerPipeline.planAndCompile(query)) { + is PassResult.Error -> { + fail("Query compilation unexpectedly failed: ${planningResult.errors}") + } + is PassResult.Success -> { + return planningResult.result.eval(session) + } + } + } + } + } +} diff --git a/lang/test/org/partiql/lang/eval/relation/RelationTests.kt b/lang/test/org/partiql/lang/eval/relation/RelationTests.kt new file mode 100644 index 0000000000..24ed0dc1e9 --- /dev/null +++ b/lang/test/org/partiql/lang/eval/relation/RelationTests.kt @@ -0,0 +1,58 @@ +package org.partiql.lang.eval.relation + +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows + +class RelationTests { + + @Test + fun relType() { + val rel = relation(RelationType.BAG) { } + assertEquals(RelationType.BAG, rel.relType) + } + + @Test + fun `0 yields`() { + val rel = relation(RelationType.BAG) { } + assertEquals(RelationType.BAG, rel.relType) + assertFalse(rel.nextRow()) + assertThrows { rel.nextRow() } + } + + @Test + fun `1 yield`() { + val rel = relation(RelationType.BAG) { yield() } + assertTrue(rel.nextRow()) + assertFalse(rel.nextRow()) + assertThrows { rel.nextRow() } + } + + @Test + fun `2 yields`() { + val rel = relation(RelationType.BAG) { + yield() + yield() + } + assertTrue(rel.nextRow()) + assertTrue(rel.nextRow()) + assertFalse(rel.nextRow()) + assertThrows { rel.nextRow() } + } + + @Test + fun `3 yields`() { + val rel = relation(RelationType.BAG) { + yield() + yield() + yield() + } + assertTrue(rel.nextRow()) + assertTrue(rel.nextRow()) + assertTrue(rel.nextRow()) + assertFalse(rel.nextRow()) + assertThrows { rel.nextRow() } + } +} diff --git a/lang/test/org/partiql/lang/eval/visitors/StaticTypeVisitorTransformTests.kt b/lang/test/org/partiql/lang/eval/visitors/StaticTypeVisitorTransformTests.kt index 7a214d4c69..0b7220b6c4 100644 --- a/lang/test/org/partiql/lang/eval/visitors/StaticTypeVisitorTransformTests.kt +++ b/lang/test/org/partiql/lang/eval/visitors/StaticTypeVisitorTransformTests.kt @@ -845,7 +845,7 @@ class StaticTypeVisitorTransformTests : VisitorTransformTestBase() { properties.forEach { (property, expectedValue) -> assertEquals( "${property.propertyName} in error doesn't match", - expectedValue, it.error.errorContext?.get(property)?.value + expectedValue, it.error.errorContext[property]?.value ) } } diff --git a/lang/test/org/partiql/lang/planner/PlannerPipelineSmokeTests.kt b/lang/test/org/partiql/lang/planner/PlannerPipelineSmokeTests.kt new file mode 100644 index 0000000000..e65ee32f0a --- /dev/null +++ b/lang/test/org/partiql/lang/planner/PlannerPipelineSmokeTests.kt @@ -0,0 +1,81 @@ +package org.partiql.lang.planner + +import com.amazon.ion.system.IonSystemBuilder +import com.amazon.ionelement.api.ionInt +import com.amazon.ionelement.api.ionString +import com.amazon.ionelement.api.toIonValue +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import org.partiql.lang.ION +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.planner.transforms.PLAN_VERSION_NUMBER +import org.partiql.lang.planner.transforms.PlanningProblemDetails +import org.partiql.lang.util.SexpAstPrettyPrinter + +/** + * Query planning primarily consists of AST traversals and rewrites. Each of those are thoroughly tested separately, + * but it is still good to have a simple "smoke test" for the planner pipeline. + */ +class PlannerPipelineSmokeTests { + private val ion = IonSystemBuilder.standard().build() + + @Suppress("DEPRECATION") + private fun createPlannerPipelineForTest(allowUndefinedVariables: Boolean) = PlannerPipeline.build(ion) { + allowUndefinedVariables(allowUndefinedVariables) + metadataResolver(createFakeMetadataResolver("Customer" to "fake_uid_for_Customer")) + } + + @Test + fun `happy path`() { + val pipeline = createPlannerPipelineForTest(allowUndefinedVariables = true) + val result = pipeline.plan("SELECT c.* FROM Customer AS c WHERE c.primaryKey = 42") + + result as PassResult.Success + println(SexpAstPrettyPrinter.format(result.result.toIonElement().asAnyElement().toIonValue(ION))) + + assertEquals( + result, + PassResult.Success( + result = PartiqlPhysical.build { + plan( + stmt = query( + bindingsToValues( + exp = struct(structFields(localId(0))), + query = filter( + i = impl("default"), + predicate = eq( + operands0 = path( + localId(0), + pathExpr(lit(ionString("primaryKey")), caseInsensitive()) + ), + operands1 = lit(ionInt(42)) + ), + source = scan( + i = impl("default"), + expr = globalId("fake_uid_for_Customer", caseInsensitive()), + asDecl = varDecl(0) + ) + ) + ) + ), + locals = listOf(localVariable("c", 0)), + version = PLAN_VERSION_NUMBER + ) + }, + warnings = emptyList() + ) + ) + } + + @Test + fun `undefined variable`() { + val qp = createPlannerPipelineForTest(allowUndefinedVariables = false) + val result = qp.plan("SELECT undefined.* FROM Customer AS c") + assertEquals( + PassResult.Error( + listOf(problem(1, 8, PlanningProblemDetails.UndefinedVariable("undefined", caseSensitive = false))) + ), + result + ) + } +} diff --git a/lang/test/org/partiql/lang/planner/Util.kt b/lang/test/org/partiql/lang/planner/Util.kt new file mode 100644 index 0000000000..b4635b1e16 --- /dev/null +++ b/lang/test/org/partiql/lang/planner/Util.kt @@ -0,0 +1,25 @@ +package org.partiql.lang.planner + +import org.partiql.lang.ast.SourceLocationMeta +import org.partiql.lang.errors.Problem +import org.partiql.lang.errors.ProblemDetails +import org.partiql.lang.eval.BindingName + +/** + * Creates a fake implementation of [MetadataResolver] with the specified [globalVariableNames]. + * + * The fake unique identifier of bound variables is computed to be `fake_uid_for_${globalVariableName}`. + */ +fun createFakeMetadataResolver(vararg globalVariableNames: Pair) = + object : MetadataResolver { + override fun resolveVariable(bindingName: BindingName): ResolutionResult { + val matches = globalVariableNames.filter { bindingName.isEquivalentTo(it.first) } + return when (matches.size) { + 0 -> ResolutionResult.Undefined + else -> ResolutionResult.GlobalVariable(matches.first().second) + } + } + } + +fun problem(line: Int, charOffset: Int, detail: ProblemDetails): Problem = + Problem(SourceLocationMeta(line.toLong(), charOffset.toLong()), detail) diff --git a/lang/test/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransformTests.kt b/lang/test/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransformTests.kt new file mode 100644 index 0000000000..d7b0d249ee --- /dev/null +++ b/lang/test/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransformTests.kt @@ -0,0 +1,152 @@ +package org.partiql.lang.planner.transforms + +import com.amazon.ion.system.IonSystemBuilder +import com.amazon.ionelement.api.ionBool +import com.amazon.ionelement.api.ionInt +import com.amazon.ionelement.api.ionString +import com.amazon.ionelement.api.toIonValue +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.assertDoesNotThrow +import org.junit.jupiter.api.assertThrows +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import org.partiql.lang.domains.PartiqlLogical +import org.partiql.lang.domains.id +import org.partiql.lang.domains.pathExpr +import org.partiql.lang.syntax.SqlParser +import org.partiql.lang.util.ArgumentsProviderBase +import org.partiql.lang.util.SexpAstPrettyPrinter + +/** + * Test cases in this class might seem a little light--that's because [AstToLogicalVisitorTransform] is getting + * heavily exercised during many other integration tests. These should be considered "smoke tests". + */ +class AstToLogicalVisitorTransformTests { + private val ion = IonSystemBuilder.standard().build() + private val parser = SqlParser(ion) + + private fun parseAndTransform(sql: String): PartiqlLogical.Statement { + val parseAstStatement = parser.parseAstStatement(sql) + println(SexpAstPrettyPrinter.format(parseAstStatement.toIonElement().asAnyElement().toIonValue(ion))) + return parseAstStatement.toLogicalPlan().stmt + } + + data class TestCase(val sql: String, val expectedAlgebra: PartiqlLogical.Statement) + + private fun runTestCase(tc: TestCase) { + val algebra = assertDoesNotThrow("Parsing TestCase.sql should not throw") { + parseAndTransform(tc.sql) + } + println(SexpAstPrettyPrinter.format(algebra.toIonElement().asAnyElement().toIonValue(ion))) + Assertions.assertEquals(tc.expectedAlgebra, algebra) + } + + @ParameterizedTest + @ArgumentsSource(ArgumentsForToLogicalTests::class) + fun `to logical`(tc: TestCase) = runTestCase(tc) + + class ArgumentsForToLogicalTests : ArgumentsProviderBase() { + override fun getParameters() = listOf( + TestCase( + // Note: + // `SELECT * FROM bar AS b` is rewritten to `SELECT b.* FROM bar as b` by [SelectStarVisitorTransform]. + // Therefore, there is no need to support `SELECT *` in `AstToLogicalVisitorTransform`. + "SELECT b.* FROM bar AS b", + PartiqlLogical.build { + query( + bindingsToValues( + struct(structFields(id("b"))), + scan(id("bar"), varDecl("b")) + ) + ) + } + ), + TestCase( + // Note: This is supported by the AST -> logical -> physical transformation but should be rejected + // by the planner since it is a full table scan, which we won't support initially. + "SELECT b.* FROM bar AS b WHERE TRUE = TRUE", + PartiqlLogical.build { + query( + bindingsToValues( + struct(structFields(id("b"))), + filter( + eq(lit(ionBool(true)), lit(ionBool(true))), + scan(id("bar"), varDecl("b")) + ) + ) + ) + } + ), + TestCase( + "SELECT b.* FROM bar AS b WHERE b.primaryKey = 42", + PartiqlLogical.build { + query( + bindingsToValues( + struct(structFields(id("b"))), + filter( + eq(path(id("b"), pathExpr(lit(ionString("primaryKey")))), lit(ionInt(42))), + scan(id("bar"), varDecl("b")) + ) + ) + ) + } + ), + TestCase( + "SELECT DISTINCT b.* FROM bar AS b", + PartiqlLogical.build { + query( + call( + "filter_distinct", + bindingsToValues( + struct(structFields(id("b"))), + scan(id("bar"), varDecl("b")) + ) + ) + ) + } + ), + ) + } + + data class TodoTestCase(val sql: String) + @ParameterizedTest + @ArgumentsSource(ArgumentsForToToDoTests::class) + fun todo(tc: TodoTestCase) { + assertThrows("Parsing TestCase.sql should throw NotImplementedError") { + parseAndTransform(tc.sql) + } + } + + /** + * A list of statements that cannot be converted into the logical algebra yet by [AstToLogicalVisitorTransform]. + * This is temporary--in the near future, we will accomplish this with a better language restriction feature which + * blocks all language features except those explicitly allowed. This will be needed to constrain possible queries + * to features supported by specific PartiQL-services. + */ + class ArgumentsForToToDoTests : ArgumentsProviderBase() { + override fun getParameters() = listOf( + // SELECT queries + TodoTestCase("SELECT b.* FROM UNPIVOT x as y"), + TodoTestCase("SELECT b.* FROM bar AS b GROUP BY a"), + TodoTestCase("SELECT b.* FROM bar AS b HAVING x"), + TodoTestCase("SELECT b.* FROM bar AS b ORDER BY y"), + TodoTestCase("PIVOT v AT n FROM data AS d"), + + // DML + TodoTestCase("CREATE TABLE foo"), + TodoTestCase("DROP TABLE foo"), + TodoTestCase("CREATE INDEX ON foo (x)"), + TodoTestCase("DROP INDEX bar ON foo"), + + // DDL + TodoTestCase("INSERT INTO foo VALUE 1"), + TodoTestCase("INSERT INTO foo VALUE 1"), + TodoTestCase("FROM x WHERE a = b SET k = 5"), + TodoTestCase("FROM x INSERT INTO foo VALUES (1, 2)"), + TodoTestCase("UPDATE x SET k = 5"), + TodoTestCase("UPDATE x INSERT INTO k << 1 >>"), + TodoTestCase("DELETE FROM y"), + TodoTestCase("REMOVE y"), + ) + } +} diff --git a/lang/test/org/partiql/lang/planner/transforms/LogicalResolvedToDefaultPhysicalVisitorTransformTests.kt b/lang/test/org/partiql/lang/planner/transforms/LogicalResolvedToDefaultPhysicalVisitorTransformTests.kt new file mode 100644 index 0000000000..c69fca6597 --- /dev/null +++ b/lang/test/org/partiql/lang/planner/transforms/LogicalResolvedToDefaultPhysicalVisitorTransformTests.kt @@ -0,0 +1,69 @@ +package org.partiql.lang.planner.transforms + +import com.amazon.ionelement.api.ionBool +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import org.partiql.lang.domains.PartiqlLogicalResolved +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.util.ArgumentsProviderBase + +class LogicalResolvedToDefaultPhysicalVisitorTransformTests { + data class TestCase(val input: PartiqlLogicalResolved.Bexpr, val expected: PartiqlPhysical.Bexpr) + + @ParameterizedTest + @ArgumentsSource(ArgumentsForToPhysicalTests::class) + fun `to physical`(tc: TestCase) { + assertEquals(tc.expected, LogicalResolvedToDefaultPhysicalVisitorTransform().transformBexpr(tc.input)) + } + + class ArgumentsForToPhysicalTests : ArgumentsProviderBase() { + override fun getParameters() = listOf( + TestCase( + PartiqlLogicalResolved.build { + scan( + expr = globalId("foo", caseInsensitive()), + asDecl = varDecl(0), + atDecl = varDecl(1), + byDecl = varDecl(2) + ) + }, + PartiqlPhysical.build { + scan( + i = DEFAULT_IMPL, + expr = globalId("foo", caseInsensitive()), + asDecl = varDecl(0), + atDecl = varDecl(1), + byDecl = varDecl(2) + ) + } + ), + TestCase( + PartiqlLogicalResolved.build { + filter( + predicate = lit(ionBool(true)), + source = scan( + expr = globalId("foo", caseInsensitive()), + asDecl = varDecl(0), + atDecl = varDecl(1), + byDecl = varDecl(2) + ) + ) + }, + PartiqlPhysical.build { + filter( + i = DEFAULT_IMPL, + predicate = lit(ionBool(true)), + source = scan( + i = DEFAULT_IMPL, + expr = globalId("foo", caseInsensitive()), + asDecl = varDecl(0), + atDecl = varDecl(1), + byDecl = varDecl(2) + ) + ) + } + ) + ) + } +} diff --git a/lang/test/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransformTests.kt b/lang/test/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransformTests.kt new file mode 100644 index 0000000000..d61e399e32 --- /dev/null +++ b/lang/test/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransformTests.kt @@ -0,0 +1,705 @@ +package org.partiql.lang.planner.transforms + +import com.amazon.ion.system.IonSystemBuilder +import com.amazon.ionelement.api.ionSymbol +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.assertDoesNotThrow +import org.junit.jupiter.api.fail +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import org.partiql.lang.domains.PartiqlLogical +import org.partiql.lang.domains.PartiqlLogicalResolved +import org.partiql.lang.errors.Problem +import org.partiql.lang.errors.ProblemCollector +import org.partiql.lang.eval.BindingCase +import org.partiql.lang.eval.builtins.DYNAMIC_LOOKUP_FUNCTION_NAME +import org.partiql.lang.eval.sourceLocationMeta +import org.partiql.lang.planner.createFakeMetadataResolver +import org.partiql.lang.planner.problem +import org.partiql.lang.syntax.SqlParser +import org.partiql.lang.util.ArgumentsProviderBase +import org.partiql.lang.util.toIntExact + +private fun localVariable(name: String, index: Int) = + PartiqlLogicalResolved.build { localVariable(name, index.toLong()) } + +/** Shortcut for creating a dynamic lookup call site for the expected plans below. */ +private fun PartiqlLogicalResolved.Builder.dynamicLookup( + name: String, + case: BindingCase, + globalsFirst: Boolean = false, + vararg searchTargets: PartiqlLogicalResolved.Expr +) = + call( + DYNAMIC_LOOKUP_FUNCTION_NAME, + listOf( + lit(ionSymbol(name)), + lit( + ionSymbol( + when (case) { + BindingCase.SENSITIVE -> "case_sensitive" + BindingCase.INSENSITIVE -> "case_insensitive" + } + ) + ), + lit( + ionSymbol( + when { + globalsFirst -> "globals_then_locals" + else -> "locals_then_globals" + } + ) + ) + ) + searchTargets + ) + +class LogicalToLogicalResolvedVisitorTransformTests { + data class TestCase( + val sql: String, + val expectation: Expectation, + val allowUndefinedVariables: Boolean = false + ) + + data class ResolvedId( + val line: Int, + val charOffset: Int, + val expr: PartiqlLogicalResolved.Expr + ) { + constructor( + line: Int, + charOffset: Int, + build: PartiqlLogicalResolved.Builder.() -> PartiqlLogicalResolved.Expr + ) : this(line, charOffset, PartiqlLogicalResolved.BUILDER().build()) + + override fun toString(): String { + return "($line, $charOffset): $expr" + } + } + + sealed class Expectation { + data class Success( + val expectedIds: List, + val expectedLocalVariables: List + ) : Expectation() { + constructor(vararg expectedIds: ResolvedId) : this(expectedIds.toList(), emptyList()) + fun withLocals(vararg expectedLocalVariables: PartiqlLogicalResolved.LocalVariable) = + this.copy(expectedLocalVariables = expectedLocalVariables.toList()) + } + data class Problems(val problems: List) : Expectation() { + constructor(vararg problems: Problem) : this(problems.toList()) + } + } + + /** Mock table resolver. That can resolve f, foo, or UPPERCASE_FOO, while respecting case-sensitivity. */ + private val metadataResolver = createFakeMetadataResolver( + *listOf( + "shadow", + "foo", + "bar", + "bat", + "UPPERCASE_FOO", + "case_AMBIGUOUS_foo", + "case_ambiguous_FOO" + ).map { + it to "fake_uid_for_$it" + }.toTypedArray() + ) + + private val ion = IonSystemBuilder.standard().build() + private val parser = SqlParser(ion) + + private fun runTestCase(tc: TestCase) { + val plan: PartiqlLogical.Plan = assertDoesNotThrow { + parser.parseAstStatement(tc.sql).toLogicalPlan() + } + + val problemHandler = ProblemCollector() + + when (tc.expectation) { + is Expectation.Success -> { + val resolved = plan.toResolvedPlan(problemHandler, metadataResolver, tc.allowUndefinedVariables) + + // extract all of the dynamic, global and local ids from the resolved logical plan. + val actualResolvedIds = + object : PartiqlLogicalResolved.VisitorFold>() { + override fun visitExpr( + node: PartiqlLogicalResolved.Expr, + accumulator: List + ): List = + when (node) { + is PartiqlLogicalResolved.Expr.GlobalId, + is PartiqlLogicalResolved.Expr.LocalId -> accumulator + node + is PartiqlLogicalResolved.Expr.Call -> { + if (node.funcName.text == DYNAMIC_LOOKUP_FUNCTION_NAME) { + accumulator + node + } else { + accumulator + } + } + else -> accumulator + } + + // Don't include children of dynamic lookup callsites + override fun walkExprCall( + node: PartiqlLogicalResolved.Expr.Call, + accumulator: List + ): List { + return if (node.funcName.text == DYNAMIC_LOOKUP_FUNCTION_NAME) { + accumulator + } else { + super.walkExprCall(node, accumulator) + } + } + }.walkPlan(resolved, emptyList()) + + assertEquals( + tc.expectation.expectedIds.size, actualResolvedIds.size, + "Number of expected resovled variables must match actual" + ) + + val remainingActualResolvedIds = actualResolvedIds.map { + val location = it.metas.sourceLocationMeta ?: error("$it missing source location meta") + ResolvedId(location.lineNum.toIntExact(), location.charOffset.toIntExact()) { it } + }.filter { expectedId: ResolvedId -> + tc.expectation.expectedIds.none { actualId -> actualId == expectedId } + } + + if (remainingActualResolvedIds.isNotEmpty()) { + val sb = StringBuilder() + sb.appendLine("Unexpected ids:") + remainingActualResolvedIds.forEach { + sb.appendLine(it) + } + sb.appendLine("Expected ids:") + tc.expectation.expectedIds.forEach { + sb.appendLine(it) + } + + fail("Unmatched resolved ids were found.\n$sb") + } + + assertEquals( + tc.expectation.expectedLocalVariables, + resolved.locals, + "Expected and actual local variables must match" + ) + } + is Expectation.Problems -> { + assertDoesNotThrow("Should not throw when variables are undefined") { + plan.toResolvedPlan(problemHandler, metadataResolver) + } + assertEquals(tc.expectation.problems, problemHandler.problems) + } + } + } + + @ParameterizedTest + @ArgumentsSource(CaseInsensitiveGlobalsCases::class) + fun `case-insensitive globals`(tc: TestCase) = runTestCase(tc) + class CaseInsensitiveGlobalsCases : ArgumentsProviderBase() { + override fun getParameters() = listOf( + // Case-insensitive resolution of global variables... + TestCase( + // all uppercase + sql = "FOO", + expectation = Expectation.Success(ResolvedId(1, 1) { globalId("fake_uid_for_foo", caseInsensitive()) }) + ), + TestCase( + // all lower case + "foo", + Expectation.Success(ResolvedId(1, 1) { globalId("fake_uid_for_foo", caseInsensitive()) }) + ), + TestCase( + // mixed case + "fOo", + Expectation.Success(ResolvedId(1, 1) { globalId("fake_uid_for_foo", caseInsensitive()) }) + ), + TestCase( + // undefined + """ foobar """, + Expectation.Problems( + problem( + 1, + 2, + PlanningProblemDetails.UndefinedVariable("foobar", caseSensitive = false) + ) + ) + ), + + // Ambiguous case-insensitive lookup + TestCase( + // ambiguous + """case_ambiguous_foo """, + // In this case, we resolve to the first matching binding. This is consistent with Postres 9.6. + Expectation.Success( + ResolvedId(1, 1) { + globalId("fake_uid_for_case_AMBIGUOUS_foo", caseInsensitive()) + } + ) + ), + + // Case-insensitive resolution of global variables with all uppercase letters... + TestCase( + // all uppercase + "UPPERCASE_FOO", + Expectation.Success( + ResolvedId(1, 1) { + globalId("fake_uid_for_UPPERCASE_FOO", caseInsensitive()) + } + ) + ), + TestCase( + // all lower case + "uppercase_foo", + Expectation.Success( + ResolvedId(1, 1) { + globalId("fake_uid_for_UPPERCASE_FOO", caseInsensitive()) + } + ) + ), + TestCase( + // mixed case + "UpPeRcAsE_fOo", + Expectation.Success( + ResolvedId(1, 1) { + globalId("fake_uid_for_UPPERCASE_FOO", caseInsensitive()) + } + ) + ), + + // undefined variables allowed + TestCase( + // undefined allowed (case-insensitive) + """some_undefined """, + Expectation.Success( + ResolvedId(1, 1) { + dynamicLookup("some_undefined", BindingCase.INSENSITIVE, globalsFirst = false) + } + ), + allowUndefinedVariables = true + ), + ) + } + + @ParameterizedTest + @ArgumentsSource(CaseSensitiveGlobalsCases::class) + fun `case-sensitive globals`(tc: TestCase) = runTestCase(tc) + class CaseSensitiveGlobalsCases : ArgumentsProviderBase() { + override fun getParameters() = listOf( + // Case-sensitive resolution of global variable with all lowercase letters + TestCase( + // all uppercase + "\"FOO\"", + Expectation.Problems( + problem( + 1, + 1, + PlanningProblemDetails.UndefinedVariable("FOO", caseSensitive = true) + ) + ) + ), + TestCase( + // all lowercase + "\"foo\"", + Expectation.Success(ResolvedId(1, 1) { globalId("fake_uid_for_foo", caseSensitive()) }) + ), + TestCase( + // mixed + "\"foO\"", + Expectation.Problems( + problem( + 1, + 1, + PlanningProblemDetails.UndefinedVariable("foO", caseSensitive = true) + ) + ) + ), + + // Case-sensitive resolution of global variables with all uppercase letters + TestCase( + // all uppercase + "\"UPPERCASE_FOO\"", + Expectation.Success( + ResolvedId(1, 1) { + globalId( + "fake_uid_for_UPPERCASE_FOO", caseSensitive() + ) + } + ) + ), + TestCase( + // all lowercase + "\"uppercase_foo\"", + Expectation.Problems( + problem(1, 1, PlanningProblemDetails.UndefinedVariable("uppercase_foo", caseSensitive = true)) + ) + ), + TestCase( + // mixed + "\"UpPeRcAsE_fOo\"", + Expectation.Problems( + problem(1, 1, PlanningProblemDetails.UndefinedVariable("UpPeRcAsE_fOo", caseSensitive = true)) + ) + ), + TestCase( + // not ambiguous when case-sensitive + "\"case_AMBIGUOUS_foo\"", + Expectation.Success( + ResolvedId(1, 1) { + globalId("fake_uid_for_case_AMBIGUOUS_foo", caseSensitive()) + } + ) + ), + TestCase( + // not ambiguous when case-sensitive + "\"case_ambiguous_FOO\"", + Expectation.Success( + ResolvedId(1, 1) { + globalId("fake_uid_for_case_ambiguous_FOO", caseSensitive()) + } + ) + ), + TestCase( + // undefined + """ FOOBAR """, + Expectation.Problems( + problem( + 1, + 2, + PlanningProblemDetails.UndefinedVariable("FOOBAR", caseSensitive = false) + ) + ) + ), + + TestCase( + // undefined allowed (case-sensitive) + "\"some_undefined\"", + Expectation.Success( + ResolvedId(1, 1) { + dynamicLookup("some_undefined", BindingCase.SENSITIVE) + } + ), + allowUndefinedVariables = true + ) + ) + } + + @ParameterizedTest + @ArgumentsSource(CaseInsensitiveLocalsVariablesCases::class) + fun `case-insensitive local variables`(tc: TestCase) = runTestCase(tc) + class CaseInsensitiveLocalsVariablesCases : ArgumentsProviderBase() { + override fun getParameters() = listOf( + // Case-insensitive resolution of local variables with all lowercase letters... + TestCase( + // all uppercase + "SELECT FOO.* FROM 1 AS foo WHERE FOO", + Expectation.Success( + ResolvedId(1, 8) { localId(0) }, + ResolvedId(1, 34) { localId(0) } + ).withLocals(localVariable("foo", 0)) + ), + TestCase( + // all lowercase + "SELECT foo.* FROM 1 AS foo WHERE foo", + Expectation.Success( + ResolvedId(1, 8) { localId(0) }, + ResolvedId(1, 34) { localId(0) } + ).withLocals(localVariable("foo", 0)) + ), + TestCase( + // mixed case + "SELECT FoO.* FROM 1 AS foo WHERE fOo", + Expectation.Success( + ResolvedId(1, 8) { localId(0) }, + ResolvedId(1, 34) { localId(0) } + ).withLocals(localVariable("foo", 0)) + ), + TestCase( + // foobar is undefined (select list) + "SELECT foobar.* FROM [] AS foo", + Expectation.Problems( + problem(1, 8, PlanningProblemDetails.UndefinedVariable("foobar", caseSensitive = false)) + ) + ), + TestCase( + // barbat is undefined (where clause) + "SELECT foo.* FROM [] AS foo WHERE barbat", + Expectation.Problems( + problem(1, 35, PlanningProblemDetails.UndefinedVariable("barbat", caseSensitive = false)) + ) + ) + ) + } + + @ParameterizedTest + @ArgumentsSource(CaseSensitiveLocalVariablesCases::class) + fun `case-sensitive locals variables`(tc: TestCase) = runTestCase(tc) + class CaseSensitiveLocalVariablesCases : ArgumentsProviderBase() { + override fun getParameters() = listOf( + // Case-insensitive resolution of local variables with all lowercase letters... + TestCase( + // all uppercase + "SELECT \"FOO\".* FROM 1 AS foo WHERE \"FOO\"", + Expectation.Problems( + problem(1, 8, PlanningProblemDetails.UndefinedVariable("FOO", caseSensitive = true)), + problem(1, 36, PlanningProblemDetails.UndefinedVariable("FOO", caseSensitive = true)) + ) + ), + TestCase( + // all lowercase + "SELECT \"foo\".* FROM 1 AS foo WHERE \"foo\"", + Expectation.Success( + ResolvedId(1, 8) { localId(0) }, + ResolvedId(1, 36) { localId(0) }, + ).withLocals(localVariable("foo", 0)) + ), + TestCase( + // mixed case + "SELECT \"FoO\".* FROM 1 AS foo WHERE \"fOo\"", + Expectation.Problems( + problem(1, 8, PlanningProblemDetails.UndefinedVariable("FoO", caseSensitive = true)), + problem(1, 36, PlanningProblemDetails.UndefinedVariable("fOo", caseSensitive = true)) + ) + ), + TestCase( + // "foobar" is undefined (select list) + "SELECT \"foobar\".* FROM [] AS foo ", + Expectation.Problems( + problem(1, 8, PlanningProblemDetails.UndefinedVariable("foobar", caseSensitive = true)) + ) + ), + TestCase( + // "barbat" is undefined (where clause) + "SELECT \"foo\".* FROM [] AS foo WHERE \"barbat\"", + Expectation.Problems( + problem(1, 37, PlanningProblemDetails.UndefinedVariable("barbat", caseSensitive = true)) + ) + ) + ) + } + + @ParameterizedTest + @ArgumentsSource(DuplicateVariableCases::class) + fun `duplicate variables`(tc: TestCase) = runTestCase(tc) + class DuplicateVariableCases : ArgumentsProviderBase() { + override fun getParameters() = listOf( + // Duplicate variables with same case + TestCase( + "SELECT {}.* FROM 1 AS a AT a", + Expectation.Problems(problem(1, 28, PlanningProblemDetails.VariablePreviouslyDefined("a"))), + ), + TestCase( + "SELECT {}.* FROM 1 AS a BY a", + Expectation.Problems(problem(1, 28, PlanningProblemDetails.VariablePreviouslyDefined("a"))), + ), + TestCase( + "SELECT {}.* FROM 1 AS notdup AT a BY a", + Expectation.Problems(problem(1, 38, PlanningProblemDetails.VariablePreviouslyDefined("a"))), + ), + TestCase( + "SELECT {}.* FROM 1 AS a AT a BY a", + Expectation.Problems( + problem(1, 28, PlanningProblemDetails.VariablePreviouslyDefined("a")), + problem(1, 33, PlanningProblemDetails.VariablePreviouslyDefined("a")) + ), + ), + // Duplicate variables with different cases + TestCase( + "SELECT {}.* FROM 1 AS a AT A", + Expectation.Problems(problem(1, 28, PlanningProblemDetails.VariablePreviouslyDefined("A"))), + ), + TestCase( + "SELECT {}.* FROM 1 AS A BY a", + Expectation.Problems(problem(1, 28, PlanningProblemDetails.VariablePreviouslyDefined("a"))), + ), + TestCase( + "SELECT {}.* FROM 1 AS notdup AT a BY A", + Expectation.Problems(problem(1, 38, PlanningProblemDetails.VariablePreviouslyDefined("A"))), + ), + TestCase( + "SELECT {}.* FROM 1 AS foo AT fOo BY foO", + Expectation.Problems( + problem(1, 30, PlanningProblemDetails.VariablePreviouslyDefined("fOo")), + problem(1, 37, PlanningProblemDetails.VariablePreviouslyDefined("foO")) + ), + ) + // Future test cases: duplicate variables across joins, i.e. `foo AS a, bar AS a`, etc. + ) + } + + @ParameterizedTest + @ArgumentsSource(MiscLocalVariableCases::class) + fun `misc local variable`(tc: TestCase) = runTestCase(tc) + class MiscLocalVariableCases : ArgumentsProviderBase() { + private fun createScanTestCase(varName: String, expectedIndex: Int) = + TestCase( + "SELECT $varName.* FROM foo AS a AT b BY c", + Expectation.Success( + ResolvedId(1, 8) { localId(expectedIndex.toLong()) }, + ResolvedId(1, 17) { globalId("fake_uid_for_foo", caseInsensitive()) } + ).withLocals(localVariable("a", 0), localVariable("b", 1), localVariable("c", 2)) + ) + + override fun getParameters() = listOf( + // Demonstrates that FROM source AS aliases work + createScanTestCase("a", 0), + // Demonstrates that FROM source AT aliases work + createScanTestCase("b", 1), + // Demonstrates that FROM source BY aliases work + createScanTestCase("c", 2), + + // Covers local variables in select list, global variables in FROM source, local variables in WHERE clause + TestCase( + "SELECT b.* FROM bar AS b WHERE b.primaryKey = 42", + Expectation.Success( + ResolvedId(1, 8) { localId(0) }, + ResolvedId(1, 17) { globalId("fake_uid_for_bar", caseInsensitive()) }, + ResolvedId(1, 32) { localId(0) }, + ).withLocals(localVariable("b", 0)) + ), + + // Demonstrate that globals-first variable lookup only happens in the FROM clause. + TestCase( + "SELECT shadow.* FROM shadow AS shadow", // `shadow` defined here shadows the global `shadow` + Expectation.Success( + ResolvedId(1, 8) { localId(0) }, + ResolvedId(1, 22) { globalId("fake_uid_for_shadow", caseInsensitive()) } + ).withLocals(localVariable("shadow", 0)) + ), + + // JOIN with shadowing + TestCase( + // first `AS s` shadowed by second `AS s`. + "SELECT s.* FROM 1 AS s, @s AS s", + Expectation.Success( + ResolvedId(1, 8) { localId(1) }, + ResolvedId(1, 26) { localId(0) } + ).withLocals(localVariable("s", 0), localVariable("s", 1)) + // Note that these two variables (^) have the same name but different indexes. + ), + ) + } + + @ParameterizedTest + @ArgumentsSource(DynamicIdSearchCases::class) + fun `dynamic_lookup search order cases`(tc: TestCase) = runTestCase(tc) + class DynamicIdSearchCases : ArgumentsProviderBase() { + // The important thing being asserted here is the contents of the dynamicId.search, which + // defines the places we'll look for variables that are unresolved at compile time. + override fun getParameters() = listOf( + // Not in an SFW query (empty search path) + TestCase( + "undefined1 + undefined2", + Expectation.Success( + ResolvedId(1, 1) { dynamicLookup("undefined1", BindingCase.INSENSITIVE, globalsFirst = false) }, + ResolvedId(1, 14) { dynamicLookup("undefined2", BindingCase.INSENSITIVE, globalsFirst = false) } + ), + allowUndefinedVariables = true + ), + + // In select list and where clause + TestCase( + "SELECT undefined1 AS u FROM 1 AS f WHERE undefined2", // 1 from source + Expectation.Success( + ResolvedId(1, 8) { dynamicLookup("undefined1", BindingCase.INSENSITIVE, globalsFirst = false, localId(0)) }, + ResolvedId(1, 42) { dynamicLookup("undefined2", BindingCase.INSENSITIVE, globalsFirst = false, localId(0)) } + ).withLocals(localVariable("f", 0)), + allowUndefinedVariables = true + ), + TestCase( + sql = "SELECT undefined1 AS u FROM 1 AS a, 2 AS b WHERE undefined2", // 2 from sources + Expectation.Success( + ResolvedId(1, 8) { dynamicLookup("undefined1", BindingCase.INSENSITIVE, globalsFirst = false, localId(1), localId(0)) }, + ResolvedId(1, 50) { dynamicLookup("undefined2", BindingCase.INSENSITIVE, globalsFirst = false, localId(1), localId(0)) } + ).withLocals(localVariable("a", 0), localVariable("b", 1)), + allowUndefinedVariables = true + ), + TestCase( + sql = "SELECT undefined1 AS u FROM 1 AS f, 1 AS b, 1 AS t WHERE undefined2", // 3 from sources + Expectation.Success( + ResolvedId(1, 8) { + dynamicLookup("undefined1", BindingCase.INSENSITIVE, globalsFirst = false, localId(2), localId(1), localId(0)) + }, + ResolvedId(1, 58) { + dynamicLookup("undefined2", BindingCase.INSENSITIVE, globalsFirst = false, localId(2), localId(1), localId(0)) + } + ).withLocals(localVariable("f", 0), localVariable("b", 1), localVariable("t", 2)), + allowUndefinedVariables = true + ), + // In from clause + TestCase( + // Wihtout scope override + "SELECT 1 AS x FROM undefined_table AS f", + Expectation.Success( + ResolvedId(1, 20) { dynamicLookup("undefined_table", BindingCase.INSENSITIVE, globalsFirst = true) }, + ).withLocals( + localVariable("f", 0), + ), + allowUndefinedVariables = true + ), + TestCase( + // Wiht scope override + "SELECT 1 AS x FROM @undefined_table AS f", + Expectation.Success( + ResolvedId(1, 21) { dynamicLookup("undefined_table", BindingCase.INSENSITIVE, globalsFirst = false) }, + ).withLocals( + localVariable("f", 0), + ), + allowUndefinedVariables = true + ), + TestCase( + // with correlated join + "SELECT 1 AS x FROM undefined_table AS f, @asdf AS f2", + Expectation.Success( + ResolvedId(1, 20) { dynamicLookup("undefined_table", BindingCase.INSENSITIVE, globalsFirst = true) }, + ResolvedId(1, 43) { dynamicLookup("asdf", BindingCase.INSENSITIVE, globalsFirst = false, localId(0)) } + ).withLocals( + localVariable("f", 0), + localVariable("f2", 1) + ), + allowUndefinedVariables = true + ), + ) + } + + @ParameterizedTest + @ArgumentsSource(SubqueryCases::class) + fun `sub-queries`(tc: TestCase) = runTestCase(tc) + class SubqueryCases : ArgumentsProviderBase() { + override fun getParameters() = listOf( + TestCase( + // inner query does not reference variables outer query + "SELECT b.* FROM (SELECT a.* FROM 1 AS a) AS b", + Expectation.Success( + ResolvedId(1, 8) { localId(1) }, + ResolvedId(1, 25) { localId(0) }, + ).withLocals(localVariable("a", 0), localVariable("b", 1)) + ), + TestCase( + // inner query references variable from outer query. + "SELECT a.*, b.* FROM 1 AS a, (SELECT a.*, b.* FROM 1 AS x) AS b", + Expectation.Success( + // The variables reference in the outer query + ResolvedId(1, 8) { localId(0) }, + ResolvedId(1, 13) { localId(2) }, + // The variables reference in the inner query + ResolvedId(1, 38) { localId(0) }, + // Note that `b` from the outer query is not accessible inside the query so we fall back on dynamic lookup + ResolvedId(1, 43) { dynamicLookup("b", BindingCase.INSENSITIVE, globalsFirst = false, localId(1), localId(0)) } + ).withLocals(localVariable("a", 0), localVariable("x", 1), localVariable("b", 2)), + allowUndefinedVariables = true + ), + + // In FROM source + TestCase( + "SELECT f.*, u.* FROM 1 AS f, undefined AS u", + Expectation.Success( + ResolvedId(1, 8) { localId(0) }, + ResolvedId(1, 13) { localId(1) }, + ResolvedId(1, 30) { dynamicLookup("undefined", BindingCase.INSENSITIVE, globalsFirst = true, localId(0)) } + ).withLocals(localVariable("f", 0), localVariable("u", 1)), + allowUndefinedVariables = true + ), + ) + } +} diff --git a/lang/test/org/partiql/lang/util/testdsl/IonResultTestCase.kt b/lang/test/org/partiql/lang/util/testdsl/IonResultTestCase.kt index e437fbd99d..bc47d17e8c 100644 --- a/lang/test/org/partiql/lang/util/testdsl/IonResultTestCase.kt +++ b/lang/test/org/partiql/lang/util/testdsl/IonResultTestCase.kt @@ -9,10 +9,12 @@ import org.partiql.lang.eval.EVALUATOR_TEST_SUITE import org.partiql.lang.eval.EvaluationSession import org.partiql.lang.eval.ExprValue import org.partiql.lang.eval.ExprValueFactory -import org.partiql.lang.eval.evaluatortestframework.AstEvaluatorTestAdapter -import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestAdapter +import org.partiql.lang.eval.evaluatortestframework.CompilerPipelineFactory import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestCase +import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.eval.evaluatortestframework.ExpectedResultFormat +import org.partiql.lang.eval.evaluatortestframework.PipelineEvaluatorTestAdapter +import org.partiql.lang.eval.evaluatortestframework.PlannerPipelineFactory import org.partiql.lang.mockdb.MockDb import org.partiql.lang.syntax.SqlParser @@ -71,9 +73,18 @@ data class IonResultTestCase( internal fun IonResultTestCase.runTestCase( valueFactory: ExprValueFactory, db: MockDb, + target: EvaluatorTestTarget, compilerPipelineBuilderBlock: CompilerPipeline.Builder.() -> Unit = { } ) { - val harness: EvaluatorTestAdapter = AstEvaluatorTestAdapter() + val adapter = PipelineEvaluatorTestAdapter( + when (target) { + EvaluatorTestTarget.COMPILER_PIPELINE -> CompilerPipelineFactory() + EvaluatorTestTarget.PLANNER_PIPELINE -> PlannerPipelineFactory() + // We don't support ALL_PIPELINES here because each pipeline needs a separate skip list, which + // is decided by the caller of this function. + EvaluatorTestTarget.ALL_PIPELINES -> error("May only test one pipeline at a time with IonResultTestCase") + } + ) val session = EvaluationSession.build { globals(db.valueBindings) @@ -94,13 +105,13 @@ internal fun IonResultTestCase.runTestCase( ) if (!this.expectFailure) { - harness.runEvaluatorTestCase(tc, session) + adapter.runEvaluatorTestCase(tc, session) } else { val message = "We expect test \"${this.name}\" to fail, but it did not. This check exists to ensure the " + "failing list is up to date." assertThrows(message) { - harness.runEvaluatorTestCase(tc, session) + adapter.runEvaluatorTestCase(tc, session) } } } diff --git a/pts/test/org/partiql/lang/pts/PartiQlPtsEvaluator.kt b/pts/test/org/partiql/lang/pts/PartiQlPtsEvaluator.kt index 373f970401..6d4bf99912 100644 --- a/pts/test/org/partiql/lang/pts/PartiQlPtsEvaluator.kt +++ b/pts/test/org/partiql/lang/pts/PartiQlPtsEvaluator.kt @@ -58,6 +58,9 @@ class PartiQlPtsEvaluator(equality: PtsEquality) : Evaluator(equality) { is ExpectedError -> TestResultSuccess(test) is ExpectedSuccess -> TestFailure(test, e.generateMessage(), TestFailure.FailureReason.UNEXPECTED_ERROR) } + } catch (e: Exception) { + // Other exception types are always failures. + TestFailure(test, "${e.javaClass.canonicalName} : ${e.message}", TestFailure.FailureReason.UNEXPECTED_ERROR) } private fun verifyTestResult(test: TestExpression, actualResult: IonValue): TestResult =