diff --git a/CHANGELOG.md b/CHANGELOG.md index ca6bbe6159..c5620465dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,51 @@ This is a pre-release containing: Please note that these changes are subject to future breaking changes without warning. +## [0.14.4] + +### Added +- Added constrained decimal as valid parameter type to functions that take in numeric parameters. +- Added async version of physical plan evaluator `PartiQLCompilerAsync`. + - The following related async APIs have been added: + - `org.partiql.lang.compiler` -- `PartiQLCompilerAsync`, `PartiQLCompilerAsyncBuilder`, `PartiQLCompilerAsyncDefault`, `PartiQLCompilerPipelineAsync` + - `org.partiql.lang.eval` -- `PartiQLStatementAsync` + - `org.partiql.lang.eval.physical` -- `VariableBindingAsync` + - `org.partiql.lang.eval.physical.operators` -- `AggregateOperatorFactoryAsync`, `CompiledGroupKeyAsync`, `CompiledAggregateFunctionAsync`, `FilterRelationalOperatorFactoryAsync`, `JoinRelationalOperatorFactoryAsync`, `LetRelationalOperatorFactoryAsync`, `LimitRelationalOperatorFactoryAsync`, `OffsetRelationalOperatorFactoryAsync`, `ProjectRelationalOperatorFactoryAsync`, `RelationExpressionAsync`, `ScanRelationalOperatorFactoryAsync`, `SortOperatorFactoryAsync`, `CompiledSortKeyAsync`, `UnpivotOperatorFactoryAsync`, `ValueExpressionAsync`, `WindowRelationalOperatorFactoryAsync`, `CompiledWindowFunctionAsync` + - `org.partiql.lang.eval.physical.window` -- `NavigationWindowFunctionAsync`, `WindowFunctionAsync` + - Overall, we see about a 10-20% performance decline in running a single query on the synchronous vs async evaluator + - JMH benchmarks added to partiql-lang: `PartiQLCompilerPipelineBenchmark` and `PartiQLCompilerPipelineAsyncBenchmark` + +### Changed +- Function resolution logic: Now the function resolver would match all possible candidate(based on if the argument can be coerced to the Signature parameter type). If there are multiple match it will first attempt to pick the one requires the least cast, then pick the function with the highest precedence. +- partiql-cli -- experimental version of CLI now uses the async physical plan evaluator + +### Deprecated +- As part of the additions to make an async physical plan evaluator, the synchronous physical plan evaluator `PartiQLCompiler` has been deprecated. + - The following related APIs have been deprecated + - `org.partiql.lang.compiler` -- `PartiQLCompiler`, `PartiQLCompilerBuilder`, `PartiQLCompilerDefault`, `PartiQLCompilerPipeline` + - `org.partiql.lang.eval` -- `PartiQLStatement` + - `org.partiql.lang.eval.physical` -- `VariableBinding` + - `org.partiql.lang.eval.physical.operators` -- `AggregateOperatorFactory`, `CompiledGroupKey`, `CompiledAggregateFunction`, `FilterRelationalOperatorFactory`, `JoinRelationalOperatorFactory`, `LetRelationalOperatorFactory`, `LimitRelationalOperatorFactory`, `OffsetRelationalOperatorFactory`, `ProjectRelationalOperatorFactory`, `RelationExpression`, `ScanRelationalOperatorFactory`, `SortOperatorFactory`, `CompiledSortKey`, `UnpivotOperatorFactory`, `ValueExpression`, `WindowRelationalOperatorFactory`, `CompiledWindowFunction` + - `org.partiql.lang.eval.physical.window` -- `NavigationWindowFunction`, `WindowFunction` + +### Fixed +- partiql-ast: `SqlDialect` will wrap unary ops (`NOT`, `+`, `-`) in parens + +### Removed + +### Security + +### Contributors +Thank you to all who have contributed! +- @yliuuuu +- @alancai98 + +## [0.14.3] - 2024-02-14 + +### Fixed +- Return type of `partiql-ast`'s `SqlDialect` for `defaultReturn` to be a `SqlBlock` rather than `Nothing` +- Flatten `CASE WHEN` branch type in `PlanTyper` + ### Contributors Thank you to all who have contributed! - @alancai98 @@ -58,6 +103,26 @@ Thank you to all who have contributed! - @RCHowell - @yliuuuu +## [0.14.2] - 2024-01-25 + +### Added + +### Changed +- Upgrade IonJava dependency to v1.11.1 + +### Deprecated + +### Fixed + +### Removed + +### Security + +### Contributors +Thank you to all who have contributed! +- @RCHowell +- @alancai98 + ## [0.14.1] - 2024-01-03 ### Added @@ -87,8 +152,8 @@ Thank you to all who have contributed! - Adds top-level IR node creation functions. - Adds `componentN` functions (destructuring) to IR nodes via Kotlin data classes - Adds public `tag` field to IR nodes for associating metadata -- Adds AST Normalization Pass. -- Adds PartiQLPlanner Interface, which is responsible for translate an AST to a Plan. +- Adds AST Normalization Pass. +- Adds PartiQLPlanner Interface, which is responsible for translate an AST to a Plan. - **EXPERIMENTAL** Evaluation of `EXCLUDE` in the `EvaluatingCompiler` - This is currently marked as experimental until the RFC is approved https://github.com/partiql/partiql-lang/issues/27 - This will be added to the `PhysicalPlanCompiler` in an upcoming release @@ -96,13 +161,13 @@ Thank you to all who have contributed! ### Changed - StaticTypeInferencer and PlanTyper will not raise an error when an expression is inferred to `NULL` or `unionOf(NULL, MISSING)`. In these cases the StaticTypeInferencer and PlanTyper will still raise the Problem Code `ExpressionAlwaysReturnsNullOrMissing` but the severity of the problem has been changed to warning. In the case an expression always returns `MISSING`, problem code `ExpressionAlwaysReturnsMissing` will be raised, which will have problem severity of error. -- **Breaking** The default integer literal type is now 32-bit; if the literal can not fit in a 32-bit integer, it overflows to 64-bit. -- **BREAKING** `PartiQLValueType` now distinguishes between Arbitrary Precision Decimal and Fixed Precision Decimal. -- **BREAKING** Function Signature Changes. Now Function signature has two subclasses, `Scalar` and `Aggregation`. +- **Breaking** The default integer literal type is now 32-bit; if the literal can not fit in a 32-bit integer, it overflows to 64-bit. +- **BREAKING** `PartiQLValueType` now distinguishes between Arbitrary Precision Decimal and Fixed Precision Decimal. +- **BREAKING** Function Signature Changes. Now Function signature has two subclasses, `Scalar` and `Aggregation`. - **BREAKING** Plugin Changes. Only return one Connector.Factory, use Kotlin fields. JVM signature remains the same. -- **BREAKING** In the produced plan: +- **BREAKING** In the produced plan: - The new plan is fully resolved and typed. - - Operators will be converted to function call. + - Operators will be converted to function call. - Changes the return type of `filter_distinct` to a list if input collection is list - Changes the `PartiQLValue` collections to implement Iterable rather than Sequence, allowing for multiple consumption. - **BREAKING** Moves PartiQLParserBuilder.standard().build() to be PartiQLParser.default(). @@ -110,6 +175,7 @@ Thank you to all who have contributed! ### Deprecated + ### Fixed - Fixes the CLI hanging on invalid queries. See issue #1230. - Fixes Timestamp Type parsing issue. Previously Timestamp Type would get parsed to a Time type. @@ -119,7 +185,7 @@ Thank you to all who have contributed! ### Removed - **Breaking** Removed IR factory in favor of static top-level functions. Change `Ast.foo()` to `foo()` -- **Breaking** Removed `org.partiql.lang.planner.transforms.AstToPlan`. Use `org.partiql.planner.PartiQLPlanner`. +- **Breaking** Removed `org.partiql.lang.planner.transforms.AstToPlan`. Use `org.partiql.planner.PartiQLPlanner`. - **Breaking** Removed `org.partiql.lang.planner.transforms.PartiQLSchemaInferencer`. In order to achieve the same functionality, one would need to use the `org.partiql.planner.PartiQLPlanner`. - To get the inferred type of the query result, one can do: `(plan.statement as Statement.Query).root.type` @@ -198,7 +264,7 @@ Thank you to all who have contributed! - Parsing of label patterns within node and edge graph patterns now supports disjunction `|`, conjunction `&`, negation `!`, and grouping. - Adds default `equals` and `hashCode` methods for each generated abstract class of Sprout. This affects the generated -classes in `:partiql-ast` and `:partiql-plan`. + classes in `:partiql-ast` and `:partiql-plan`. - Adds README to `partiql-types` package. - Initializes PartiQL's Code Coverage library - Adds support for BRANCH and BRANCH-CONDITION Coverage @@ -240,12 +306,12 @@ classes in `:partiql-ast` and `:partiql-plan`. - Introduces `isNullCall` and `isNullable` properties to FunctionSignature. - Removed `Nullable...Value` implementations of PartiQLValue and made the standard implementations nullable. - Using PartiQLValueType requires optin; this was a miss from an earlier commit. -- Modified timestamp static type to model precision and time zone. +- Modified timestamp static type to model precision and time zone. ### Deprecated -- **Breaking**: Deprecates the `Arguments`, `RequiredArgs`, `RequiredWithOptional`, and `RequiredWithVariadic` classes, - along with the `callWithOptional()`, `callWithVariadic()`, and the overloaded `call()` methods in the `ExprFunction` class, - marking them with a Deprecation Level of ERROR. Now, it's recommended to use +- **Breaking**: Deprecates the `Arguments`, `RequiredArgs`, `RequiredWithOptional`, and `RequiredWithVariadic` classes, + along with the `callWithOptional()`, `callWithVariadic()`, and the overloaded `call()` methods in the `ExprFunction` class, + marking them with a Deprecation Level of ERROR. Now, it's recommended to use `call(session: EvaluationSession, args: List)` and `callWithRequired()` instead. - **Breaking**: Deprecates `optionalParameter` and `variadicParameter` in the `FunctionSignature` with a Deprecation Level of ERROR. Please use multiple implementations of ExprFunction and use the LIST ExprValue to @@ -281,7 +347,7 @@ Thank you to all who have contributed! - Moves PartiqlAst, PartiqlLogical, PartiqlLogicalResolved, and PartiqlPhysical (along with the transforms) to a new project, `partiql-ast`. These are still imported into `partiql-lang` with the `api` annotation. Therefore, no action is required to consume the migrated classes. However, this now gives consumers of the AST, Experimental Plans, - Visitors, and VisitorTransforms the option of importing them directly using: `org.partiql:partiql-ast:${VERSION}`. + Visitors, and VisitorTransforms the option of importing them directly using: `org.partiql:partiql-ast:${VERSION}`. The file `partiql.ion` is still published in the `partiql-lang-kotlin` JAR. - Moves internal class org.partiql.lang.syntax.PartiQLParser to org.partiql.lang.syntax.impl.PartiQLPigParser as we refactor for explicit API. - Moves ANTLR grammar to `partiql-parser` package. The files `PartiQL.g4` and `PartiQLTokens.g4` are still published in the `partiql-lang-kotlin` JAR. @@ -366,15 +432,15 @@ Thank you to all who have contributed! ### Added -- Adds an initial implementation of GPML (Graph Pattern Matching Language), following - PartiQL [RFC-0025](https://github.com/partiql/partiql-docs/blob/main/RFCs/0025-graph-data-model.md) +- Adds an initial implementation of GPML (Graph Pattern Matching Language), following + PartiQL [RFC-0025](https://github.com/partiql/partiql-docs/blob/main/RFCs/0025-graph-data-model.md) and [RFC-0033](https://github.com/partiql/partiql-docs/blob/main/RFCs/0033-graph-query.md). This initial implementation includes: - - A file format for external graphs, defined as a schema in ISL (Ion Schema Language), + - A file format for external graphs, defined as a schema in ISL (Ion Schema Language), as well as an in-memory graph data model and a reader for loading external graphs into it. - - CLI shell commands `!add_graph` and `!add_graph_from_file` for bringing - externally-defined graphs into the evaluation environment. - - Evaluation of straight-path patterns with simple label matching and + - CLI shell commands `!add_graph` and `!add_graph_from_file` for bringing + externally-defined graphs into the evaluation environment. + - Evaluation of straight-path patterns with simple label matching and all directed/undirected edge patterns. - Adds new `TupleConstraint` variant, `Ordered`, to represent ordering in `StructType`. See the KDoc for more information. @@ -484,7 +550,7 @@ breaking changes if migrating from v0.9.2. The breaking changes accidentally int ### Added - Adds ability to pipe queries to the CLI. - Adds ability to run PartiQL files as executables by adding support for shebangs. -- Adds experimental syntax for CREATE TABLE, towards addressing +- Adds experimental syntax for CREATE TABLE, towards addressing [#36](https://github.com/partiql/partiql-docs/issues/36) of specifying PartiQL DDL. ### Changed @@ -984,7 +1050,11 @@ breaking changes if migrating from v0.9.2. The breaking changes accidentally int ### Added Initial alpha release of PartiQL. -[Unreleased]: https://github.com/partiql/partiql-lang-kotlin/compare/v0.14.1...HEAD +[Unreleased]: https://github.com/partiql/partiql-lang-kotlin/compare/v0.14.5...HEAD +[0.14.5]: https://github.com/partiql/partiql-lang-kotlin/compare/v0.14.4...v0.14.5 +[0.14.4]: https://github.com/partiql/partiql-lang-kotlin/compare/v0.14.3...v0.14.4 +[0.14.3]: https://github.com/partiql/partiql-lang-kotlin/compare/v0.14.2...v0.14.3 +[0.14.2]: https://github.com/partiql/partiql-lang-kotlin/compare/v0.14.1...v0.14.2 [0.14.1]: https://github.com/partiql/partiql-lang-kotlin/compare/v0.14.0-alpha...v0.14.1 [0.14.0-alpha]: https://github.com/partiql/partiql-lang-kotlin/compare/v0.13.2-alpha...v0.14.0-alpha [0.13.2-alpha]: https://github.com/partiql/partiql-lang-kotlin/compare/v0.13.1-alpha...v0.13.2-alpha diff --git a/README.md b/README.md index 178221277b..2e99b55d54 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ This project is published to [Maven Central](https://search.maven.org/artifact/o | Group ID | Artifact ID | Recommended Version | |---------------|-----------------------|---------------------| -| `org.partiql` | `partiql-lang-kotlin` | `0.14.1` | +| `org.partiql` | `partiql-lang-kotlin` | `0.14.4` | For Maven builds, add the following to your `pom.xml`: diff --git a/buildSrc/src/main/kotlin/partiql.versions.kt b/buildSrc/src/main/kotlin/partiql.versions.kt index 23697e6a34..e543827c41 100644 --- a/buildSrc/src/main/kotlin/partiql.versions.kt +++ b/buildSrc/src/main/kotlin/partiql.versions.kt @@ -31,12 +31,15 @@ object Versions { const val gson = "2.10.1" const val guava = "31.1-jre" const val ionElement = "1.0.0" - const val ionJava = "1.10.2" + const val ionJava = "1.11.1" const val ionSchema = "1.2.1" const val jansi = "2.4.0" const val jgenhtml = "1.6" const val jline = "3.21.0" - const val jmh = "0.5.3" + const val jmhGradlePlugin = "0.5.3" + const val jmhCore = "1.37" + const val jmhGeneratorAnnprocess = "1.37" + const val jmhGeneratorBytecode = "1.37" const val joda = "2.12.1" const val kotlinPoet = "1.11.0" const val kotlinxCollections = "0.3.5" @@ -44,6 +47,8 @@ object Versions { const val kasechange = "1.3.0" const val ktlint = "11.6.0" const val pig = "0.6.2" + const val kotlinxCoroutines = "1.6.0" + const val kotlinxCoroutinesJdk8 = "1.6.0" // Testing const val assertj = "3.11.0" @@ -54,6 +59,7 @@ object Versions { const val junit4Params = "1.1.1" const val mockito = "4.5.0" const val mockk = "1.11.0" + const val kotlinxCoroutinesTest = "1.6.0" } object Deps { @@ -84,6 +90,8 @@ object Deps { const val picoCli = "info.picocli:picocli:${Versions.picoCli}" const val pig = "org.partiql:partiql-ir-generator:${Versions.pig}" const val pigRuntime = "org.partiql:partiql-ir-generator-runtime:${Versions.pig}" + const val kotlinxCoroutines = "org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.kotlinxCoroutines}" + const val kotlinxCoroutinesJdk8 = "org.jetbrains.kotlinx:kotlinx-coroutines-jdk8:${Versions.kotlinxCoroutinesJdk8}" // Testing const val assertj = "org.assertj:assertj-core:${Versions.assertj}" @@ -97,6 +105,12 @@ object Deps { const val kotlinTestJunit = "org.jetbrains.kotlin:kotlin-test-junit5:${Versions.kotlin}" const val mockito = "org.mockito:mockito-junit-jupiter:${Versions.mockito}" const val mockk = "io.mockk:mockk:${Versions.mockk}" + const val kotlinxCoroutinesTest = "org.jetbrains.kotlinx:kotlinx-coroutines-test:${Versions.kotlinxCoroutinesTest}" + + // JMH Benchmarking + const val jmhCore = "org.openjdk.jmh:jmh-core:${Versions.jmhCore}" + const val jmhGeneratorAnnprocess = "org.openjdk.jmh:jmh-core:${Versions.jmhGeneratorAnnprocess}" + const val jmhGeneratorBytecode = "org.openjdk.jmh:jmh-core:${Versions.jmhGeneratorBytecode}" } object Plugins { @@ -114,4 +128,4 @@ object Plugins { const val ktlint = "org.jlleitschuh.gradle.ktlint" const val library = "org.gradle.java-library" const val testFixtures = "org.gradle.java-test-fixtures" -} +} \ No newline at end of file diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 961ef516b5..126a7b37c9 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -26,6 +26,8 @@ dependencies { implementation(project(":partiql-lang")) implementation(project(":partiql-eval")) implementation(project(":partiql-types")) + implementation(Deps.kotlinxCoroutines) + implementation(Deps.kotlinxCoroutinesJdk8) implementation(Deps.awsSdkS3) } diff --git a/examples/src/main/java/org/partiql/examples/PartiQLCompilerPipelineJavaExample.java b/examples/src/main/java/org/partiql/examples/PartiQLCompilerPipelineAsyncJavaExample.java similarity index 57% rename from examples/src/main/java/org/partiql/examples/PartiQLCompilerPipelineJavaExample.java rename to examples/src/main/java/org/partiql/examples/PartiQLCompilerPipelineAsyncJavaExample.java index 1a989311e0..b267b4c4ea 100644 --- a/examples/src/main/java/org/partiql/examples/PartiQLCompilerPipelineJavaExample.java +++ b/examples/src/main/java/org/partiql/examples/PartiQLCompilerPipelineAsyncJavaExample.java @@ -2,17 +2,25 @@ import com.amazon.ion.IonSystem; import com.amazon.ion.system.IonSystemBuilder; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import kotlin.OptIn; +import kotlin.coroutines.EmptyCoroutineContext; +import kotlinx.coroutines.CoroutineScopeKt; +import kotlinx.coroutines.CoroutineStart; +import kotlinx.coroutines.future.FutureKt; import org.jetbrains.annotations.NotNull; import org.partiql.annotations.ExperimentalPartiQLCompilerPipeline; import org.partiql.examples.util.Example; -import org.partiql.lang.compiler.PartiQLCompiler; -import org.partiql.lang.compiler.PartiQLCompilerBuilder; -import org.partiql.lang.compiler.PartiQLCompilerPipeline; +import org.partiql.lang.compiler.PartiQLCompilerAsync; +import org.partiql.lang.compiler.PartiQLCompilerAsyncBuilder; +import org.partiql.lang.compiler.PartiQLCompilerPipelineAsync; import org.partiql.lang.eval.Bindings; import org.partiql.lang.eval.EvaluationSession; import org.partiql.lang.eval.ExprValue; import org.partiql.lang.eval.PartiQLResult; +import org.partiql.lang.eval.PartiQLStatementAsync; import org.partiql.lang.eval.ProjectionIterationBehavior; import org.partiql.lang.planner.EvaluatorOptions; import org.partiql.lang.planner.GlobalResolutionResult; @@ -25,14 +33,14 @@ import java.io.PrintStream; /** - * This is an example of using PartiQLCompilerPipeline in Java. + * This is an example of using PartiQLCompilerPipelineAsync in Java. * It is an experimental feature and is marked as such, with @OptIn, in this example. - * Unfortunately, it seems like the Java does not recognize the Optin annotation specified in Kotlin. + * Unfortunately, it seems like the Java does not recognize the OptIn annotation specified in Kotlin. * Java users will be able to access the experimental APIs freely, and not be warned at all. */ -public class PartiQLCompilerPipelineJavaExample extends Example { +public class PartiQLCompilerPipelineAsyncJavaExample extends Example { - public PartiQLCompilerPipelineJavaExample(@NotNull PrintStream out) { + public PartiQLCompilerPipelineAsyncJavaExample(@NotNull PrintStream out) { super(out); } @@ -49,10 +57,7 @@ public void run() { "{name: \"mary\", age: 19}" + "]"; - final Bindings globalVariables = Bindings.lazyBindingsBuilder().addBinding("myTable", () -> { - ExprValue exprValue = ExprValue.of(ion.singleValue(myTable)); - return exprValue; - }).build(); + final Bindings globalVariables = Bindings.lazyBindingsBuilder().addBinding("myTable", () -> ExprValue.of(ion.singleValue(myTable))).build(); final EvaluationSession session = EvaluationSession.builder() .globals(globalVariables) @@ -79,17 +84,41 @@ public void run() { final PartiQLPlanner planner = PartiQLPlannerBuilder.standard().globalVariableResolver(globalVariableResolver).build(); @OptIn(markerClass = ExperimentalPartiQLCompilerPipeline.class) - final PartiQLCompiler compiler = PartiQLCompilerBuilder.standard().options(evaluatorOptions).build(); + final PartiQLCompilerAsync compiler = PartiQLCompilerAsyncBuilder.standard().options(evaluatorOptions).build(); @OptIn(markerClass = ExperimentalPartiQLCompilerPipeline.class) - final PartiQLCompilerPipeline pipeline = new PartiQLCompilerPipeline( + final PartiQLCompilerPipelineAsync pipeline = new PartiQLCompilerPipelineAsync( parser, planner, compiler ); String query = "SELECT t.name FROM myTable AS t WHERE t.age > 20"; print("PartiQL query:", query); - PartiQLResult result = pipeline.compile(query).eval(session); + + // Calling Kotlin coroutines from Java requires some additional libraries from `kotlinx.coroutines.future` + // to return a `java.util.concurrent.CompletableFuture`. If a use case arises to call the + // `PartiQLCompilerPipelineAsync` APIs directly from Java, we can add Kotlin functions that directly return + // Java's async libraries (e.g. in https://stackoverflow.com/a/52887677). + CompletableFuture statementFuture = FutureKt.future( + CoroutineScopeKt.CoroutineScope(EmptyCoroutineContext.INSTANCE), + EmptyCoroutineContext.INSTANCE, + CoroutineStart.DEFAULT, + (scope, continuation) -> pipeline.compile(query, continuation) + ); + + PartiQLResult result; + try { + PartiQLStatementAsync statement = statementFuture.get(); + CompletableFuture resultFuture = FutureKt.future( + CoroutineScopeKt.CoroutineScope(EmptyCoroutineContext.INSTANCE), + EmptyCoroutineContext.INSTANCE, + CoroutineStart.DEFAULT, + (scope, continuation) -> statement.eval(session, continuation) + ); + result = resultFuture.get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } ExprValue exprValue = null; if (result instanceof PartiQLResult.Value) { exprValue = ((PartiQLResult.Value) result).getValue(); @@ -100,4 +129,4 @@ public void run() { print("result", exprValue); } -} \ No newline at end of file +} diff --git a/examples/src/main/kotlin/org/partiql/examples/PartiQLCompilerPipelineExample.kt b/examples/src/main/kotlin/org/partiql/examples/PartiQLCompilerPipelineAsyncExample.kt similarity index 89% rename from examples/src/main/kotlin/org/partiql/examples/PartiQLCompilerPipelineExample.kt rename to examples/src/main/kotlin/org/partiql/examples/PartiQLCompilerPipelineAsyncExample.kt index a4186c7a0d..e8247e1a79 100644 --- a/examples/src/main/kotlin/org/partiql/examples/PartiQLCompilerPipelineExample.kt +++ b/examples/src/main/kotlin/org/partiql/examples/PartiQLCompilerPipelineAsyncExample.kt @@ -1,9 +1,10 @@ package org.partiql.examples import com.amazon.ion.system.IonSystemBuilder +import kotlinx.coroutines.runBlocking import org.partiql.annotations.ExperimentalPartiQLCompilerPipeline import org.partiql.examples.util.Example -import org.partiql.lang.compiler.PartiQLCompilerPipeline +import org.partiql.lang.compiler.PartiQLCompilerPipelineAsync import org.partiql.lang.eval.Bindings import org.partiql.lang.eval.EvaluationSession import org.partiql.lang.eval.ExprValue @@ -20,7 +21,7 @@ import java.io.PrintStream * One way to do so is to add the `Optin(Experimental::class) before the class. where is the feature name. * Also see: https://kotlinlang.org/docs/opt-in-requirements.html#module-wide-opt-in */ -class PartiQLCompilerPipelineExample(out: PrintStream) : Example(out) { +class PartiQLCompilerPipelineAsyncExample(out: PrintStream) : Example(out) { private val myIonSystem = IonSystemBuilder.standard().build() @@ -59,7 +60,7 @@ class PartiQLCompilerPipelineExample(out: PrintStream) : Example(out) { .build() @OptIn(ExperimentalPartiQLCompilerPipeline::class) - private val partiQLCompilerPipeline = PartiQLCompilerPipeline.build { + private val partiQLCompilerPipeline = PartiQLCompilerPipelineAsync.build { planner .globalVariableResolver(globalVariableResolver) compiler @@ -71,7 +72,10 @@ class PartiQLCompilerPipelineExample(out: PrintStream) : Example(out) { print("PartiQL query:", query) @OptIn(ExperimentalPartiQLCompilerPipeline::class) - val exprValue = when (val result = partiQLCompilerPipeline.compile(query).eval(session)) { + val result = runBlocking { + partiQLCompilerPipeline.compile(query).eval(session) + } + val exprValue = when (result) { is PartiQLResult.Value -> result.value is PartiQLResult.Delete, is PartiQLResult.Explain.Domain, diff --git a/examples/src/main/kotlin/org/partiql/examples/util/Main.kt b/examples/src/main/kotlin/org/partiql/examples/util/Main.kt index 122a82f3ea..3d74fbfc08 100644 --- a/examples/src/main/kotlin/org/partiql/examples/util/Main.kt +++ b/examples/src/main/kotlin/org/partiql/examples/util/Main.kt @@ -12,8 +12,8 @@ import org.partiql.examples.EvaluationWithLazyBindings import org.partiql.examples.ParserErrorExample import org.partiql.examples.ParserExample import org.partiql.examples.ParserJavaExample -import org.partiql.examples.PartiQLCompilerPipelineExample -import org.partiql.examples.PartiQLCompilerPipelineJavaExample +import org.partiql.examples.PartiQLCompilerPipelineAsyncExample +import org.partiql.examples.PartiQLCompilerPipelineAsyncJavaExample import org.partiql.examples.PartialEvaluationVisitorTransformExample import org.partiql.examples.PreventJoinVisitorExample import org.partiql.examples.S3JavaExample @@ -26,7 +26,9 @@ private val examples = mapOf( S3JavaExample::class.java.simpleName to S3JavaExample(System.out), EvaluationJavaExample::class.java.simpleName to EvaluationJavaExample(System.out), ParserJavaExample::class.java.simpleName to ParserJavaExample(System.out), - PartiQLCompilerPipelineJavaExample::class.java.simpleName to PartiQLCompilerPipelineJavaExample(System.out), + PartiQLCompilerPipelineAsyncJavaExample::class.java.simpleName to PartiQLCompilerPipelineAsyncJavaExample( + System.out + ), // Kotlin Examples CsvExprValueExample::class.java.simpleName to CsvExprValueExample(System.out), @@ -39,7 +41,7 @@ private val examples = mapOf( PartialEvaluationVisitorTransformExample::class.java.simpleName to PartialEvaluationVisitorTransformExample(System.out), PreventJoinVisitorExample::class.java.simpleName to PreventJoinVisitorExample(System.out), SimpleExpressionEvaluation::class.java.simpleName to SimpleExpressionEvaluation(System.out), - PartiQLCompilerPipelineExample::class.java.simpleName to PartiQLCompilerPipelineExample(System.out) + PartiQLCompilerPipelineAsyncExample::class.java.simpleName to PartiQLCompilerPipelineAsyncExample(System.out) ) fun main(args: Array) { diff --git a/examples/src/test/kotlin/org/partiql/examples/PartiQLCompilerPipelineExampleTest.kt b/examples/src/test/kotlin/org/partiql/examples/PartiQLCompilerPipelineAsyncExampleTest.kt similarity index 81% rename from examples/src/test/kotlin/org/partiql/examples/PartiQLCompilerPipelineExampleTest.kt rename to examples/src/test/kotlin/org/partiql/examples/PartiQLCompilerPipelineAsyncExampleTest.kt index 47141203d5..3ce825937f 100644 --- a/examples/src/test/kotlin/org/partiql/examples/PartiQLCompilerPipelineExampleTest.kt +++ b/examples/src/test/kotlin/org/partiql/examples/PartiQLCompilerPipelineAsyncExampleTest.kt @@ -3,8 +3,8 @@ package org.partiql.examples import org.partiql.examples.util.Example import java.io.PrintStream -class PartiQLCompilerPipelineExampleTest : BaseExampleTest() { - override fun example(out: PrintStream): Example = PartiQLCompilerPipelineExample(out) +class PartiQLCompilerPipelineAsyncExampleTest : BaseExampleTest() { + override fun example(out: PrintStream): Example = PartiQLCompilerPipelineAsyncExample(out) override val expected = """ |PartiQL query: diff --git a/examples/src/test/kotlin/org/partiql/examples/PartiQLCompilerPipelineJavaExampleTest.kt b/examples/src/test/kotlin/org/partiql/examples/PartiQLCompilerPipelineAsyncJavaExampleTest.kt similarity index 67% rename from examples/src/test/kotlin/org/partiql/examples/PartiQLCompilerPipelineJavaExampleTest.kt rename to examples/src/test/kotlin/org/partiql/examples/PartiQLCompilerPipelineAsyncJavaExampleTest.kt index 74a91ec02c..8309e2b044 100644 --- a/examples/src/test/kotlin/org/partiql/examples/PartiQLCompilerPipelineJavaExampleTest.kt +++ b/examples/src/test/kotlin/org/partiql/examples/PartiQLCompilerPipelineAsyncJavaExampleTest.kt @@ -3,8 +3,9 @@ package org.partiql.examples import org.partiql.examples.util.Example import java.io.PrintStream -class PartiQLCompilerPipelineJavaExampleTest : BaseExampleTest() { - override fun example(out: PrintStream): Example = PartiQLCompilerPipelineJavaExample(out) +class PartiQLCompilerPipelineAsyncJavaExampleTest : BaseExampleTest() { + override fun example(out: PrintStream): Example = + PartiQLCompilerPipelineAsyncJavaExample(out) override val expected = """ |PartiQL query: diff --git a/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinVisitorPoem.kt b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinVisitorPoem.kt index de8aa6fbb3..59247f1081 100644 --- a/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinVisitorPoem.kt +++ b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinVisitorPoem.kt @@ -183,6 +183,7 @@ class KotlinVisitorPoem(symbols: KotlinSymbols) : KotlinPoem(symbols) { addFunction(visit) } } + .addKdoc("WARNING: This interface should not be implemented or extended by code outside of this library. Please extend [$baseVisitorName].") .build() return FileSpec.builder(visitorPackageName, visitorName).addType(visitor).build() } diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/Sql.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/Sql.kt index b9065fe8fe..50d5934cca 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/Sql.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/Sql.kt @@ -1,15 +1,41 @@ package org.partiql.ast.sql import org.partiql.ast.AstNode +import org.partiql.ast.sql.internal.InternalSqlDialect +import org.partiql.ast.sql.internal.InternalSqlLayout + +/** + * No argument uses optimized internal. Leaving older ones for backwards-compatibility. + */ +public fun AstNode.sql(): String { + val head = InternalSqlDialect.PARTIQL.apply(this) + return InternalSqlLayout.format(head) +} /** * Pretty-print this [AstNode] as SQL text with the given [SqlLayout] */ -@JvmOverloads +@Deprecated("To be removed in the next major version") public fun AstNode.sql( layout: SqlLayout = SqlLayout.DEFAULT, +): String = SqlDialect.PARTIQL.apply(this).sql(layout) + +/** + * Pretty-print this [AstNode] as SQL text with the given [SqlDialect] + */ +@Deprecated("To be removed in the next major version") +public fun AstNode.sql( dialect: SqlDialect = SqlDialect.PARTIQL, -): String = accept(dialect, SqlBlock.Nil).sql(layout) +): String = dialect.apply(this).sql(SqlLayout.DEFAULT) + +/** + * Pretty-print this [AstNode] as SQL text with the given [SqlLayout] and [SqlDialect] + */ +@Deprecated("To be removed in the next major version") +public fun AstNode.sql( + layout: SqlLayout, + dialect: SqlDialect, +): String = dialect.apply(this).sql(layout) // a <> b <-> a concat b diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlBlock.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlBlock.kt index c163e8998e..e712ac0f32 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlBlock.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlBlock.kt @@ -6,11 +6,13 @@ package org.partiql.ast.sql * @param layout SQL formatting ruleset * @return SQL text */ +@Deprecated("To be removed in the next major version") public fun SqlBlock.sql(layout: SqlLayout = SqlLayout.DEFAULT): String = layout.format(this) /** * Representation of some textual corpus; akin to Wadler's "A prettier printer" Document type. */ +@Deprecated("This will be changed in the next major version") sealed interface SqlBlock { public override fun toString(): String @@ -54,6 +56,7 @@ sealed interface SqlBlock { } } +@Deprecated("This will be changed in the next major version") public interface BlockVisitor { public fun visit(block: SqlBlock, ctx: C): R @@ -69,6 +72,7 @@ public interface BlockVisitor { public fun visitLink(block: SqlBlock.Link, ctx: C): R } +@Deprecated("This will be changed in the next major version") public abstract class BlockBaseVisitor : BlockVisitor { public abstract fun defaultReturn(block: SqlBlock, ctx: C): R diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt index 05dc0a9610..6a589a0556 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt @@ -26,6 +26,7 @@ import java.io.PrintStream /** * SqlDialect represents the base behavior for transforming an [AstNode] tree into a [SqlBlock] tree. */ +@Deprecated("This will be changed in the next major version") @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") public abstract class SqlDialect : AstBaseVisitor() { @@ -40,12 +41,12 @@ public abstract class SqlDialect : AstBaseVisitor() { public val PARTIQL = object : SqlDialect() {} } - override fun defaultReturn(node: AstNode, head: SqlBlock) = + override fun defaultReturn(node: AstNode, head: SqlBlock): SqlBlock = throw UnsupportedOperationException("Cannot print $node") // STATEMENTS - override fun visitStatementQuery(node: Statement.Query, head: SqlBlock) = visitExpr(node.expr, head) + override fun visitStatementQuery(node: Statement.Query, head: SqlBlock): SqlBlock = visitExpr(node.expr, head) // IDENTIFIERS & PATHS @@ -55,7 +56,7 @@ public abstract class SqlDialect : AstBaseVisitor() { * @param node * @param head */ - public open fun visitExprWrapped(node: Expr, head: SqlBlock) = when (node) { + public open fun visitExprWrapped(node: Expr, head: SqlBlock): SqlBlock = when (node) { is Expr.SFW -> { var h = head h = h concat "(" @@ -66,7 +67,7 @@ public abstract class SqlDialect : AstBaseVisitor() { else -> visitExpr(node, head) } - override fun visitIdentifierSymbol(node: Identifier.Symbol, head: SqlBlock) = head concat r(node.sql()) + override fun visitIdentifierSymbol(node: Identifier.Symbol, head: SqlBlock): SqlBlock = head concat r(node.sql()) override fun visitIdentifierQualified(node: Identifier.Qualified, head: SqlBlock): SqlBlock { val path = node.steps.fold(node.root.sql()) { p, step -> p + "." + step.sql() } @@ -116,93 +117,93 @@ public abstract class SqlDialect : AstBaseVisitor() { } // cannot write path step outside the context of a path as we don't want it to reflow - override fun visitPathStep(node: Path.Step, head: SqlBlock) = error("path step cannot be written directly") + override fun visitPathStep(node: Path.Step, head: SqlBlock): SqlBlock = error("path step cannot be written directly") - override fun visitPathStepSymbol(node: Path.Step.Symbol, head: SqlBlock) = visitPathStep(node, head) + override fun visitPathStepSymbol(node: Path.Step.Symbol, head: SqlBlock): SqlBlock = visitPathStep(node, head) - override fun visitPathStepIndex(node: Path.Step.Index, head: SqlBlock) = visitPathStep(node, head) + override fun visitPathStepIndex(node: Path.Step.Index, head: SqlBlock): SqlBlock = visitPathStep(node, head) // TYPES - override fun visitTypeNullType(node: Type.NullType, head: SqlBlock) = head concat r("NULL") + override fun visitTypeNullType(node: Type.NullType, head: SqlBlock): SqlBlock = head concat r("NULL") - override fun visitTypeMissing(node: Type.Missing, head: SqlBlock) = head concat r("MISSING") + override fun visitTypeMissing(node: Type.Missing, head: SqlBlock): SqlBlock = head concat r("MISSING") - override fun visitTypeBool(node: Type.Bool, head: SqlBlock) = head concat r("BOOL") + override fun visitTypeBool(node: Type.Bool, head: SqlBlock): SqlBlock = head concat r("BOOL") - override fun visitTypeTinyint(node: Type.Tinyint, head: SqlBlock) = head concat r("TINYINT") + override fun visitTypeTinyint(node: Type.Tinyint, head: SqlBlock): SqlBlock = head concat r("TINYINT") - override fun visitTypeSmallint(node: Type.Smallint, head: SqlBlock) = head concat r("SMALLINT") + override fun visitTypeSmallint(node: Type.Smallint, head: SqlBlock): SqlBlock = head concat r("SMALLINT") - override fun visitTypeInt2(node: Type.Int2, head: SqlBlock) = head concat r("INT2") + override fun visitTypeInt2(node: Type.Int2, head: SqlBlock): SqlBlock = head concat r("INT2") - override fun visitTypeInt4(node: Type.Int4, head: SqlBlock) = head concat r("INT4") + override fun visitTypeInt4(node: Type.Int4, head: SqlBlock): SqlBlock = head concat r("INT4") - override fun visitTypeBigint(node: Type.Bigint, head: SqlBlock) = head concat r("BIGINT") + override fun visitTypeBigint(node: Type.Bigint, head: SqlBlock): SqlBlock = head concat r("BIGINT") - override fun visitTypeInt8(node: Type.Int8, head: SqlBlock) = head concat r("INT8") + override fun visitTypeInt8(node: Type.Int8, head: SqlBlock): SqlBlock = head concat r("INT8") - override fun visitTypeInt(node: Type.Int, head: SqlBlock) = head concat r("INT") + override fun visitTypeInt(node: Type.Int, head: SqlBlock): SqlBlock = head concat r("INT") - override fun visitTypeReal(node: Type.Real, head: SqlBlock) = head concat r("REAL") + override fun visitTypeReal(node: Type.Real, head: SqlBlock): SqlBlock = head concat r("REAL") - override fun visitTypeFloat32(node: Type.Float32, head: SqlBlock) = head concat r("FLOAT32") + override fun visitTypeFloat32(node: Type.Float32, head: SqlBlock): SqlBlock = head concat r("FLOAT32") - override fun visitTypeFloat64(node: Type.Float64, head: SqlBlock) = head concat r("DOUBLE PRECISION") + override fun visitTypeFloat64(node: Type.Float64, head: SqlBlock): SqlBlock = head concat r("DOUBLE PRECISION") - override fun visitTypeDecimal(node: Type.Decimal, head: SqlBlock) = + override fun visitTypeDecimal(node: Type.Decimal, head: SqlBlock): SqlBlock = head concat type("DECIMAL", node.precision, node.scale) - override fun visitTypeNumeric(node: Type.Numeric, head: SqlBlock) = + override fun visitTypeNumeric(node: Type.Numeric, head: SqlBlock): SqlBlock = head concat type("NUMERIC", node.precision, node.scale) - override fun visitTypeChar(node: Type.Char, head: SqlBlock) = head concat type("CHAR", node.length) + override fun visitTypeChar(node: Type.Char, head: SqlBlock): SqlBlock = head concat type("CHAR", node.length) - override fun visitTypeVarchar(node: Type.Varchar, head: SqlBlock) = head concat type("VARCHAR", node.length) + override fun visitTypeVarchar(node: Type.Varchar, head: SqlBlock): SqlBlock = head concat type("VARCHAR", node.length) - override fun visitTypeString(node: Type.String, head: SqlBlock) = head concat r("STRING") + override fun visitTypeString(node: Type.String, head: SqlBlock): SqlBlock = head concat r("STRING") - override fun visitTypeSymbol(node: Type.Symbol, head: SqlBlock) = head concat r("SYMBOL") + override fun visitTypeSymbol(node: Type.Symbol, head: SqlBlock): SqlBlock = head concat r("SYMBOL") - override fun visitTypeBit(node: Type.Bit, head: SqlBlock) = head concat type("BIT", node.length) + override fun visitTypeBit(node: Type.Bit, head: SqlBlock): SqlBlock = head concat type("BIT", node.length) - override fun visitTypeBitVarying(node: Type.BitVarying, head: SqlBlock) = head concat type("BINARY", node.length) + override fun visitTypeBitVarying(node: Type.BitVarying, head: SqlBlock): SqlBlock = head concat type("BINARY", node.length) - override fun visitTypeByteString(node: Type.ByteString, head: SqlBlock) = head concat type("BYTE", node.length) + override fun visitTypeByteString(node: Type.ByteString, head: SqlBlock): SqlBlock = head concat type("BYTE", node.length) - override fun visitTypeBlob(node: Type.Blob, head: SqlBlock) = head concat type("BLOB", node.length) + override fun visitTypeBlob(node: Type.Blob, head: SqlBlock): SqlBlock = head concat type("BLOB", node.length) - override fun visitTypeClob(node: Type.Clob, head: SqlBlock) = head concat type("CLOB", node.length) + override fun visitTypeClob(node: Type.Clob, head: SqlBlock): SqlBlock = head concat type("CLOB", node.length) - override fun visitTypeBag(node: Type.Bag, head: SqlBlock) = head concat r("BAG") + override fun visitTypeBag(node: Type.Bag, head: SqlBlock): SqlBlock = head concat r("BAG") - override fun visitTypeList(node: Type.List, head: SqlBlock) = head concat r("LIST") + override fun visitTypeList(node: Type.List, head: SqlBlock): SqlBlock = head concat r("LIST") - override fun visitTypeSexp(node: Type.Sexp, head: SqlBlock) = head concat r("SEXP") + override fun visitTypeSexp(node: Type.Sexp, head: SqlBlock): SqlBlock = head concat r("SEXP") - override fun visitTypeTuple(node: Type.Tuple, head: SqlBlock) = head concat r("TUPLE") + override fun visitTypeTuple(node: Type.Tuple, head: SqlBlock): SqlBlock = head concat r("TUPLE") - override fun visitTypeStruct(node: Type.Struct, head: SqlBlock) = head concat r("STRUCT") + override fun visitTypeStruct(node: Type.Struct, head: SqlBlock): SqlBlock = head concat r("STRUCT") - override fun visitTypeAny(node: Type.Any, head: SqlBlock) = head concat r("ANY") + override fun visitTypeAny(node: Type.Any, head: SqlBlock): SqlBlock = head concat r("ANY") - override fun visitTypeDate(node: Type.Date, head: SqlBlock) = head concat r("DATE") + override fun visitTypeDate(node: Type.Date, head: SqlBlock): SqlBlock = head concat r("DATE") override fun visitTypeTime(node: Type.Time, head: SqlBlock): SqlBlock = head concat type("TIME", node.precision) - override fun visitTypeTimeWithTz(node: Type.TimeWithTz, head: SqlBlock) = + override fun visitTypeTimeWithTz(node: Type.TimeWithTz, head: SqlBlock): SqlBlock = head concat type("TIME WITH TIMEZONE", node.precision, gap = true) - override fun visitTypeTimestamp(node: Type.Timestamp, head: SqlBlock) = + override fun visitTypeTimestamp(node: Type.Timestamp, head: SqlBlock): SqlBlock = head concat type("TIMESTAMP", node.precision) - override fun visitTypeTimestampWithTz(node: Type.TimestampWithTz, head: SqlBlock) = + override fun visitTypeTimestampWithTz(node: Type.TimestampWithTz, head: SqlBlock): SqlBlock = head concat type("TIMESTAMP WITH TIMEZONE", node.precision, gap = true) - override fun visitTypeInterval(node: Type.Interval, head: SqlBlock) = head concat type("INTERVAL", node.precision) + override fun visitTypeInterval(node: Type.Interval, head: SqlBlock): SqlBlock = head concat type("INTERVAL", node.precision) // unsupported - override fun visitTypeCustom(node: Type.Custom, head: SqlBlock) = defaultReturn(node, head) + override fun visitTypeCustom(node: Type.Custom, head: SqlBlock): SqlBlock = defaultReturn(node, head) // Expressions @@ -230,13 +231,15 @@ public abstract class SqlDialect : AstBaseVisitor() { override fun visitExprUnary(node: Expr.Unary, head: SqlBlock): SqlBlock { val op = when (node.op) { - Expr.Unary.Op.NOT -> "NOT " - Expr.Unary.Op.POS -> "+" - Expr.Unary.Op.NEG -> "-" + Expr.Unary.Op.NOT -> "NOT (" + Expr.Unary.Op.POS -> "+(" + Expr.Unary.Op.NEG -> "-(" } var h = head h = h concat r(op) - return visitExprWrapped(node.expr, h) + h = visitExprWrapped(node.expr, h) + h = h concat r(")") + return h } override fun visitExprBinary(node: Expr.Binary, head: SqlBlock): SqlBlock { @@ -274,7 +277,7 @@ public abstract class SqlDialect : AstBaseVisitor() { return h } - override fun visitExprSessionAttribute(node: Expr.SessionAttribute, head: SqlBlock) = + override fun visitExprSessionAttribute(node: Expr.SessionAttribute, head: SqlBlock): SqlBlock = head concat r(node.attribute.name) override fun visitExprPath(node: Expr.Path, head: SqlBlock): SqlBlock { @@ -283,7 +286,7 @@ public abstract class SqlDialect : AstBaseVisitor() { return h } - override fun visitExprPathStepSymbol(node: Expr.Path.Step.Symbol, head: SqlBlock) = + override fun visitExprPathStepSymbol(node: Expr.Path.Step.Symbol, head: SqlBlock): SqlBlock = head concat r(".${node.symbol.sql()}") override fun visitExprPathStepIndex(node: Expr.Path.Step.Index, head: SqlBlock): SqlBlock { @@ -296,9 +299,9 @@ public abstract class SqlDialect : AstBaseVisitor() { return h } - override fun visitExprPathStepWildcard(node: Expr.Path.Step.Wildcard, head: SqlBlock) = head concat r("[*]") + override fun visitExprPathStepWildcard(node: Expr.Path.Step.Wildcard, head: SqlBlock): SqlBlock = head concat r("[*]") - override fun visitExprPathStepUnpivot(node: Expr.Path.Step.Unpivot, head: SqlBlock) = head concat r(".*") + override fun visitExprPathStepUnpivot(node: Expr.Path.Step.Unpivot, head: SqlBlock): SqlBlock = head concat r(".*") override fun visitExprCall(node: Expr.Call, head: SqlBlock): SqlBlock { var h = head @@ -320,11 +323,11 @@ public abstract class SqlDialect : AstBaseVisitor() { return h } - override fun visitExprParameter(node: Expr.Parameter, head: SqlBlock) = head concat r("?") + override fun visitExprParameter(node: Expr.Parameter, head: SqlBlock): SqlBlock = head concat r("?") - override fun visitExprValues(node: Expr.Values, head: SqlBlock) = head concat list("VALUES (") { node.rows } + override fun visitExprValues(node: Expr.Values, head: SqlBlock): SqlBlock = head concat list("VALUES (") { node.rows } - override fun visitExprValuesRow(node: Expr.Values.Row, head: SqlBlock) = head concat list { node.items } + override fun visitExprValuesRow(node: Expr.Values.Row, head: SqlBlock): SqlBlock = head concat list { node.items } override fun visitExprCollection(node: Expr.Collection, head: SqlBlock): SqlBlock { val (start, end) = when (node.type) { @@ -337,7 +340,7 @@ public abstract class SqlDialect : AstBaseVisitor() { return head concat list(start, end) { node.values } } - override fun visitExprStruct(node: Expr.Struct, head: SqlBlock) = head concat list("{", "}") { node.fields } + override fun visitExprStruct(node: Expr.Struct, head: SqlBlock): SqlBlock = head concat list("{", "}") { node.fields } override fun visitExprStructField(node: Expr.Struct.Field, head: SqlBlock): SqlBlock { var h = head @@ -698,7 +701,7 @@ public abstract class SqlDialect : AstBaseVisitor() { // LET - override fun visitLet(node: Let, head: SqlBlock) = head concat list("LET ", "") { node.bindings } + override fun visitLet(node: Let, head: SqlBlock): SqlBlock = head concat list("LET ", "") { node.bindings } override fun visitLetBinding(node: Let.Binding, head: SqlBlock): SqlBlock { var h = head @@ -750,7 +753,7 @@ public abstract class SqlDialect : AstBaseVisitor() { // ORDER BY - override fun visitOrderBy(node: OrderBy, head: SqlBlock) = head concat list("ORDER BY ", "") { node.sorts } + override fun visitOrderBy(node: OrderBy, head: SqlBlock): SqlBlock = head concat list("ORDER BY ", "") { node.sorts } override fun visitSort(node: Sort, head: SqlBlock): SqlBlock { var h = head diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlLayout.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlLayout.kt index 4bfd1dae41..013a332225 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlLayout.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlLayout.kt @@ -3,6 +3,7 @@ package org.partiql.ast.sql /** * [SqlLayout] determines how an [SqlBlock] tree is transformed in SQL text. */ +@Deprecated("This will be changed in the next major version") public abstract class SqlLayout { abstract val indent: Indent diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/internal/InternalSqlBlock.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/internal/InternalSqlBlock.kt new file mode 100644 index 0000000000..272ee0631a --- /dev/null +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/internal/InternalSqlBlock.kt @@ -0,0 +1,58 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.ast.sql.internal + +/** + * Representation of some textual elements as a token (singly-linked) list. + */ +internal sealed class InternalSqlBlock { + + /** + * Next token (if any) in the list. + */ + internal var next: InternalSqlBlock? = null + + /** + * A newline / link break token. + */ + internal class NL : InternalSqlBlock() + + /** + * A raw text token. Cannot be broken. + */ + internal class Text(val text: String) : InternalSqlBlock() + + /** + * A nest token representing a (possible indented) token sublist. + * + * @property prefix A prefix character such as '{', '(', or '['. + * @property postfix A postfix character such as '}', ')', or ']]. + * @property child + */ + internal class Nest( + val prefix: String?, + val postfix: String?, + val child: InternalSqlBlock, + ) : InternalSqlBlock() + + companion object { + + /** + * Helper function to create root node (empty). + */ + @JvmStatic + internal fun root(): InternalSqlBlock = Text("") + } +} diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/internal/InternalSqlDialect.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/internal/InternalSqlDialect.kt new file mode 100644 index 0000000000..4ecfb9d569 --- /dev/null +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/internal/InternalSqlDialect.kt @@ -0,0 +1,852 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.ast.sql.internal + +import org.partiql.ast.AstNode +import org.partiql.ast.Exclude +import org.partiql.ast.Expr +import org.partiql.ast.From +import org.partiql.ast.GroupBy +import org.partiql.ast.Identifier +import org.partiql.ast.Let +import org.partiql.ast.OrderBy +import org.partiql.ast.Path +import org.partiql.ast.Select +import org.partiql.ast.SetOp +import org.partiql.ast.SetQuantifier +import org.partiql.ast.Sort +import org.partiql.ast.Statement +import org.partiql.ast.Type +import org.partiql.ast.visitor.AstBaseVisitor +import org.partiql.value.MissingValue +import org.partiql.value.NullValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.io.PartiQLValueTextWriter +import java.io.ByteArrayOutputStream +import java.io.PrintStream + +/** + * SqlDialect represents the base behavior for transforming an [AstNode] tree into a [InternalSqlBlock] tree. + */ +@Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") +internal abstract class InternalSqlDialect : AstBaseVisitor() { + + /** + * Default entry-point, can also be us. + */ + internal fun apply(node: AstNode): InternalSqlBlock { + val head = InternalSqlBlock.root() + val tail = head + node.accept(this, tail) + return head + } + + internal companion object { + + @JvmStatic + val PARTIQL = object : InternalSqlDialect() {} + } + + override fun defaultReturn(node: AstNode, tail: InternalSqlBlock): InternalSqlBlock = + throw UnsupportedOperationException("Cannot print $node") + + // STATEMENTS + + override fun visitStatementQuery(node: Statement.Query, tail: InternalSqlBlock): InternalSqlBlock = visitExpr(node.expr, tail) + + // IDENTIFIERS & PATHS + + /** + * Default behavior is to wrap all SFW queries with parentheses. + * + * @param node + * @param tail + */ + open fun visitExprWrapped(node: Expr, tail: InternalSqlBlock): InternalSqlBlock = when (node) { + is Expr.SFW -> { + var t = tail + t = t concat "(" + t = visitExprSFW(node, t) + t = t concat ")" + t + } + else -> visitExpr(node, tail) + } + + override fun visitIdentifierSymbol(node: Identifier.Symbol, tail: InternalSqlBlock): InternalSqlBlock = tail concat node.sql() + + override fun visitIdentifierQualified(node: Identifier.Qualified, tail: InternalSqlBlock): InternalSqlBlock { + val path = node.steps.fold(node.root.sql()) { p, step -> p + "." + step.sql() } + return tail concat path + } + + override fun visitPath(node: Path, tail: InternalSqlBlock): InternalSqlBlock { + val path = node.steps.fold(node.root.sql()) { p, step -> + when (step) { + is Path.Step.Index -> p + "[${step.index}]" + is Path.Step.Symbol -> p + "." + step.symbol.sql() + } + } + return tail concat path + } + + override fun visitExclude(node: Exclude, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = t concat " EXCLUDE " + t = t concat list(start = null, end = null) { node.items } + return t + } + + override fun visitExcludeItem(node: Exclude.Item, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = visitExprVar(node.root, t) + t = t concat list(delimiter = null, start = null, end = null) { node.steps } + return t + } + + override fun visitExcludeStepCollIndex(node: Exclude.Step.CollIndex, tail: InternalSqlBlock): InternalSqlBlock { + return tail concat "[${node.index}]" + } + + override fun visitExcludeStepStructWildcard(node: Exclude.Step.StructWildcard, tail: InternalSqlBlock): InternalSqlBlock { + return tail concat ".*" + } + + override fun visitExcludeStepStructField(node: Exclude.Step.StructField, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail concat "." + t = visitIdentifierSymbol(node.symbol, t) + return t + } + + override fun visitExcludeStepCollWildcard(node: Exclude.Step.CollWildcard, tail: InternalSqlBlock): InternalSqlBlock { + return tail concat "[*]" + } + + // cannot write path step outside the context of a path as we don't want it to reflow + override fun visitPathStep(node: Path.Step, tail: InternalSqlBlock): InternalSqlBlock = + error("path step cannot be written directly") + + override fun visitPathStepSymbol(node: Path.Step.Symbol, tail: InternalSqlBlock): InternalSqlBlock = visitPathStep(node, tail) + + override fun visitPathStepIndex(node: Path.Step.Index, tail: InternalSqlBlock): InternalSqlBlock = visitPathStep(node, tail) + + // TYPES + + override fun visitTypeNullType(node: Type.NullType, tail: InternalSqlBlock): InternalSqlBlock = tail concat "NULL" + + override fun visitTypeMissing(node: Type.Missing, tail: InternalSqlBlock): InternalSqlBlock = tail concat "MISSING" + + override fun visitTypeBool(node: Type.Bool, tail: InternalSqlBlock): InternalSqlBlock = tail concat "BOOL" + + override fun visitTypeTinyint(node: Type.Tinyint, tail: InternalSqlBlock): InternalSqlBlock = tail concat "TINYINT" + + override fun visitTypeSmallint(node: Type.Smallint, tail: InternalSqlBlock): InternalSqlBlock = tail concat "SMALLINT" + + override fun visitTypeInt2(node: Type.Int2, tail: InternalSqlBlock): InternalSqlBlock = tail concat "INT2" + + override fun visitTypeInt4(node: Type.Int4, tail: InternalSqlBlock): InternalSqlBlock = tail concat "INT4" + + override fun visitTypeBigint(node: Type.Bigint, tail: InternalSqlBlock): InternalSqlBlock = tail concat "BIGINT" + + override fun visitTypeInt8(node: Type.Int8, tail: InternalSqlBlock): InternalSqlBlock = tail concat "INT8" + + override fun visitTypeInt(node: Type.Int, tail: InternalSqlBlock): InternalSqlBlock = tail concat "INT" + + override fun visitTypeReal(node: Type.Real, tail: InternalSqlBlock): InternalSqlBlock = tail concat "REAL" + + override fun visitTypeFloat32(node: Type.Float32, tail: InternalSqlBlock): InternalSqlBlock = tail concat "FLOAT32" + + override fun visitTypeFloat64(node: Type.Float64, tail: InternalSqlBlock): InternalSqlBlock = tail concat "DOUBLE PRECISION" + + override fun visitTypeDecimal(node: Type.Decimal, tail: InternalSqlBlock): InternalSqlBlock = + tail concat type("DECIMAL", node.precision, node.scale) + + override fun visitTypeNumeric(node: Type.Numeric, tail: InternalSqlBlock): InternalSqlBlock = + tail concat type("NUMERIC", node.precision, node.scale) + + override fun visitTypeChar(node: Type.Char, tail: InternalSqlBlock): InternalSqlBlock = tail concat type("CHAR", node.length) + + override fun visitTypeVarchar(node: Type.Varchar, tail: InternalSqlBlock): InternalSqlBlock = + tail concat type("VARCHAR", node.length) + + override fun visitTypeString(node: Type.String, tail: InternalSqlBlock): InternalSqlBlock = tail concat "STRING" + + override fun visitTypeSymbol(node: Type.Symbol, tail: InternalSqlBlock): InternalSqlBlock = tail concat "SYMBOL" + + override fun visitTypeBit(node: Type.Bit, tail: InternalSqlBlock): InternalSqlBlock = tail concat type("BIT", node.length) + + override fun visitTypeBitVarying(node: Type.BitVarying, tail: InternalSqlBlock): InternalSqlBlock = + tail concat type("BINARY", node.length) + + override fun visitTypeByteString(node: Type.ByteString, tail: InternalSqlBlock): InternalSqlBlock = + tail concat type("BYTE", node.length) + + override fun visitTypeBlob(node: Type.Blob, tail: InternalSqlBlock): InternalSqlBlock = tail concat type("BLOB", node.length) + + override fun visitTypeClob(node: Type.Clob, tail: InternalSqlBlock): InternalSqlBlock = tail concat type("CLOB", node.length) + + override fun visitTypeBag(node: Type.Bag, tail: InternalSqlBlock): InternalSqlBlock = tail concat "BAG" + + override fun visitTypeList(node: Type.List, tail: InternalSqlBlock): InternalSqlBlock = tail concat "LIST" + + override fun visitTypeSexp(node: Type.Sexp, tail: InternalSqlBlock): InternalSqlBlock = tail concat "SEXP" + + override fun visitTypeTuple(node: Type.Tuple, tail: InternalSqlBlock): InternalSqlBlock = tail concat "TUPLE" + + override fun visitTypeStruct(node: Type.Struct, tail: InternalSqlBlock): InternalSqlBlock = tail concat "STRUCT" + + override fun visitTypeAny(node: Type.Any, tail: InternalSqlBlock): InternalSqlBlock = tail concat "ANY" + + override fun visitTypeDate(node: Type.Date, tail: InternalSqlBlock): InternalSqlBlock = tail concat "DATE" + + override fun visitTypeTime(node: Type.Time, tail: InternalSqlBlock): InternalSqlBlock = tail concat type("TIME", node.precision) + + override fun visitTypeTimeWithTz(node: Type.TimeWithTz, tail: InternalSqlBlock): InternalSqlBlock = + tail concat type("TIME WITH TIMEZONE", node.precision, gap = true) + + override fun visitTypeTimestamp(node: Type.Timestamp, tail: InternalSqlBlock): InternalSqlBlock = + tail concat type("TIMESTAMP", node.precision) + + override fun visitTypeTimestampWithTz(node: Type.TimestampWithTz, tail: InternalSqlBlock): InternalSqlBlock = + tail concat type("TIMESTAMP WITH TIMEZONE", node.precision, gap = true) + + override fun visitTypeInterval(node: Type.Interval, tail: InternalSqlBlock): InternalSqlBlock = + tail concat type("INTERVAL", node.precision) + + // unsupported + override fun visitTypeCustom(node: Type.Custom, tail: InternalSqlBlock): InternalSqlBlock = defaultReturn(node, tail) + + // Expressions + + @OptIn(PartiQLValueExperimental::class) + override fun visitExprLit(node: Expr.Lit, tail: InternalSqlBlock): InternalSqlBlock { + // Simplified PartiQL Value writing, as this intentionally omits formatting + val value = when (node.value) { + is MissingValue -> "MISSING" // force uppercase + is NullValue -> "NULL" // force uppercase + else -> { + val buffer = ByteArrayOutputStream() + val valueWriter = PartiQLValueTextWriter(PrintStream(buffer), false) + valueWriter.append(node.value) + buffer.toString() + } + } + return tail concat value + } + + override fun visitExprIon(node: Expr.Ion, tail: InternalSqlBlock): InternalSqlBlock { + // simplified Ion value writing, as this intentionally omits formatting + val value = node.value.toString() + return tail concat "`$value`" + } + + override fun visitExprUnary(node: Expr.Unary, tail: InternalSqlBlock): InternalSqlBlock { + val op = when (node.op) { + Expr.Unary.Op.NOT -> "NOT (" + Expr.Unary.Op.POS -> "+(" + Expr.Unary.Op.NEG -> "-(" + } + var t = tail + t = t concat op + t = visitExprWrapped(node.expr, t) + t = t concat ")" + return t + } + + override fun visitExprBinary(node: Expr.Binary, tail: InternalSqlBlock): InternalSqlBlock { + val op = when (node.op) { + Expr.Binary.Op.PLUS -> "+" + Expr.Binary.Op.MINUS -> "-" + Expr.Binary.Op.TIMES -> "*" + Expr.Binary.Op.DIVIDE -> "/" + Expr.Binary.Op.MODULO -> "%" + Expr.Binary.Op.CONCAT -> "||" + Expr.Binary.Op.AND -> "AND" + Expr.Binary.Op.OR -> "OR" + Expr.Binary.Op.EQ -> "=" + Expr.Binary.Op.NE -> "<>" + Expr.Binary.Op.GT -> ">" + Expr.Binary.Op.GTE -> ">=" + Expr.Binary.Op.LT -> "<" + Expr.Binary.Op.LTE -> "<=" + Expr.Binary.Op.BITWISE_AND -> "&" + } + var t = tail + t = visitExprWrapped(node.lhs, t) + t = t concat " $op " + t = visitExprWrapped(node.rhs, t) + return t + } + + override fun visitExprVar(node: Expr.Var, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + // Prepend @ + if (node.scope == Expr.Var.Scope.LOCAL) { + t = t concat "@" + } + t = visitIdentifier(node.identifier, t) + return t + } + + override fun visitExprSessionAttribute(node: Expr.SessionAttribute, tail: InternalSqlBlock): InternalSqlBlock = + tail concat node.attribute.name + + override fun visitExprPath(node: Expr.Path, tail: InternalSqlBlock): InternalSqlBlock { + var t = visitExprWrapped(node.root, tail) + t = node.steps.fold(t) { b, step -> visitExprPathStep(step, b) } + return t + } + + override fun visitExprPathStepSymbol(node: Expr.Path.Step.Symbol, tail: InternalSqlBlock): InternalSqlBlock = + tail concat ".${node.symbol.sql()}" + + override fun visitExprPathStepIndex(node: Expr.Path.Step.Index, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + val key = node.key + // use [ ] syntax + t = t concat "[" + t = visitExprWrapped(key, t) + t = t concat "]" + return t + } + + override fun visitExprPathStepWildcard(node: Expr.Path.Step.Wildcard, tail: InternalSqlBlock): InternalSqlBlock = tail concat "[*]" + + override fun visitExprPathStepUnpivot(node: Expr.Path.Step.Unpivot, tail: InternalSqlBlock): InternalSqlBlock = tail concat ".*" + + override fun visitExprCall(node: Expr.Call, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = visitIdentifier(node.function, t) + t = t concat list { node.args } + return t + } + + override fun visitExprAgg(node: Expr.Agg, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + val f = node.function + // Special case + if (f is Identifier.Symbol && f.symbol == "COUNT_STAR") { + return t concat "COUNT(*)" + } + val start = if (node.setq != null) "(${node.setq!!.name} " else "(" + t = visitIdentifier(f, t) + t = t concat list(start) { node.args } + return t + } + + override fun visitExprParameter(node: Expr.Parameter, tail: InternalSqlBlock): InternalSqlBlock = tail concat "?" + + override fun visitExprValues(node: Expr.Values, tail: InternalSqlBlock): InternalSqlBlock = + tail concat list("VALUES (") { node.rows } + + override fun visitExprValuesRow(node: Expr.Values.Row, tail: InternalSqlBlock): InternalSqlBlock = tail concat list { node.items } + + override fun visitExprCollection(node: Expr.Collection, tail: InternalSqlBlock): InternalSqlBlock { + val (start, end) = when (node.type) { + Expr.Collection.Type.BAG -> "<<" to ">>" + Expr.Collection.Type.ARRAY -> "[" to "]" + Expr.Collection.Type.VALUES -> "VALUES (" to ")" + Expr.Collection.Type.LIST -> "(" to ")" + Expr.Collection.Type.SEXP -> "SEXP (" to ")" + } + return tail concat list(start, end) { node.values } + } + + override fun visitExprStruct(node: Expr.Struct, tail: InternalSqlBlock): InternalSqlBlock = + tail concat list("{", "}") { node.fields } + + override fun visitExprStructField(node: Expr.Struct.Field, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = visitExprWrapped(node.name, t) + t = t concat ": " + t = visitExprWrapped(node.value, t) + return t + } + + override fun visitExprLike(node: Expr.Like, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = visitExprWrapped(node.value, t) + t = t concat if (node.not == true) " NOT LIKE " else " LIKE " + t = visitExprWrapped(node.pattern, t) + if (node.escape != null) { + t = t concat " ESCAPE " + t = visitExprWrapped(node.escape!!, t) + } + return t + } + + override fun visitExprBetween(node: Expr.Between, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = visitExprWrapped(node.value, t) + t = t concat if (node.not == true) " NOT BETWEEN " else " BETWEEN " + t = visitExprWrapped(node.from, t) + t = t concat " AND " + t = visitExprWrapped(node.to, t) + return t + } + + override fun visitExprInCollection(node: Expr.InCollection, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = visitExprWrapped(node.lhs, t) + t = t concat if (node.not == true) " NOT IN " else " IN " + t = visitExprWrapped(node.rhs, t) + return t + } + + override fun visitExprIsType(node: Expr.IsType, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = visitExprWrapped(node.value, t) + t = t concat if (node.not == true) " IS NOT " else " IS " + t = visitType(node.type, t) + return t + } + + override fun visitExprCase(node: Expr.Case, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = t concat "CASE" + t = when (node.expr) { + null -> t + else -> visitExprWrapped(node.expr!!, t concat " ") + } + // WHEN(s) + t = node.branches.fold(t) { acc, branch -> visitExprCaseBranch(branch, acc) } + // ELSE + t = when (node.default) { + null -> t + else -> { + t = t concat " ELSE " + visitExprWrapped(node.default!!, t) + } + } + t = t concat " END" + return t + } + + override fun visitExprCaseBranch(node: Expr.Case.Branch, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = t concat " WHEN " + t = visitExprWrapped(node.condition, t) + t = t concat " THEN " + t = visitExprWrapped(node.expr, t) + return t + } + + override fun visitExprCoalesce(node: Expr.Coalesce, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = t concat "COALESCE" + t = t concat list { node.args } + return t + } + + override fun visitExprNullIf(node: Expr.NullIf, tail: InternalSqlBlock): InternalSqlBlock { + val args = listOf(node.value, node.nullifier) + var t = tail + t = t concat "NULLIF" + t = t concat list { args } + return t + } + + override fun visitExprSubstring(node: Expr.Substring, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = t concat "SUBSTRING(" + t = visitExprWrapped(node.value, t) + if (node.start != null) { + t = t concat " FROM " + t = visitExprWrapped(node.start!!, t) + } + if (node.length != null) { + t = t concat " FOR " + t = visitExprWrapped(node.length!!, t) + } + t = t concat ")" + return t + } + + override fun visitExprPosition(node: Expr.Position, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = t concat "POSITION(" + t = visitExprWrapped(node.lhs, t) + t = t concat " IN " + t = visitExprWrapped(node.rhs, t) + t = t concat ")" + return t + } + + override fun visitExprTrim(node: Expr.Trim, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = t concat "TRIM(" + // [LEADING|TRAILING|BOTH] + if (node.spec != null) { + t = t concat "${node.spec!!.name} " + } + // [ FROM] + if (node.chars != null) { + t = visitExprWrapped(node.chars!!, t) + t = t concat " FROM " + } + t = visitExprWrapped(node.value, t) + t = t concat ")" + return t + } + + override fun visitExprOverlay(node: Expr.Overlay, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = t concat "OVERLAY(" + t = visitExprWrapped(node.value, t) + t = t concat " PLACING " + t = visitExprWrapped(node.overlay, t) + t = t concat " FROM " + t = visitExprWrapped(node.start, t) + if (node.length != null) { + t = t concat " FOR " + t = visitExprWrapped(node.length!!, t) + } + t = t concat ")" + return t + } + + override fun visitExprExtract(node: Expr.Extract, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = t concat "EXTRACT(" + t = t concat node.field.name + t = t concat " FROM " + t = visitExprWrapped(node.source, t) + t = t concat ")" + return t + } + + override fun visitExprCast(node: Expr.Cast, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = t concat "CAST(" + t = visitExprWrapped(node.value, t) + t = t concat " AS " + t = visitType(node.asType, t) + t = t concat ")" + return t + } + + override fun visitExprCanCast(node: Expr.CanCast, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = t concat "CAN_CAST(" + t = visitExprWrapped(node.value, t) + t = t concat " AS " + t = visitType(node.asType, t) + t = t concat ")" + return t + } + + override fun visitExprCanLosslessCast(node: Expr.CanLosslessCast, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = t concat "CAN_LOSSLESS_CAST(" + t = visitExprWrapped(node.value, t) + t = t concat " AS " + t = visitType(node.asType, t) + t = t concat ")" + return t + } + + override fun visitExprDateAdd(node: Expr.DateAdd, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = t concat "DATE_ADD(" + t = t concat node.field.name + t = t concat ", " + t = visitExprWrapped(node.lhs, t) + t = t concat ", " + t = visitExprWrapped(node.rhs, t) + t = t concat ")" + return t + } + + override fun visitExprDateDiff(node: Expr.DateDiff, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = t concat "DATE_DIFF(" + t = t concat node.field.name + t = t concat ", " + t = visitExprWrapped(node.lhs, t) + t = t concat ", " + t = visitExprWrapped(node.rhs, t) + t = t concat ")" + return t + } + + override fun visitExprBagOp(node: Expr.BagOp, tail: InternalSqlBlock): InternalSqlBlock { + // [OUTER] [UNION|INTERSECT|EXCEPT] [ALL|DISTINCT] + val op = mutableListOf() + when (node.outer) { + true -> op.add("OUTER") + else -> {} + } + when (node.type.type) { + SetOp.Type.UNION -> op.add("UNION") + SetOp.Type.INTERSECT -> op.add("INTERSECT") + SetOp.Type.EXCEPT -> op.add("EXCEPT") + } + when (node.type.setq) { + SetQuantifier.ALL -> op.add("ALL") + SetQuantifier.DISTINCT -> op.add("DISTINCT") + null -> {} + } + var t = tail + t = visitExprWrapped(node.lhs, t) + t = t concat " ${op.joinToString(" ")} " + t = visitExprWrapped(node.rhs, t) + return t + } + + // SELECT-FROM-WHERE + + override fun visitExprSFW(node: Expr.SFW, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + // SELECT + t = visit(node.select, t) + // EXCLUDE + t = node.exclude?.let { visit(it, t) } ?: t + // FROM + t = visit(node.from, t concat " FROM ") + // LET + t = if (node.let != null) visitLet(node.let!!, t concat " ") else t + // WHERE + t = if (node.where != null) visitExprWrapped(node.where!!, t concat " WHERE ") else t + // GROUP BY + t = if (node.groupBy != null) visitGroupBy(node.groupBy!!, t concat " ") else t + // HAVING + t = if (node.having != null) visitExprWrapped(node.having!!, t concat " HAVING ") else t + // SET OP + t = if (node.setOp != null) visitExprSFWSetOp(node.setOp!!, t concat " ") else t + // ORDER BY + t = if (node.orderBy != null) visitOrderBy(node.orderBy!!, t concat " ") else t + // LIMIT + t = if (node.limit != null) visitExprWrapped(node.limit!!, t concat " LIMIT ") else t + // OFFSET + t = if (node.offset != null) visitExprWrapped(node.offset!!, t concat " OFFSET ") else t + return t + } + + // SELECT + + override fun visitSelectStar(node: Select.Star, tail: InternalSqlBlock): InternalSqlBlock { + val select = when (node.setq) { + SetQuantifier.ALL -> "SELECT ALL *" + SetQuantifier.DISTINCT -> "SELECT DISTINCT *" + null -> "SELECT *" + } + return tail concat select + } + + override fun visitSelectProject(node: Select.Project, tail: InternalSqlBlock): InternalSqlBlock { + val select = when (node.setq) { + SetQuantifier.ALL -> "SELECT ALL " + SetQuantifier.DISTINCT -> "SELECT DISTINCT " + null -> "SELECT " + } + return tail concat list(select, "") { node.items } + } + + override fun visitSelectProjectItemAll(node: Select.Project.Item.All, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = visitExprWrapped(node.expr, t) + t = t concat ".*" + return t + } + + override fun visitSelectProjectItemExpression(node: Select.Project.Item.Expression, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = visitExprWrapped(node.expr, t) + t = if (node.asAlias != null) t concat " AS ${node.asAlias!!.sql()}" else t + return t + } + + override fun visitSelectPivot(node: Select.Pivot, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = t concat "PIVOT " + t = visitExprWrapped(node.key, t) + t = t concat " AT " + t = visitExprWrapped(node.value, t) + return t + } + + override fun visitSelectValue(node: Select.Value, tail: InternalSqlBlock): InternalSqlBlock { + val select = when (node.setq) { + SetQuantifier.ALL -> "SELECT ALL VALUE " + SetQuantifier.DISTINCT -> "SELECT DISTINCT VALUE " + null -> "SELECT VALUE " + } + var t = tail + t = t concat select + t = visitExprWrapped(node.constructor, t) + return t + } + + // FROM + + override fun visitFromValue(node: From.Value, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = when (node.type) { + From.Value.Type.SCAN -> t + From.Value.Type.UNPIVOT -> t concat "UNPIVOT " + } + t = visitExprWrapped(node.expr, t) + t = if (node.asAlias != null) t concat " AS ${node.asAlias!!.sql()}" else t + t = if (node.atAlias != null) t concat " AT ${node.atAlias!!.sql()}" else t + t = if (node.byAlias != null) t concat " BY ${node.byAlias!!.sql()}" else t + return t + } + + override fun visitFromJoin(node: From.Join, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = visitFrom(node.lhs, t) + t = t concat when (node.type) { + From.Join.Type.INNER -> " INNER JOIN " + From.Join.Type.LEFT -> " LEFT JOIN " + From.Join.Type.LEFT_OUTER -> " LEFT OUTER JOIN " + From.Join.Type.RIGHT -> " RIGHT JOIN " + From.Join.Type.RIGHT_OUTER -> " RIGHT OUTER JOIN " + From.Join.Type.FULL -> " FULL JOIN " + From.Join.Type.FULL_OUTER -> " FULL OUTER JOIN " + From.Join.Type.CROSS -> " CROSS JOIN " + From.Join.Type.COMMA -> ", " + null -> " JOIN " + } + t = visitFrom(node.rhs, t) + t = if (node.condition != null) visit(node.condition!!, t concat " ON ") else t + return t + } + + // LET + + override fun visitLet(node: Let, tail: InternalSqlBlock): InternalSqlBlock = tail concat list("LET ", "") { node.bindings } + + override fun visitLetBinding(node: Let.Binding, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = visitExprWrapped(node.expr, t) + t = t concat " AS ${node.asAlias.sql()}" + return t + } + + // GROUP BY + + override fun visitGroupBy(node: GroupBy, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = t concat when (node.strategy) { + GroupBy.Strategy.FULL -> "GROUP BY " + GroupBy.Strategy.PARTIAL -> "GROUP PARTIAL BY " + } + t = t concat list("", "") { node.keys } + t = if (node.asAlias != null) t concat " GROUP AS ${node.asAlias!!.sql()}" else t + return t + } + + override fun visitGroupByKey(node: GroupBy.Key, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = visitExprWrapped(node.expr, t) + t = if (node.asAlias != null) t concat " AS ${node.asAlias!!.sql()}" else t + return t + } + + // SET OPERATORS + + override fun visitSetOp(node: SetOp, tail: InternalSqlBlock): InternalSqlBlock { + val op = when (node.setq) { + null -> node.type.name + else -> "${node.type.name} ${node.setq!!.name}" + } + return tail concat op + } + + override fun visitExprSFWSetOp(node: Expr.SFW.SetOp, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = visitSetOp(node.type, t) + t = t concat InternalSqlBlock.Nest( + prefix = " (", + postfix = ")", + child = InternalSqlBlock.root().apply { visitExprSFW(node.operand, this) }, + ) + return t + } + + // ORDER BY + + override fun visitOrderBy(node: OrderBy, tail: InternalSqlBlock): InternalSqlBlock = + tail concat list("ORDER BY ", "") { node.sorts } + + override fun visitSort(node: Sort, tail: InternalSqlBlock): InternalSqlBlock { + var t = tail + t = visitExprWrapped(node.expr, t) + t = when (node.dir) { + Sort.Dir.ASC -> t concat " ASC" + Sort.Dir.DESC -> t concat " DESC" + null -> t + } + t = when (node.nulls) { + Sort.Nulls.FIRST -> t concat " NULLS FIRST" + Sort.Nulls.LAST -> t concat " NULLS LAST" + null -> t + } + return t + } + + // --- Block Constructor Helpers + + private infix fun InternalSqlBlock.concat(rhs: String): InternalSqlBlock { + next = InternalSqlBlock.Text(rhs) + return next!! + } + + private infix fun InternalSqlBlock.concat(rhs: InternalSqlBlock): InternalSqlBlock { + next = rhs + return next!! + } + + private fun type(symbol: String, vararg args: Int?, gap: Boolean = false): InternalSqlBlock { + val p = args.filterNotNull() + val t = when { + p.isEmpty() -> symbol + else -> { + val a = p.joinToString(",") + when (gap) { + true -> "$symbol ($a)" + else -> "$symbol($a)" + } + } + } + // types are modeled as text; as we don't way to reflow + return InternalSqlBlock.Text(t) + } + + private fun list( + start: String? = "(", + end: String? = ")", + delimiter: String? = ", ", + children: () -> List, + ): InternalSqlBlock { + val kids = children() + val h = InternalSqlBlock.root() + var t = h + kids.forEachIndexed { i, child -> + t = child.accept(this, t) + t = if (delimiter != null && (i + 1) < kids.size) t concat delimiter else t + } + return InternalSqlBlock.Nest( + prefix = start, + postfix = end, + child = h, + ) + } + + private fun Identifier.Symbol.sql() = when (caseSensitivity) { + Identifier.CaseSensitivity.SENSITIVE -> "\"$symbol\"" + Identifier.CaseSensitivity.INSENSITIVE -> symbol // verbatim .. + } +} diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/internal/InternalSqlLayout.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/internal/InternalSqlLayout.kt new file mode 100644 index 0000000000..8d7d858c41 --- /dev/null +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/internal/InternalSqlLayout.kt @@ -0,0 +1,36 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.ast.sql.internal + +internal object InternalSqlLayout { + + internal fun format(head: InternalSqlBlock): String { + val sb = StringBuilder() + var curr: InternalSqlBlock? = head + while (curr != null) { + when (curr) { + is InternalSqlBlock.NL -> sb.appendLine() + is InternalSqlBlock.Text -> sb.append(curr.text) + is InternalSqlBlock.Nest -> { + if (curr.prefix != null) sb.append(curr.prefix) + sb.append(format(curr.child)) + if (curr.postfix != null) sb.append(curr.postfix) + } + } + curr = curr.next + } + return sb.toString() + } +} diff --git a/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt b/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt index 28b161bef6..039fbd4112 100644 --- a/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt +++ b/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt @@ -227,24 +227,63 @@ class SqlDialectTest { @JvmStatic fun exprOperators() = listOf( - expect("NOT NULL") { + expect("NOT (NULL)") { exprUnary { op = Expr.Unary.Op.NOT expr = NULL } }, - expect("+NULL") { + expect("+(NULL)") { exprUnary { op = Expr.Unary.Op.POS expr = NULL } }, - expect("-NULL") { + expect("-(NULL)") { exprUnary { op = Expr.Unary.Op.NEG expr = NULL } }, + expect("NOT (NOT (NULL))") { + exprUnary { + op = Expr.Unary.Op.NOT + expr = exprUnary { + op = Expr.Unary.Op.NOT + expr = NULL + } + } + }, + expect("+(+(NULL))") { + exprUnary { + op = Expr.Unary.Op.POS + expr = exprUnary { + op = Expr.Unary.Op.POS + expr = NULL + } + } + }, + expect("-(-(NULL))") { + exprUnary { + op = Expr.Unary.Op.NEG + expr = exprUnary { + op = Expr.Unary.Op.NEG + expr = NULL + } + } + }, + expect("+(-(+(NULL)))") { + exprUnary { + op = Expr.Unary.Op.POS + expr = exprUnary { + op = Expr.Unary.Op.NEG + expr = exprUnary { + op = Expr.Unary.Op.POS + expr = NULL + } + } + } + }, expect("NULL + NULL") { exprBinary { op = Expr.Binary.Op.PLUS diff --git a/partiql-cli/build.gradle.kts b/partiql-cli/build.gradle.kts index c3ba7c3720..562f5a188f 100644 --- a/partiql-cli/build.gradle.kts +++ b/partiql-cli/build.gradle.kts @@ -38,7 +38,7 @@ dependencies { implementation(Deps.joda) implementation(Deps.picoCli) implementation(Deps.kotlinReflect) - + implementation(Deps.kotlinxCoroutines) testImplementation(Deps.mockito) } diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt index 147d335012..ab79fab3d2 100644 --- a/partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt +++ b/partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt @@ -15,14 +15,16 @@ package org.partiql.cli -import AstPrinter import com.amazon.ion.system.IonSystemBuilder +import com.amazon.ion.system.IonTextWriterBuilder import org.partiql.cli.pico.PartiQLCommand import org.partiql.cli.shell.info import org.partiql.lang.eval.EvaluationSession import org.partiql.parser.PartiQLParser +import org.partiql.plan.Statement import org.partiql.plan.debug.PlanPrinter import org.partiql.planner.PartiQLPlanner +import org.partiql.plugins.local.toIon import picocli.CommandLine import java.io.PrintStream import java.nio.file.Paths @@ -78,6 +80,16 @@ object Debug { out.info("-- Plan ----------") PlanPrinter.append(out, result.statement) + when (val plan = result.statement) { + is Statement.Query -> { + out.info("-- Schema ----------") + val outputSchema = java.lang.StringBuilder() + val ionWriter = IonTextWriterBuilder.minimal().withPrettyPrinting().build(outputSchema) + plan.root.type.toIon().writeTo(ionWriter) + out.info(outputSchema.toString()) + } + } + return "OK" } } diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/AbstractPipeline.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/AbstractPipeline.kt index c2fdabe5f4..8b0b7bbbe6 100644 --- a/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/AbstractPipeline.kt +++ b/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/AbstractPipeline.kt @@ -20,6 +20,7 @@ import com.amazon.ionelement.api.ionInt import com.amazon.ionelement.api.ionString import com.amazon.ionelement.api.ionStructOf import com.amazon.ionelement.api.toIonValue +import kotlinx.coroutines.runBlocking import org.partiql.annotations.ExperimentalPartiQLCompilerPipeline import org.partiql.cli.Debug import org.partiql.cli.functions.QueryDDB @@ -28,8 +29,8 @@ import org.partiql.cli.functions.ReadFile_2 import org.partiql.cli.functions.WriteFile_1 import org.partiql.cli.functions.WriteFile_2 import org.partiql.lang.CompilerPipeline -import org.partiql.lang.compiler.PartiQLCompilerBuilder -import org.partiql.lang.compiler.PartiQLCompilerPipeline +import org.partiql.lang.compiler.PartiQLCompilerAsyncBuilder +import org.partiql.lang.compiler.PartiQLCompilerPipelineAsync import org.partiql.lang.eval.CompileOptions import org.partiql.lang.eval.EvaluationSession import org.partiql.lang.eval.ExprFunction @@ -49,7 +50,7 @@ import java.nio.file.Path import java.time.ZoneOffset /** - * A means by which we can run both the EvaluatingCompiler and PartiQLCompilerPipeline + * A means by which we can run both the EvaluatingCompiler and [PartiQLCompilerPipelineAsync]. */ internal sealed class AbstractPipeline(open val options: PipelineOptions) { @@ -162,7 +163,7 @@ internal sealed class AbstractPipeline(open val options: PipelineOptions) { } /** - * Wraps the PartiQLCompilerPipeline + * Wraps the [PartiQLCompilerPipelineAsync] */ @OptIn(ExperimentalPartiQLCompilerPipeline::class) class PipelineExperimental(options: PipelineOptions) : AbstractPipeline(options) { @@ -182,17 +183,19 @@ internal sealed class AbstractPipeline(open val options: PipelineOptions) { override fun compile(input: String, session: EvaluationSession): PartiQLResult { val globalVariableResolver = createGlobalVariableResolver(session) - val pipeline = PartiQLCompilerPipeline( + val pipeline = PartiQLCompilerPipelineAsync( parser = options.parser, planner = PartiQLPlannerBuilder.standard() .options(plannerOptions) .globalVariableResolver(globalVariableResolver) .build(), - compiler = PartiQLCompilerBuilder.standard() + compiler = PartiQLCompilerAsyncBuilder.standard() .options(evaluatorOptions) .build(), ) - return pipeline.compile(input).eval(session) + return runBlocking { + pipeline.compile(input).eval(session) + } } private fun createGlobalVariableResolver(session: EvaluationSession) = GlobalVariableResolver { diff --git a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEngineDefaultTest.kt b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEngineDefaultTest.kt index 766c875efe..d8dc6ecb9f 100644 --- a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEngineDefaultTest.kt +++ b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEngineDefaultTest.kt @@ -666,7 +666,7 @@ class PartiQLEngineDefaultTest { END ; """.trimIndent(), - expected = nullValue() + expected = stringValue(null) ), SuccessTestCase( input = """ @@ -849,7 +849,7 @@ class PartiQLEngineDefaultTest { END ; """.trimIndent(), - expected = nullValue() + expected = stringValue(null) ), SuccessTestCase( input = """ diff --git a/partiql-lang/build.gradle.kts b/partiql-lang/build.gradle.kts index 6c4842dbc9..c712a25da9 100644 --- a/partiql-lang/build.gradle.kts +++ b/partiql-lang/build.gradle.kts @@ -15,7 +15,7 @@ plugins { id(Plugins.conventions) - id(Plugins.jmh) version Versions.jmh + id(Plugins.jmh) version Versions.jmhGradlePlugin id(Plugins.library) id(Plugins.publish) id(Plugins.testFixtures) @@ -40,6 +40,7 @@ dependencies { implementation(Deps.antlrRuntime) implementation(Deps.csv) implementation(Deps.kotlinReflect) + implementation(Deps.kotlinxCoroutines) testImplementation(testFixtures(project(":partiql-planner"))) testImplementation(project(":plugins:partiql-memory")) @@ -48,6 +49,7 @@ dependencies { testImplementation(Deps.junit4Params) testImplementation(Deps.junitVintage) // Enables JUnit4 testImplementation(Deps.mockk) + testImplementation(Deps.kotlinxCoroutinesTest) testFixturesImplementation(project(":lib:isl")) testFixturesImplementation(Deps.kotlinTest) @@ -59,6 +61,18 @@ dependencies { testFixturesImplementation(Deps.junitParams) testFixturesImplementation(Deps.junitVintage) // Enables JUnit4 testFixturesImplementation(Deps.mockk) + + // The JMH gradle plugin that we currently use is 0.5.3, which uses JMH version 1.25. The JMH gradle plugin has a + // newer version (see https://github.com/melix/jmh-gradle-plugin/releases) which upgrades the JMH version. We can't + // use that newer plugin version until we upgrade our gradle version to 8.0+. JMH version 1.25 does not support + // creating CPU flamegraphs using the JMH benchmarks, hence why the newer version dependency is specified here. + // + // When we upgrade gradle to 8.0+, we can upgrade the gradle plugin to the latest and remove this dependency block + dependencies { + jmh(Deps.jmhCore) + jmh(Deps.jmhGeneratorAnnprocess) + jmh(Deps.jmhGeneratorBytecode) + } } publish { diff --git a/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/PartiQLCompilerPipelineAsyncBenchmark.kt b/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/PartiQLCompilerPipelineAsyncBenchmark.kt new file mode 100644 index 0000000000..5e9b2f3c90 --- /dev/null +++ b/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/PartiQLCompilerPipelineAsyncBenchmark.kt @@ -0,0 +1,380 @@ +package org.partiql.jmh.benchmarks + +import com.amazon.ion.IonSystem +import com.amazon.ion.system.IonSystemBuilder +import kotlinx.coroutines.runBlocking +import org.openjdk.jmh.annotations.Benchmark +import org.openjdk.jmh.annotations.BenchmarkMode +import org.openjdk.jmh.annotations.Fork +import org.openjdk.jmh.annotations.Measurement +import org.openjdk.jmh.annotations.Mode +import org.openjdk.jmh.annotations.OutputTimeUnit +import org.openjdk.jmh.annotations.Scope +import org.openjdk.jmh.annotations.State +import org.openjdk.jmh.annotations.Warmup +import org.openjdk.jmh.infra.Blackhole +import org.partiql.annotations.ExperimentalPartiQLCompilerPipeline +import org.partiql.jmh.utils.FORK_VALUE_RECOMMENDED +import org.partiql.jmh.utils.MEASUREMENT_ITERATION_VALUE_RECOMMENDED +import org.partiql.jmh.utils.MEASUREMENT_TIME_VALUE_RECOMMENDED +import org.partiql.jmh.utils.WARMUP_ITERATION_VALUE_RECOMMENDED +import org.partiql.jmh.utils.WARMUP_TIME_VALUE_RECOMMENDED +import org.partiql.lang.compiler.PartiQLCompilerPipelineAsync +import org.partiql.lang.eval.Bindings +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.PartiQLResult +import org.partiql.lang.planner.GlobalResolutionResult +import org.partiql.lang.syntax.PartiQLParserBuilder +import java.util.concurrent.TimeUnit + +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +open class PartiQLCompilerPipelineAsyncBenchmark { + companion object { + private const val FORK_VALUE: Int = FORK_VALUE_RECOMMENDED + private const val MEASUREMENT_ITERATION_VALUE: Int = MEASUREMENT_ITERATION_VALUE_RECOMMENDED + private const val MEASUREMENT_TIME_VALUE: Int = MEASUREMENT_TIME_VALUE_RECOMMENDED + private const val WARMUP_ITERATION_VALUE: Int = WARMUP_ITERATION_VALUE_RECOMMENDED + private const val WARMUP_TIME_VALUE: Int = WARMUP_TIME_VALUE_RECOMMENDED + } + + @State(Scope.Thread) + @OptIn(ExperimentalPartiQLCompilerPipeline::class) + open class MyState { + private val parser = PartiQLParserBuilder.standard().build() + private val myIonSystem: IonSystem = IonSystemBuilder.standard().build() + + private fun tableWithRows(numRows: Int): ExprValue { + val allRows = (1..numRows).joinToString { index -> + """ + { + "id": $index, + "someString": "some string foo $index", + "someDecimal": $index.00, + "someBlob": {{ dHdvIHBhZGRpbmcgY2hhcmFjdGVycw== }}, + "someTimestamp": 2007-02-23T12:14:15.${index}Z + } + """.trimIndent() + } + val data = "[ $allRows ]" + return ExprValue.of( + myIonSystem.singleValue(data) + ) + } + + private val bindings = Bindings.ofMap( + mapOf( + "t1" to tableWithRows(1), + "t10" to tableWithRows(10), + "t100" to tableWithRows(100), + "t1000" to tableWithRows(1000), + "t10000" to tableWithRows(10000), + "t100000" to tableWithRows(100000), + ) + ) + + private val parameters = listOf( + ExprValue.newInt(5), // WHERE `id` > 5 + ExprValue.newInt(1000000), // LIMIT 1000000 + ExprValue.newInt(3), // OFFSET 3 * 2 + ExprValue.newInt(2), // ------------^ + ) + val session = EvaluationSession.build { + globals(bindings) + parameters(parameters) + } + + val pipeline = PartiQLCompilerPipelineAsync.build { + planner.globalVariableResolver { + val value = session.globals[it] + if (value != null) { + GlobalResolutionResult.GlobalVariable(it.name) + } else { + GlobalResolutionResult.Undefined + } + } + } + + val query1 = parser.parseAstStatement( + """ + SELECT * FROM t100000 + """.trimIndent() + ) + val query2 = parser.parseAstStatement( + """ + SELECT * + FROM t100000 + WHERE t100000.someTimestamp < UTCNOW() + """.trimIndent() + ) + val query3 = parser.parseAstStatement( + """ + SELECT * + FROM t100000 + WHERE t100000.someTimestamp < UTCNOW() + LIMIT ${Int.MAX_VALUE} + """.trimIndent() + ) + val query4 = parser.parseAstStatement( + """ + SELECT * + FROM t100000 + WHERE t100000.someTimestamp < UTCNOW() + ORDER BY t100000.id DESC + """.trimIndent() + ) + val query5 = parser.parseAstStatement( + """ + SELECT * + FROM t100000 + WHERE t100000.someTimestamp < UTCNOW() AND t100000.id > ? + LIMIT ? + OFFSET ? * ? + """.trimIndent() + ) + val query6 = parser.parseAstStatement( + """ + SELECT * + FROM t100000 + WHERE t100000.someTimestamp < UTCNOW() AND t100000.id > ? + ORDER BY t100000.id DESC + LIMIT ? + OFFSET ? * ? + """.trimIndent() + ) + val query7 = parser.parseAstStatement( + """ + SELECT * + FROM t10000 + WHERE t10000.someTimestamp < UTCNOW() AND t10000.id > ? + ORDER BY t10000.id DESC + LIMIT ? + OFFSET ? * ? + """.trimIndent() + ) + val query8 = parser.parseAstStatement( + """ + SELECT * + FROM t1000 + WHERE t1000.someTimestamp < UTCNOW() AND t1000.id > ? + ORDER BY t1000.id DESC + LIMIT ? + OFFSET ? * ? + """.trimIndent() + ) + val query9 = parser.parseAstStatement( + """ + SELECT * + FROM t100 + WHERE t100.someTimestamp < UTCNOW() AND t100.id > ? + ORDER BY t100.id DESC + LIMIT ? + OFFSET ? * ? + """.trimIndent() + ) + val query10 = parser.parseAstStatement( + """ + SELECT * + FROM t10 + WHERE t10.someTimestamp < UTCNOW() AND t10.id > ? + ORDER BY t10.id DESC + LIMIT ? + OFFSET ? * ? + """.trimIndent() + ) + val query11 = parser.parseAstStatement( + """ + SELECT * + FROM t1 + WHERE t1.someTimestamp < UTCNOW() AND t1.id > ? + ORDER BY t1.id DESC + LIMIT ? + OFFSET ? * ? + """.trimIndent() + ) + + val statement1 = runBlocking { pipeline.compile(query1) } + val statement2 = runBlocking { pipeline.compile(query2) } + val statement3 = runBlocking { pipeline.compile(query3) } + val statement4 = runBlocking { pipeline.compile(query4) } + val statement5 = runBlocking { pipeline.compile(query5) } + val statement6 = runBlocking { pipeline.compile(query6) } + val statement7 = runBlocking { pipeline.compile(query7) } + val statement8 = runBlocking { pipeline.compile(query8) } + val statement9 = runBlocking { pipeline.compile(query9) } + val statement10 = runBlocking { pipeline.compile(query10) } + val statement11 = runBlocking { pipeline.compile(query11) } + } + + @OptIn(ExperimentalPartiQLCompilerPipeline::class) + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testCompileQuery1(state: MyState, blackhole: Blackhole) = runBlocking { + val statement = state.pipeline.compile(state.query1) + blackhole.consume(statement) + } + + @OptIn(ExperimentalPartiQLCompilerPipeline::class) + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testCompileQuery2(state: MyState, blackhole: Blackhole) = runBlocking { + val statement = state.pipeline.compile(state.query2) + blackhole.consume(statement) + } + + @OptIn(ExperimentalPartiQLCompilerPipeline::class) + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testCompileQuery3(state: MyState, blackhole: Blackhole) = runBlocking { + val statement = state.pipeline.compile(state.query3) + blackhole.consume(statement) + } + + @OptIn(ExperimentalPartiQLCompilerPipeline::class) + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testCompileQuery4(state: MyState, blackhole: Blackhole) = runBlocking { + val statement = state.pipeline.compile(state.query4) + blackhole.consume(statement) + } + + @OptIn(ExperimentalPartiQLCompilerPipeline::class) + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testCompileQuery5(state: MyState, blackhole: Blackhole) = runBlocking { + val statement = state.pipeline.compile(state.query5) + blackhole.consume(statement) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery1(state: MyState, blackhole: Blackhole) = runBlocking { + val result = state.statement1.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery2(state: MyState, blackhole: Blackhole) = runBlocking { + val result = state.statement2.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery3(state: MyState, blackhole: Blackhole) = runBlocking { + val result = state.statement3.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery4(state: MyState, blackhole: Blackhole) = runBlocking { + val result = state.statement4.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery5(state: MyState, blackhole: Blackhole) = runBlocking { + val result = state.statement5.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery6(state: MyState, blackhole: Blackhole) = runBlocking { + val result = state.statement6.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery7(state: MyState, blackhole: Blackhole) = runBlocking { + val result = state.statement7.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery8(state: MyState, blackhole: Blackhole) = runBlocking { + val result = state.statement8.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery9(state: MyState, blackhole: Blackhole) = runBlocking { + val result = state.statement9.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery10(state: MyState, blackhole: Blackhole) = runBlocking { + val result = state.statement10.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery11(state: MyState, blackhole: Blackhole) = runBlocking { + val result = state.statement11.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } +} diff --git a/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/PartiQLCompilerPipelineBenchmark.kt b/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/PartiQLCompilerPipelineBenchmark.kt new file mode 100644 index 0000000000..4647524908 --- /dev/null +++ b/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/PartiQLCompilerPipelineBenchmark.kt @@ -0,0 +1,381 @@ +package org.partiql.jmh.benchmarks + +import com.amazon.ion.IonSystem +import com.amazon.ion.system.IonSystemBuilder +import kotlinx.coroutines.runBlocking +import org.openjdk.jmh.annotations.Benchmark +import org.openjdk.jmh.annotations.BenchmarkMode +import org.openjdk.jmh.annotations.Fork +import org.openjdk.jmh.annotations.Measurement +import org.openjdk.jmh.annotations.Mode +import org.openjdk.jmh.annotations.OutputTimeUnit +import org.openjdk.jmh.annotations.Scope +import org.openjdk.jmh.annotations.State +import org.openjdk.jmh.annotations.Warmup +import org.openjdk.jmh.infra.Blackhole +import org.partiql.annotations.ExperimentalPartiQLCompilerPipeline +import org.partiql.jmh.utils.FORK_VALUE_RECOMMENDED +import org.partiql.jmh.utils.MEASUREMENT_ITERATION_VALUE_RECOMMENDED +import org.partiql.jmh.utils.MEASUREMENT_TIME_VALUE_RECOMMENDED +import org.partiql.jmh.utils.WARMUP_ITERATION_VALUE_RECOMMENDED +import org.partiql.jmh.utils.WARMUP_TIME_VALUE_RECOMMENDED +import org.partiql.lang.compiler.PartiQLCompilerPipeline +import org.partiql.lang.eval.Bindings +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.PartiQLResult +import org.partiql.lang.planner.GlobalResolutionResult +import org.partiql.lang.syntax.PartiQLParserBuilder +import java.util.concurrent.TimeUnit + +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +@Deprecated("To be removed in the next major version once the synchronous physical plan compiler is removed.") +open class PartiQLCompilerPipelineBenchmark { + companion object { + private const val FORK_VALUE: Int = FORK_VALUE_RECOMMENDED + private const val MEASUREMENT_ITERATION_VALUE: Int = MEASUREMENT_ITERATION_VALUE_RECOMMENDED + private const val MEASUREMENT_TIME_VALUE: Int = MEASUREMENT_TIME_VALUE_RECOMMENDED + private const val WARMUP_ITERATION_VALUE: Int = WARMUP_ITERATION_VALUE_RECOMMENDED + private const val WARMUP_TIME_VALUE: Int = WARMUP_TIME_VALUE_RECOMMENDED + } + + @State(Scope.Thread) + @OptIn(ExperimentalPartiQLCompilerPipeline::class) + open class MyState { + private val parser = PartiQLParserBuilder.standard().build() + private val myIonSystem: IonSystem = IonSystemBuilder.standard().build() + + private fun tableWithRows(numRows: Int): ExprValue { + val allRows = (1..numRows).joinToString { index -> + """ + { + "id": $index, + "someString": "some string foo $index", + "someDecimal": $index.00, + "someBlob": {{ dHdvIHBhZGRpbmcgY2hhcmFjdGVycw== }}, + "someTimestamp": 2007-02-23T12:14:15.${index}Z + } + """.trimIndent() + } + val data = "[ $allRows ]" + return ExprValue.of( + myIonSystem.singleValue(data) + ) + } + + private val bindings = Bindings.ofMap( + mapOf( + "t1" to tableWithRows(1), + "t10" to tableWithRows(10), + "t100" to tableWithRows(100), + "t1000" to tableWithRows(1000), + "t10000" to tableWithRows(10000), + "t100000" to tableWithRows(100000), + ) + ) + + private val parameters = listOf( + ExprValue.newInt(5), // WHERE `id` > 5 + ExprValue.newInt(1000000), // LIMIT 1000000 + ExprValue.newInt(3), // OFFSET 3 * 2 + ExprValue.newInt(2), // ------------^ + ) + val session = EvaluationSession.build { + globals(bindings) + parameters(parameters) + } + + val pipeline = PartiQLCompilerPipeline.build { + planner.globalVariableResolver { + val value = session.globals[it] + if (value != null) { + GlobalResolutionResult.GlobalVariable(it.name) + } else { + GlobalResolutionResult.Undefined + } + } + } + + val query1 = parser.parseAstStatement( + """ + SELECT * FROM t100000 + """.trimIndent() + ) + val query2 = parser.parseAstStatement( + """ + SELECT * + FROM t100000 + WHERE t100000.someTimestamp < UTCNOW() + """.trimIndent() + ) + val query3 = parser.parseAstStatement( + """ + SELECT * + FROM t100000 + WHERE t100000.someTimestamp < UTCNOW() + LIMIT ${Int.MAX_VALUE} + """.trimIndent() + ) + val query4 = parser.parseAstStatement( + """ + SELECT * + FROM t100000 + WHERE t100000.someTimestamp < UTCNOW() + ORDER BY t100000.id DESC + """.trimIndent() + ) + val query5 = parser.parseAstStatement( + """ + SELECT * + FROM t100000 + WHERE t100000.someTimestamp < UTCNOW() AND t100000.id > ? + LIMIT ? + OFFSET ? * ? + """.trimIndent() + ) + private val query6 = parser.parseAstStatement( + """ + SELECT * + FROM t100000 + WHERE t100000.someTimestamp < UTCNOW() AND t100000.id > ? + ORDER BY t100000.id DESC + LIMIT ? + OFFSET ? * ? + """.trimIndent() + ) + private val query7 = parser.parseAstStatement( + """ + SELECT * + FROM t10000 + WHERE t10000.someTimestamp < UTCNOW() AND t10000.id > ? + ORDER BY t10000.id DESC + LIMIT ? + OFFSET ? * ? + """.trimIndent() + ) + private val query8 = parser.parseAstStatement( + """ + SELECT * + FROM t1000 + WHERE t1000.someTimestamp < UTCNOW() AND t1000.id > ? + ORDER BY t1000.id DESC + LIMIT ? + OFFSET ? * ? + """.trimIndent() + ) + private val query9 = parser.parseAstStatement( + """ + SELECT * + FROM t100 + WHERE t100.someTimestamp < UTCNOW() AND t100.id > ? + ORDER BY t100.id DESC + LIMIT ? + OFFSET ? * ? + """.trimIndent() + ) + private val query10 = parser.parseAstStatement( + """ + SELECT * + FROM t10 + WHERE t10.someTimestamp < UTCNOW() AND t10.id > ? + ORDER BY t10.id DESC + LIMIT ? + OFFSET ? * ? + """.trimIndent() + ) + private val query11 = parser.parseAstStatement( + """ + SELECT * + FROM t1 + WHERE t1.someTimestamp < UTCNOW() AND t1.id > ? + ORDER BY t1.id DESC + LIMIT ? + OFFSET ? * ? + """.trimIndent() + ) + + val statement1 = pipeline.compile(query1) + val statement2 = pipeline.compile(query2) + val statement3 = pipeline.compile(query3) + val statement4 = pipeline.compile(query4) + val statement5 = pipeline.compile(query5) + val statement6 = pipeline.compile(query6) + val statement7 = pipeline.compile(query7) + val statement8 = pipeline.compile(query8) + val statement9 = pipeline.compile(query9) + val statement10 = pipeline.compile(query10) + val statement11 = pipeline.compile(query11) + } + + @OptIn(ExperimentalPartiQLCompilerPipeline::class) + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testCompileQuery1(state: MyState, blackhole: Blackhole) = runBlocking { + val statement = state.pipeline.compile(state.query1) + blackhole.consume(statement) + } + + @OptIn(ExperimentalPartiQLCompilerPipeline::class) + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testCompileQuery2(state: MyState, blackhole: Blackhole) = runBlocking { + val statement = state.pipeline.compile(state.query2) + blackhole.consume(statement) + } + + @OptIn(ExperimentalPartiQLCompilerPipeline::class) + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testCompileQuery3(state: MyState, blackhole: Blackhole) = runBlocking { + val statement = state.pipeline.compile(state.query3) + blackhole.consume(statement) + } + + @OptIn(ExperimentalPartiQLCompilerPipeline::class) + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testCompileQuery4(state: MyState, blackhole: Blackhole) { + val statement = state.pipeline.compile(state.query4) + blackhole.consume(statement) + } + + @OptIn(ExperimentalPartiQLCompilerPipeline::class) + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testCompileQuery5(state: MyState, blackhole: Blackhole) { + val statement = state.pipeline.compile(state.query5) + blackhole.consume(statement) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery1(state: MyState, blackhole: Blackhole) { + val result = state.statement1.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery2(state: MyState, blackhole: Blackhole) { + val result = state.statement2.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery3(state: MyState, blackhole: Blackhole) { + val result = state.statement3.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery4(state: MyState, blackhole: Blackhole) { + val result = state.statement4.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery5(state: MyState, blackhole: Blackhole) { + val result = state.statement5.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery6(state: MyState, blackhole: Blackhole) { + val result = state.statement6.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery7(state: MyState, blackhole: Blackhole) { + val result = state.statement7.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery8(state: MyState, blackhole: Blackhole) { + val result = state.statement8.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery9(state: MyState, blackhole: Blackhole) { + val result = state.statement9.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery10(state: MyState, blackhole: Blackhole) { + val result = state.statement10.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun testEvalQuery11(state: MyState, blackhole: Blackhole) { + val result = state.statement11.eval(state.session) + val exprValue = (result as PartiQLResult.Value).value + blackhole.consume(exprValue) + blackhole.consume(exprValue.iterator().forEach { }) + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompiler.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompiler.kt index b2c7493583..1f08f2311b 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompiler.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompiler.kt @@ -23,15 +23,18 @@ import org.partiql.lang.planner.PartiQLPlanner * [PartiQLCompiler] is responsible for transforming a [PartiqlPhysical.Plan] into an executable [PartiQLStatement]. */ @ExperimentalPartiQLCompilerPipeline +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerAsync")) interface PartiQLCompiler { /** * Compiles the [PartiqlPhysical.Plan] to an executable [PartiQLStatement]. */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerAsync.compile")) fun compile(statement: PartiqlPhysical.Plan): PartiQLStatement /** * Compiles the [PartiqlPhysical.Statement.Explain] with the details provided in [details] */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerAsync.compile")) fun compile(statement: PartiqlPhysical.Plan, details: PartiQLPlanner.PlanningDetails): PartiQLStatement } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerAsync.kt new file mode 100644 index 0000000000..c8a2ba10d8 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerAsync.kt @@ -0,0 +1,37 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.lang.compiler + +import org.partiql.annotations.ExperimentalPartiQLCompilerPipeline +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.PartiQLStatementAsync +import org.partiql.lang.planner.PartiQLPlanner + +/** + * [PartiQLCompilerAsync] is responsible for transforming a [PartiqlPhysical.Plan] into an executable [PartiQLStatementAsync]. + */ +@ExperimentalPartiQLCompilerPipeline +interface PartiQLCompilerAsync { + + /** + * Compiles the [PartiqlPhysical.Plan] to an executable [PartiQLStatementAsync]. + */ + suspend fun compile(statement: PartiqlPhysical.Plan): PartiQLStatementAsync + + /** + * Compiles the [PartiqlPhysical.Statement.Explain] with the details provided in [details] + */ + suspend fun compile(statement: PartiqlPhysical.Plan, details: PartiQLPlanner.PlanningDetails): PartiQLStatementAsync +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerAsyncBuilder.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerAsyncBuilder.kt new file mode 100644 index 0000000000..321bbf0949 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerAsyncBuilder.kt @@ -0,0 +1,150 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.lang.compiler + +import org.partiql.annotations.ExperimentalPartiQLCompilerPipeline +import org.partiql.annotations.ExperimentalWindowFunctions +import org.partiql.lang.eval.ExprFunction +import org.partiql.lang.eval.ThunkReturnTypeAssertions +import org.partiql.lang.eval.TypingMode +import org.partiql.lang.eval.builtins.DynamicLookupExprFunction +import org.partiql.lang.eval.builtins.SCALAR_BUILTINS_DEFAULT +import org.partiql.lang.eval.builtins.definitionalBuiltins +import org.partiql.lang.eval.builtins.storedprocedure.StoredProcedure +import org.partiql.lang.eval.physical.operators.AggregateOperatorFactoryDefaultAsync +import org.partiql.lang.eval.physical.operators.FilterRelationalOperatorFactoryDefaultAsync +import org.partiql.lang.eval.physical.operators.JoinRelationalOperatorFactoryDefaultAsync +import org.partiql.lang.eval.physical.operators.LetRelationalOperatorFactoryDefaultAsync +import org.partiql.lang.eval.physical.operators.LimitRelationalOperatorFactoryDefaultAsync +import org.partiql.lang.eval.physical.operators.OffsetRelationalOperatorFactoryDefaultAsync +import org.partiql.lang.eval.physical.operators.RelationalOperatorFactory +import org.partiql.lang.eval.physical.operators.ScanRelationalOperatorFactoryDefaultAsync +import org.partiql.lang.eval.physical.operators.SortOperatorFactoryDefaultAsync +import org.partiql.lang.eval.physical.operators.UnpivotOperatorFactoryDefaultAsync +import org.partiql.lang.eval.physical.operators.WindowRelationalOperatorFactoryDefaultAsync +import org.partiql.lang.planner.EvaluatorOptions +import org.partiql.lang.types.CustomType + +/** + * Builder class to instantiate a [PartiQLCompilerAsync]. + * + * Example usages: + * + * ``` + * // Default + * val compiler = PartiQLCompilerAsyncBuilder.standard().build() + * + * // Fluent builder + * val compiler = PartiQLCompilerAsyncBuilder.standard() + * .customFunctions(myCustomFunctionList) + * .build() + * ``` + */ + +@ExperimentalPartiQLCompilerPipeline +class PartiQLCompilerAsyncBuilder private constructor() { + + private var options: EvaluatorOptions = EvaluatorOptions.standard() + private var customTypes: List = emptyList() + private var customFunctions: List = emptyList() + private var customProcedures: List = emptyList() + private var customOperatorFactories: List = emptyList() + + companion object { + + /** + * A collection of all the default relational operator implementations provided by PartiQL. + * + * By default, the query planner will select these as the implementations for all relational operators, but + * alternate implementations may be provided and chosen by physical plan passes. + * + * @see [org.partiql.lang.planner.PlannerPipeline.Builder.addPhysicalPlanPass] + * @see [org.partiql.lang.planner.PlannerPipeline.Builder.addRelationalOperatorFactory] + */ + + private val DEFAULT_RELATIONAL_OPERATOR_FACTORIES = listOf( + AggregateOperatorFactoryDefaultAsync, + SortOperatorFactoryDefaultAsync, + UnpivotOperatorFactoryDefaultAsync, + FilterRelationalOperatorFactoryDefaultAsync, + ScanRelationalOperatorFactoryDefaultAsync, + JoinRelationalOperatorFactoryDefaultAsync, + OffsetRelationalOperatorFactoryDefaultAsync, + LimitRelationalOperatorFactoryDefaultAsync, + LetRelationalOperatorFactoryDefaultAsync, + // Notice here we will not propagate the optin requirement to the user + @OptIn(ExperimentalWindowFunctions::class) + WindowRelationalOperatorFactoryDefaultAsync, + ) + + @JvmStatic + fun standard() = PartiQLCompilerAsyncBuilder() + } + + fun build(): PartiQLCompilerAsync { + if (options.thunkOptions.thunkReturnTypeAssertions == ThunkReturnTypeAssertions.ENABLED) { + TODO("ThunkReturnTypeAssertions.ENABLED requires a static type pass") + } + return PartiQLCompilerAsyncDefault( + evaluatorOptions = options, + customTypedOpParameters = customTypes.associateBy( + keySelector = { it.name }, + valueTransform = { it.typedOpParameter } + ), + functions = allFunctions(options.typingMode), + procedures = customProcedures.associateBy( + keySelector = { it.signature.name }, + valueTransform = { it } + ), + operatorFactories = allOperatorFactories() + ) + } + + fun options(options: EvaluatorOptions) = this.apply { + this.options = options + } + + fun customFunctions(customFunctions: List) = this.apply { + this.customFunctions = customFunctions + } + + fun customTypes(customTypes: List) = this.apply { + this.customTypes = customTypes + } + + fun customProcedures(customProcedures: List) = this.apply { + this.customProcedures = customProcedures + } + + fun customOperatorFactories(customOperatorFactories: List) = this.apply { + this.customOperatorFactories = customOperatorFactories + } + + // --- Internal ---------------------------------- + + private fun allFunctions(typingMode: TypingMode): List { + val definitionalBuiltins = definitionalBuiltins(typingMode) + val builtins = SCALAR_BUILTINS_DEFAULT + return definitionalBuiltins + builtins + customFunctions + DynamicLookupExprFunction() + } + + private fun allOperatorFactories() = (DEFAULT_RELATIONAL_OPERATOR_FACTORIES + customOperatorFactories).apply { + groupBy { it.key }.entries.firstOrNull { it.value.size > 1 }?.let { + error( + "More than one BindingsOperatorFactory for ${it.key.operator} named '${it.value}' was specified." + ) + } + }.associateBy { it.key } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerAsyncDefault.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerAsyncDefault.kt new file mode 100644 index 0000000000..0b98276430 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerAsyncDefault.kt @@ -0,0 +1,159 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.lang.compiler + +import org.partiql.annotations.ExperimentalPartiQLCompilerPipeline +import org.partiql.lang.domains.PartiqlAst +import org.partiql.lang.domains.PartiqlLogical +import org.partiql.lang.domains.PartiqlLogicalResolved +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.errors.PartiQLException +import org.partiql.lang.eval.ExprFunction +import org.partiql.lang.eval.PartiQLResult +import org.partiql.lang.eval.PartiQLStatementAsync +import org.partiql.lang.eval.builtins.storedprocedure.StoredProcedure +import org.partiql.lang.eval.physical.PhysicalBexprToThunkConverterAsync +import org.partiql.lang.eval.physical.PhysicalPlanCompilerAsync +import org.partiql.lang.eval.physical.PhysicalPlanCompilerAsyncImpl +import org.partiql.lang.eval.physical.PhysicalPlanThunkAsync +import org.partiql.lang.eval.physical.operators.RelationalOperatorFactory +import org.partiql.lang.eval.physical.operators.RelationalOperatorFactoryKey +import org.partiql.lang.planner.EvaluatorOptions +import org.partiql.lang.planner.PartiQLPlanner +import org.partiql.lang.types.TypedOpParameter + +@ExperimentalPartiQLCompilerPipeline +internal class PartiQLCompilerAsyncDefault( + evaluatorOptions: EvaluatorOptions, + customTypedOpParameters: Map, + functions: List, + procedures: Map, + operatorFactories: Map +) : PartiQLCompilerAsync { + + private lateinit var exprConverter: PhysicalPlanCompilerAsyncImpl + private val bexprConverter = PhysicalBexprToThunkConverterAsync( + exprConverter = object : PhysicalPlanCompilerAsync { + override suspend fun convert(expr: PartiqlPhysical.Expr): PhysicalPlanThunkAsync = exprConverter.convert(expr) + }, + relationalOperatorFactory = operatorFactories + ) + + init { + exprConverter = PhysicalPlanCompilerAsyncImpl( + functions = functions, + customTypedOpParameters = customTypedOpParameters, + procedures = procedures, + evaluatorOptions = evaluatorOptions, + bexperConverter = bexprConverter + ) + } + + override suspend fun compile(statement: PartiqlPhysical.Plan): PartiQLStatementAsync { + return when (val stmt = statement.stmt) { + is PartiqlPhysical.Statement.Dml -> compileDml(stmt, statement.locals.size) + is PartiqlPhysical.Statement.Exec, + is PartiqlPhysical.Statement.Query -> { + val expression = exprConverter.compile(statement) + PartiQLStatementAsync { expression.eval(it) } + } + is PartiqlPhysical.Statement.Explain -> throw PartiQLException("Unable to compile EXPLAIN without details.") + } + } + + override suspend fun compile(statement: PartiqlPhysical.Plan, details: PartiQLPlanner.PlanningDetails): PartiQLStatementAsync { + return when (val stmt = statement.stmt) { + is PartiqlPhysical.Statement.Dml -> compileDml(stmt, statement.locals.size) + is PartiqlPhysical.Statement.Exec, + is PartiqlPhysical.Statement.Query -> compile(statement) + is PartiqlPhysical.Statement.Explain -> PartiQLStatementAsync { compileExplain(stmt, details) } + } + } + + // --- INTERNAL ------------------- + + private enum class ExplainDomains { + AST, + AST_NORMALIZED, + LOGICAL, + LOGICAL_RESOLVED, + PHYSICAL, + PHYSICAL_TRANSFORMED + } + + private suspend fun compileDml(dml: PartiqlPhysical.Statement.Dml, localsSize: Int): PartiQLStatementAsync { + val rows = exprConverter.compile(dml.rows, localsSize) + return PartiQLStatementAsync { session -> + when (dml.operation) { + is PartiqlPhysical.DmlOperation.DmlReplace -> PartiQLResult.Replace(dml.uniqueId.text, (rows.eval(session) as PartiQLResult.Value).value) + is PartiqlPhysical.DmlOperation.DmlInsert -> PartiQLResult.Insert(dml.uniqueId.text, (rows.eval(session) as PartiQLResult.Value).value) + is PartiqlPhysical.DmlOperation.DmlDelete -> PartiQLResult.Delete(dml.uniqueId.text, (rows.eval(session) as PartiQLResult.Value).value) + is PartiqlPhysical.DmlOperation.DmlUpdate -> TODO("DML Update compilation not supported yet.") + } + } + } + + private fun compileExplain(statement: PartiqlPhysical.Statement.Explain, details: PartiQLPlanner.PlanningDetails): PartiQLResult.Explain.Domain { + return when (val target = statement.target) { + is PartiqlPhysical.ExplainTarget.Domain -> compileExplainDomain(target, details) + } + } + + private fun compileExplainDomain(statement: PartiqlPhysical.ExplainTarget.Domain, details: PartiQLPlanner.PlanningDetails): PartiQLResult.Explain.Domain { + val format = statement.format?.text + val type = statement.type?.text?.uppercase() ?: ExplainDomains.AST.name + val domain = try { + ExplainDomains.valueOf(type) + } catch (ex: IllegalArgumentException) { + throw PartiQLException("Illegal argument: $type") + } + return when (domain) { + ExplainDomains.AST -> { + val explain = details.ast!! as PartiqlAst.Statement.Explain + val target = explain.target as PartiqlAst.ExplainTarget.Domain + PartiQLResult.Explain.Domain(target.statement, format) + } + ExplainDomains.AST_NORMALIZED -> { + val explain = details.astNormalized!! as PartiqlAst.Statement.Explain + val target = explain.target as PartiqlAst.ExplainTarget.Domain + PartiQLResult.Explain.Domain(target.statement, format) + } + ExplainDomains.LOGICAL -> { + val explain = details.logical!!.stmt as PartiqlLogical.Statement.Explain + val target = explain.target as PartiqlLogical.ExplainTarget.Domain + val plan = details.logical.copy(stmt = target.statement) + PartiQLResult.Explain.Domain(plan, format) + } + ExplainDomains.LOGICAL_RESOLVED -> { + val explain = details.logicalResolved!!.stmt as PartiqlLogicalResolved.Statement.Explain + val target = explain.target as PartiqlLogicalResolved.ExplainTarget.Domain + val plan = details.logicalResolved.copy(stmt = target.statement) + PartiQLResult.Explain.Domain(plan, format) + } + ExplainDomains.PHYSICAL -> { + val explain = details.physical!!.stmt as PartiqlPhysical.Statement.Explain + val target = explain.target as PartiqlPhysical.ExplainTarget.Domain + val plan = details.physical.copy(stmt = target.statement) + PartiQLResult.Explain.Domain(plan, format) + } + ExplainDomains.PHYSICAL_TRANSFORMED -> { + val explain = details.physicalTransformed!!.stmt as PartiqlPhysical.Statement.Explain + val target = explain.target as PartiqlPhysical.ExplainTarget.Domain + val plan = details.physicalTransformed.copy(stmt = target.statement) + PartiQLResult.Explain.Domain(plan, format) + } + } + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerBuilder.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerBuilder.kt index 5ade0ad938..6c3bbb87f7 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerBuilder.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerBuilder.kt @@ -54,6 +54,7 @@ import org.partiql.lang.types.CustomType */ @ExperimentalPartiQLCompilerPipeline +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerAsyncBuilder")) class PartiQLCompilerBuilder private constructor() { private var options: EvaluatorOptions = EvaluatorOptions.standard() @@ -90,9 +91,11 @@ class PartiQLCompilerBuilder private constructor() { ) @JvmStatic + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerAsyncBuilder.standard")) fun standard() = PartiQLCompilerBuilder() } + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerAsyncBuilder.build")) fun build(): PartiQLCompiler { if (options.thunkOptions.thunkReturnTypeAssertions == ThunkReturnTypeAssertions.ENABLED) { TODO("ThunkReturnTypeAssertions.ENABLED requires a static type pass") @@ -112,22 +115,27 @@ class PartiQLCompilerBuilder private constructor() { ) } + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerAsyncBuilder.options")) fun options(options: EvaluatorOptions) = this.apply { this.options = options } + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerAsyncBuilder.customFunctions")) fun customFunctions(customFunctions: List) = this.apply { this.customFunctions = customFunctions } + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerAsyncBuilder.customTypes")) fun customTypes(customTypes: List) = this.apply { this.customTypes = customTypes } + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerAsyncBuilder.customProcedures")) fun customProcedures(customProcedures: List) = this.apply { this.customProcedures = customProcedures } + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerAsyncBuilder.customOperatorFactories")) fun customOperatorFactories(customOperatorFactories: List) = this.apply { this.customOperatorFactories = customOperatorFactories } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerDefault.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerDefault.kt index c38c822b79..fbd6e426d5 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerDefault.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerDefault.kt @@ -36,6 +36,7 @@ import org.partiql.lang.planner.PartiQLPlanner import org.partiql.lang.types.TypedOpParameter @ExperimentalPartiQLCompilerPipeline +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerAsyncDefault")) internal class PartiQLCompilerDefault( private val evaluatorOptions: EvaluatorOptions, private val customTypedOpParameters: Map, diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipeline.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipeline.kt index 1547f0a010..f0eb009375 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipeline.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipeline.kt @@ -42,6 +42,7 @@ import org.partiql.lang.syntax.PartiQLParserBuilder * ``` */ @ExperimentalPartiQLCompilerPipeline +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerPipelineAsync")) class PartiQLCompilerPipeline( private val parser: Parser, private val planner: PartiQLPlanner, @@ -54,6 +55,7 @@ class PartiQLCompilerPipeline( * Returns a [PartiQLCompilerPipeline] with default parser, planner, and compiler configurations. */ @JvmStatic + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerPipelineAsync.standard")) fun standard() = PartiQLCompilerPipeline( parser = PartiQLParserBuilder.standard().build(), planner = PartiQLPlannerBuilder.standard().build(), @@ -75,6 +77,7 @@ class PartiQLCompilerPipeline( * } * ``` */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerPipelineAsync.build")) fun build(block: Builder.() -> Unit): PartiQLCompilerPipeline { val builder = Builder() block.invoke(builder) @@ -89,6 +92,7 @@ class PartiQLCompilerPipeline( /** * Compiles a PartiQL query into an executable [PartiQLStatement]. */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerPipelineAsync.compile")) fun compile(statement: String): PartiQLStatement { val ast = parser.parseAstStatement(statement) return compile(ast) @@ -97,6 +101,7 @@ class PartiQLCompilerPipeline( /** * Compiles a [PartiqlAst.Statement] representation of a query into an executable [PartiQLStatement]. */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerPipelineAsync.compile")) fun compile(statement: PartiqlAst.Statement): PartiQLStatement { val result = planner.plan(statement) if (result is PartiQLPlanner.Result.Error) { @@ -110,10 +115,12 @@ class PartiQLCompilerPipeline( * Compiles a [PartiqlPhysical.Plan] representation of a query into an executable [PartiQLStatement]. */ @JvmOverloads + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerPipelineAsync.compile")) fun compile(statement: PartiqlPhysical.Plan, details: PartiQLPlanner.PlanningDetails = PartiQLPlanner.PlanningDetails()): PartiQLStatement { return compiler.compile(statement, details) } + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLCompilerPipelineAsync.Builder")) class Builder internal constructor() { var parser = PartiQLParserBuilder.standard() var planner = PartiQLPlannerBuilder.standard() diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipelineAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipelineAsync.kt new file mode 100644 index 0000000000..18e11870fd --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipelineAsync.kt @@ -0,0 +1,123 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.lang.compiler + +import org.partiql.annotations.ExperimentalPartiQLCompilerPipeline +import org.partiql.lang.domains.PartiqlAst +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.errors.PartiQLException +import org.partiql.lang.eval.PartiQLStatementAsync +import org.partiql.lang.planner.PartiQLPlanner +import org.partiql.lang.planner.PartiQLPlannerBuilder +import org.partiql.lang.syntax.Parser +import org.partiql.lang.syntax.PartiQLParserBuilder + +/** + * [PartiQLCompilerPipelineAsync] is the top-level class for embedded usage of PartiQL. + * + * Example usage: + * ``` + * // Within a coroutine scope or `suspend fun` + * val pipeline = PartiQLCompilerPipelineAsync.standard() + * val session = // session bindings + * val statement = pipeline.compile("-- some PartiQL query!") + * val result = statement.eval(session) + * when (result) { + * is PartiQLResult.Value -> handle(result) // Query Result + * is PartiQLResult.Insert -> handle(result) // DML `Insert` + * is PartiQLResult.Delete -> handle(result) // DML `Delete` + * ... + * } + * ``` + */ +@ExperimentalPartiQLCompilerPipeline +class PartiQLCompilerPipelineAsync( + private val parser: Parser, + private val planner: PartiQLPlanner, + private val compiler: PartiQLCompilerAsync +) { + + companion object { + + /** + * Returns a [PartiQLCompilerPipelineAsync] with default parser, planner, and compiler configurations. + */ + @JvmStatic + fun standard() = PartiQLCompilerPipelineAsync( + parser = PartiQLParserBuilder.standard().build(), + planner = PartiQLPlannerBuilder.standard().build(), + compiler = PartiQLCompilerAsyncBuilder.standard().build() + ) + + /** + * Builder utility for pipeline creation. + * + * Example usage: + * ``` + * val pipeline = PartiQLCompilerPipelineAsync.build { + * planner.options(plannerOptions) + * .globalVariableResolver(globalVariableResolver) + * compiler.ionSystem(ION) + * .options(evaluatorOptions) + * .customTypes(myCustomTypes) + * .customFunctions(myCustomFunctions) + * } + * ``` + */ + fun build(block: Builder.() -> Unit): PartiQLCompilerPipelineAsync { + val builder = Builder() + block.invoke(builder) + return PartiQLCompilerPipelineAsync( + parser = builder.parser.build(), + planner = builder.planner.build(), + compiler = builder.compiler.build(), + ) + } + } + + /** + * Compiles a PartiQL query into an executable [PartiQLStatementAsync]. + */ + suspend fun compile(statement: String): PartiQLStatementAsync { + val ast = parser.parseAstStatement(statement) + return compile(ast) + } + + /** + * Compiles a [PartiqlAst.Statement] representation of a query into an executable [PartiQLStatementAsync]. + */ + suspend fun compile(statement: PartiqlAst.Statement): PartiQLStatementAsync { + val result = planner.plan(statement) + if (result is PartiQLPlanner.Result.Error) { + throw PartiQLException(result.problems.toString()) + } + val plan = (result as PartiQLPlanner.Result.Success).plan + return compile(plan, result.details) + } + + /** + * Compiles a [PartiqlPhysical.Plan] representation of a query into an executable [PartiQLStatementAsync]. + */ + @JvmOverloads + suspend fun compile(statement: PartiqlPhysical.Plan, details: PartiQLPlanner.PlanningDetails = PartiQLPlanner.PlanningDetails()): PartiQLStatementAsync { + return compiler.compile(statement, details) + } + + class Builder internal constructor() { + var parser = PartiQLParserBuilder.standard() + var planner = PartiQLPlannerBuilder.standard() + var compiler = PartiQLCompilerAsyncBuilder.standard() + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/ExpressionAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/ExpressionAsync.kt new file mode 100644 index 0000000000..c3bf60291f --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/ExpressionAsync.kt @@ -0,0 +1,25 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.lang.eval + +/** + * An expression that can be evaluated to [ExprValue]. + */ +internal interface ExpressionAsync { + /** + * Evaluates the [ExpressionAsync] with the given Session + */ + suspend fun eval(session: EvaluationSession): PartiQLResult +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/PartiQLStatement.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/PartiQLStatement.kt index 87e37d8600..43dfdb51c9 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/PartiQLStatement.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/PartiQLStatement.kt @@ -17,7 +17,8 @@ package org.partiql.lang.eval /** * A compiled PartiQL statement */ +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("PartiQLStatementAsync")) fun interface PartiQLStatement { - + @Deprecated("To be removed in next major version.", replaceWith = ReplaceWith("PartiQLStatementAsync.eval")) fun eval(session: EvaluationSession): PartiQLResult } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/PartiQLStatementAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/PartiQLStatementAsync.kt new file mode 100644 index 0000000000..af3fbb6b0d --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/PartiQLStatementAsync.kt @@ -0,0 +1,23 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.lang.eval + +/** + * A compiled PartiQL statement intended to be evaluated from a Kotlin coroutine. + */ +fun interface PartiQLStatementAsync { + + suspend fun eval(session: EvaluationSession): PartiQLResult +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/ThunkAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/ThunkAsync.kt new file mode 100644 index 0000000000..c314eb68aa --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/ThunkAsync.kt @@ -0,0 +1,613 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.lang.eval + +import com.amazon.ionelement.api.MetaContainer +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.toList +import org.partiql.errors.ErrorBehaviorInPermissiveMode +import org.partiql.errors.ErrorCode +import org.partiql.errors.Property +import org.partiql.lang.ast.SourceLocationMeta +import org.partiql.lang.ast.StaticTypeMeta +import org.partiql.lang.domains.staticType +import org.partiql.lang.types.StaticTypeUtils.isInstance + +/** + * A thunk with no parameters other than the current environment. + * + * See https://en.wikipedia.org/wiki/Thunk + * + * @param TEnv The type of the environment. Generic so that the legacy AST compiler and the new compiler may use + * different types here. + */ +internal typealias ThunkAsync = suspend (TEnv) -> ExprValue + +/** + * A thunk taking a single argument and the current environment. + * + * See https://en.wikipedia.org/wiki/Thunk + * + * @param TEnv The type of the environment. Generic so that the legacy AST compiler and the new compiler may use + * different types here. + * @param TArg The type of the additional argument. + */ +internal typealias ThunkValueAsync = suspend (TEnv, TArg) -> ExprValue + +/** + * An extension method for creating [ThunkFactoryAsync] based on the type of [TypingMode] + * - when [TypingMode] is [TypingMode.LEGACY], creates [LegacyThunkFactoryAsync] + * - when [TypingMode] is [TypingMode.PERMISSIVE], creates [PermissiveThunkFactoryAsync] + */ +internal fun TypingMode.createThunkFactoryAsync( + thunkOptions: ThunkOptions +): ThunkFactoryAsync = when (this) { + TypingMode.LEGACY -> LegacyThunkFactoryAsync(thunkOptions) + TypingMode.PERMISSIVE -> PermissiveThunkFactoryAsync(thunkOptions) +} +/** + * Provides methods for constructing new thunks according to the specified [CompileOptions]. + */ +internal abstract class ThunkFactoryAsync( + val thunkOptions: ThunkOptions +) { + private fun checkEvaluationTimeType(thunkResult: ExprValue, metas: MetaContainer): ExprValue { + // When this check is enabled we throw an exception the [MetaContainer] does not have a + // [StaticTypeMeta]. This indicates a bug or unimplemented support for an AST node in + // [StaticTypeInferenceVisitorTransform]. + val staticType = metas.staticType?.type ?: error("Metas collection does not have a StaticTypeMeta") + if (!isInstance(thunkResult, staticType)) { + throw EvaluationException( + "Runtime type does not match the expected StaticType", + ErrorCode.EVALUATOR_VALUE_NOT_INSTANCE_OF_EXPECTED_TYPE, + errorContext = errorContextFrom(metas).apply { + this[Property.EXPECTED_STATIC_TYPE] = staticType.toString() + }, + internal = true + ) + } + return thunkResult + } + + /** + * If [ThunkReturnTypeAssertions.ENABLED] is set, wraps the receiver thunk in another thunk + * that verifies that the value returned from the receiver thunk matches the type found in the [StaticTypeMeta] + * contained within [metas]. + * + * If [metas] contains does not contain [StaticTypeMeta], an [IllegalStateException] is thrown. This is to prevent + * confusion in the case [org.partiql.lang.eval.visitors.StaticTypeInferenceVisitorTransform] has a bug which + * prevents it from assigning a [StaticTypeMeta] or in case it is not run at all. + */ + protected suspend fun ThunkAsync.typeCheck(metas: MetaContainer): ThunkAsync = + when (thunkOptions.thunkReturnTypeAssertions) { + ThunkReturnTypeAssertions.DISABLED -> this + ThunkReturnTypeAssertions.ENABLED -> { + val wrapper: ThunkAsync = { env: TEnv -> + val thunkResult: ExprValue = this(env) + checkEvaluationTimeType(thunkResult, metas) + } + wrapper + } + } + + /** Same as [typeCheck] but works on a [ThunkEnvValue] instead of a [Thunk]. */ + protected suspend fun ThunkValueAsync.typeCheckEnvValue(metas: MetaContainer): ThunkValueAsync = + when (thunkOptions.thunkReturnTypeAssertions) { + ThunkReturnTypeAssertions.DISABLED -> this + ThunkReturnTypeAssertions.ENABLED -> { + val wrapper: ThunkValueAsync = { env: TEnv, value: ExprValue -> + val thunkResult: ExprValue = this(env, value) + checkEvaluationTimeType(thunkResult, metas) + } + wrapper + } + } + + /** + * Creates a [Thunk] which handles exceptions by wrapping them into an [EvaluationException] which uses + * [handleExceptionAsync] to handle exceptions appropriately. + * + * Literal lambdas passed to this function as [t] are inlined into the body of the function being returned, which + * reduces the need to create additional call contexts. The lambdas passed as [t] may not contain non-local returns + * (`crossinline`). + */ + internal suspend inline fun thunkEnvAsync(metas: MetaContainer, crossinline t: ThunkAsync): ThunkAsync { + val sourceLocationMeta = metas[SourceLocationMeta.TAG] as? SourceLocationMeta + + val thunkAsync: ThunkAsync = { env: TEnv -> + this.handleExceptionAsync(sourceLocationMeta) { + t(env) + } + } + return thunkAsync.typeCheck(metas) + } + + /** + * Defines the strategy for unknown propagation of 1-3 operands. + * + * This is the [TypingMode] specific implementation of unknown-propagation, used by the [thunkEnvOperands] + * functions. [getVal1], [getVal2] and [getVal2] are lambdas to allow for differences in short-circuiting. + * + * For all [TypingMode]s, if the values returned by [getVal1], [getVal2] and [getVal2] are all known, + * [compute] is invoked to perform the operation-specific computation. + * + * Note: this must be public due to a Kotlin compiler bug: https://youtrack.jetbrains.com/issue/KT-22625. + * This shouldn't matter though because this class is still `internal`. + */ + abstract suspend fun propagateUnknowns( + getVal1: suspend () -> ExprValue, + getVal2: (suspend () -> ExprValue)?, + getVal3: (suspend () -> ExprValue)?, + compute: (ExprValue, ExprValue?, ExprValue?) -> ExprValue + ): ExprValue + + /** + * Similar to the other [propagateUnknowns] overload, performs unknown propagation for a variadic sequence of + * operations. + * + * Note: this must be public due to a Kotlin compiler bug: https://youtrack.jetbrains.com/issue/KT-22625. + * This shouldn't matter though because this class is still `internal`. + */ + abstract suspend fun propagateUnknowns( + operands: Sequence, + compute: (List) -> ExprValue + ): ExprValue + + /** + * Creates a thunk that accepts three [Thunk] operands ([t1], [t2], and [t3]), evaluates them and propagates + * unknowns according to the current [TypingMode]. When possible, use this function or one of its overloads + * instead of [thunkEnvAsync] when the operation requires propagation of unknown values. + * + * [t1], [t2] and [t3] are each evaluated in with short-circuiting depending on the current [TypingMode]: + * + * - In [TypingMode.PERMISSIVE] mode, the first `MISSING` returned from one of the thunks causes a short-circuit, + * and `MISSING` is returned immediately without evaluating the remaining thunks. If none of the thunks return + * `MISSING`, if any of them has returned `NULL`, `NULL` is returned. + * - In [TypingMode.LEGACY] mode, the first `NULL` or `MISSING` returned from one of the thunks causes a + * short-circuit, and returns `NULL` without evaluating the remaining thunks. + * + * In both modes, if none of the thunks returns `MISSING` or `NULL`, [compute] is invoked to perform the final + * computation on values of the operands which are guaranteed to be known. + * + * Overloads of this function exist that accept 1 and 2 arguments. We do not make [t2] and [t3] nullable with a + * default value of `null` instead of supplying those overloads primarily because [compute] has a different + * signature for each, but also because that would prevent [thunkEnvOperands] from being `inline`. + */ + internal suspend inline fun thunkEnvOperands( + metas: MetaContainer, + crossinline t1: ThunkAsync, + crossinline t2: ThunkAsync, + crossinline t3: ThunkAsync, + crossinline compute: (TEnv, ExprValue, ExprValue, ExprValue) -> ExprValue + ): ThunkAsync = + this.thunkEnvAsync(metas) { env -> + propagateUnknowns({ t1(env) }, { t2(env) }, { t3(env) }) { v1, v2, v3 -> + compute(env, v1, v2!!, v3!!) + } + }.typeCheck(metas) + + /** See the [thunkEnvOperands] with three [Thunk] operands. */ + internal suspend inline fun thunkEnvOperands( + metas: MetaContainer, + crossinline t1: ThunkAsync, + crossinline t2: ThunkAsync, + crossinline compute: (TEnv, ExprValue, ExprValue) -> ExprValue + ): ThunkAsync = + this.thunkEnvAsync(metas) { env -> + propagateUnknowns({ t1(env) }, { t2(env) }, null) { v1, v2, _ -> + compute(env, v1, v2!!) + } + }.typeCheck(metas) + + /** See the [thunkEnvOperands] with three [Thunk] operands. */ + internal suspend inline fun thunkEnvOperands( + metas: MetaContainer, + crossinline t1: ThunkAsync, + crossinline compute: (TEnv, ExprValue) -> ExprValue + ): ThunkAsync = + this.thunkEnvAsync(metas) { env -> + propagateUnknowns({ t1(env) }, null, null) { v1, _, _ -> + compute(env, v1) + } + }.typeCheck(metas) + + /** See the [thunkEnvOperands] with a variadic list of [Thunk] operands. */ + internal suspend inline fun thunkEnvOperands( + metas: MetaContainer, + operandThunks: List>, + crossinline compute: (TEnv, List) -> ExprValue + ): ThunkAsync { + + return this.thunkEnvAsync(metas) { env -> + val operandSeq = flow { + operandThunks.forEach { emit(it(env)) } + } + propagateUnknowns(operandSeq.toList().asSequence()) { values -> + compute(env, values) + } + }.typeCheck(metas) + } + + /** Similar to [thunkEnvAsync], but creates a [ThunkEnvValue] instead. */ + internal suspend inline fun thunkEnvValue( + metas: MetaContainer, + crossinline t: ThunkValueAsync + ): ThunkValueAsync { + val sourceLocationMeta = metas[SourceLocationMeta.TAG] as? SourceLocationMeta + + val tVal: ThunkValueAsync = { env: TEnv, arg1: ExprValue -> + this.handleExceptionAsync(sourceLocationMeta) { + t(env, arg1) + } + } + return tVal.typeCheckEnvValue(metas) + } + + /** + * Similar to [thunkEnvAsync] but evaluates all [argThunks] and performs a fold using [op] as the operation. + * + * Also handles null propagation appropriately for NAryOp arithmetic operations. Each thunk in [argThunks] + * is evaluated in turn and: + * + * - for [TypingMode.LEGACY], the first unknown operand short-circuits, returning `NULL`. + * - for [TypingMode.PERMISSIVE], the first missing operand short-circuits, returning `MISSING`. Then, if one + * of the operands returned `NULL`, `NULL` is returned. + * + * For both modes, if all the operands are known, performs a fold over them with [op]. + */ + internal abstract suspend fun thunkFold( + metas: MetaContainer, + argThunks: List>, + op: (ExprValue, ExprValue) -> ExprValue + ): ThunkAsync + + /** + * Similar to [thunkFold] but intended for comparison operators, i.e. `=`, `>`, `>=`, `<`, `<=`. + * + * The first argument of [op] is always the value of `argThunks[n]` and + * the second is always `argThunks[n + 1]` where `n` is 0 to `argThunks.size - 2`. + * + * - If [op] returns false, the thunk short circuits and the result of the thunk becomes `false`. + * - for [TypingMode.LEGACY], the first unknown operand short-circuits, returning `NULL`. + * - for [TypingMode.PERMISSIVE], the first missing operand short-circuits, returning `MISSING`. Then, if one + * of the operands returned `NULL`, `NULL` is returned. + * + * If [op] is true for all invocations then the result of the thunk becomes `true`, otherwise the result is `false`. + * + * The name of this function was inspired by Racket's `andmap` procedure. + */ + internal abstract suspend fun thunkAndMap( + metas: MetaContainer, + argThunks: List>, + op: (ExprValue, ExprValue) -> Boolean + ): ThunkAsync + + /** Populates [exception] with the line & column from the specified [SourceLocationMeta]. */ + protected fun populateErrorContext( + exception: EvaluationException, + sourceLocation: SourceLocationMeta? + ): EvaluationException { + // Only add source location data to the error context if it doesn't already exist + // in [errorContext]. + if (!exception.errorContext.hasProperty(Property.LINE_NUMBER)) { + sourceLocation?.let { fillErrorContext(exception.errorContext, sourceLocation) } + } + return exception + } + + /** + * Handles exceptions appropriately for a run-time [ThunkAsync]. + * + * - The [SourceLocationMeta] will be extracted from [MetaContainer] and included in any [EvaluationException] that + * is thrown, if present. + * - The location information is added to the [EvaluationException]'s `errorContext`, if it is not already present. + * - Exceptions thrown by [block] that are not an [EvaluationException] cause an [EvaluationException] to be thrown + * with the original exception as the cause. + */ + abstract suspend fun handleExceptionAsync( + sourceLocation: SourceLocationMeta?, + block: suspend () -> ExprValue + ): ExprValue +} + +/** + * Provides methods for constructing new thunks according to the specified [CompileOptions] for [TypingMode.LEGACY] behaviour. + */ +internal class LegacyThunkFactoryAsync( + thunkOptions: ThunkOptions +) : ThunkFactoryAsync(thunkOptions) { + + override suspend fun propagateUnknowns( + getVal1: suspend () -> ExprValue, + getVal2: (suspend () -> ExprValue)?, + getVal3: (suspend () -> ExprValue)?, + compute: (ExprValue, ExprValue?, ExprValue?) -> ExprValue + ): ExprValue { + val val1 = getVal1() + return when { + val1.isUnknown() -> ExprValue.nullValue + else -> { + val val2 = getVal2?.let { it() } + when { + val2 == null -> compute(val1, null, null) + val2.isUnknown() -> ExprValue.nullValue + else -> { + val val3 = getVal3?.let { it() } + when { + val3 == null -> compute(val1, val2, null) + val3.isUnknown() -> ExprValue.nullValue + else -> compute(val1, val2, val3) + } + } + } + } + } + } + + override suspend fun propagateUnknowns( + operands: Sequence, + compute: (List) -> ExprValue + ): ExprValue { + // Because we need to short-circuit on the first unknown value and [operands] is a sequence, + // we can't use .map here. (non-local returns on `.map` are not allowed) + val argValues = mutableListOf() + operands.forEach { + when { + it.isUnknown() -> return ExprValue.nullValue + else -> argValues.add(it) + } + } + return compute(argValues) + } + + /** See [ThunkFactoryAsync.thunkFold]. */ + override suspend fun thunkFold( + metas: MetaContainer, + argThunks: List>, + op: (ExprValue, ExprValue) -> ExprValue + ): ThunkAsync { + require(argThunks.isNotEmpty()) { "argThunks must not be empty" } + + val firstThunk = argThunks.first() + val otherThunks = argThunks.drop(1) + return thunkEnvAsync(metas) thunkBlock@{ env -> + val firstValue = firstThunk(env) + when { + // Short-circuit at first NULL or MISSING value and return NULL. + firstValue.isUnknown() -> ExprValue.nullValue + else -> { + otherThunks.fold(firstValue) { acc, curr -> + val currValue = curr(env) + if (currValue.type.isUnknown) { + return@thunkBlock ExprValue.nullValue + } + op(acc, currValue) + } + } + } + }.typeCheck(metas) + } + + /** See [ThunkFactoryAsync.thunkAndMap]. */ + override suspend fun thunkAndMap( + metas: MetaContainer, + argThunks: List>, + op: (ExprValue, ExprValue) -> Boolean + ): ThunkAsync { + require(argThunks.size >= 2) { "argThunks must have at least two elements" } + + val firstThunk = argThunks.first() + val otherThunks = argThunks.drop(1) + + return thunkEnvAsync(metas) thunkBlock@{ env -> + val firstValue = firstThunk(env) + when { + // If the first value is unknown, short circuit returning null. + firstValue.isUnknown() -> ExprValue.nullValue + else -> { + otherThunks.fold(firstValue) { lastValue, currentThunk -> + + val currentValue = currentThunk(env) + if (currentValue.isUnknown()) { + return@thunkBlock ExprValue.nullValue + } + + val result = op(lastValue, currentValue) + if (!result) { + return@thunkBlock ExprValue.newBoolean(false) + } + + currentValue + } + + ExprValue.newBoolean(true) + } + } + } + } + + /** + * Handles exceptions appropriately for a run-time [ThunkAsync] respecting [TypingMode.LEGACY] behaviour. + * + * - The [SourceLocationMeta] will be extracted from [MetaContainer] and included in any [EvaluationException] that + * is thrown, if present. + * - The location information is added to the [EvaluationException]'s `errorContext`, if it is not already present. + * - Exceptions thrown by [block] that are not an [EvaluationException] cause an [EvaluationException] to be thrown + * with the original exception as the cause. + */ + override suspend fun handleExceptionAsync( + sourceLocation: SourceLocationMeta?, + block: suspend () -> ExprValue + ): ExprValue = + try { + block() + } catch (e: EvaluationException) { + throw populateErrorContext(e, sourceLocation) + } catch (e: Exception) { + thunkOptions.handleExceptionForLegacyMode(e, sourceLocation) + } +} + +/** + * Provides methods for constructing new thunks according to the specified [CompileOptions] and for + * [TypingMode.PERMISSIVE] behaviour. + */ +internal class PermissiveThunkFactoryAsync( + thunkOptions: ThunkOptions +) : ThunkFactoryAsync(thunkOptions) { + + override suspend fun propagateUnknowns( + getVal1: suspend () -> ExprValue, + getVal2: (suspend () -> ExprValue)?, + getVal3: (suspend () -> ExprValue)?, + compute: (ExprValue, ExprValue?, ExprValue?) -> ExprValue + ): ExprValue { + val val1 = getVal1() + return when (val1.type) { + ExprValueType.MISSING -> ExprValue.missingValue + else -> { + val val2 = getVal2?.let { it() } + when { + val2 == null -> nullOrCompute(val1, null, null, compute) + val2.type == ExprValueType.MISSING -> ExprValue.missingValue + else -> { + val val3 = getVal3?.let { it() } + when { + val3 == null -> nullOrCompute(val1, val2, null, compute) + val3.type == ExprValueType.MISSING -> ExprValue.missingValue + else -> nullOrCompute(val1, val2, val3, compute) + } + } + } + } + } + } + + override suspend fun propagateUnknowns( + operands: Sequence, + compute: (List) -> ExprValue + ): ExprValue { + + // Because we need to short-circuit on the first MISSING value and [operands] is a sequence, + // we can't use .map here. (non-local returns on `.map` are not allowed) + val argValues = mutableListOf() + operands.forEach { + when (it.type) { + ExprValueType.MISSING -> return ExprValue.missingValue + else -> argValues.add(it) + } + } + return when { + // if any result is `NULL`, propagate return null instead. + argValues.any { it.type == ExprValueType.NULL } -> ExprValue.nullValue + else -> compute(argValues) + } + } + + private fun nullOrCompute( + v1: ExprValue, + v2: ExprValue?, + v3: ExprValue?, + compute: (ExprValue, ExprValue?, ExprValue?) -> ExprValue + ): ExprValue = + when { + v1.type == ExprValueType.NULL || + (v2?.let { it.type == ExprValueType.NULL }) ?: false || + (v3?.let { it.type == ExprValueType.NULL }) ?: false -> ExprValue.nullValue + else -> compute(v1, v2, v3) + } + + /** See [ThunkFactoryAsync.thunkFold]. */ + override suspend fun thunkFold( + metas: MetaContainer, + argThunks: List>, + op: (ExprValue, ExprValue) -> ExprValue + ): ThunkAsync { + require(argThunks.isNotEmpty()) { "argThunks must not be empty" } + + return thunkEnvAsync(metas) { env -> + val values = argThunks.map { + val v = it(env) + when (v.type) { + // Short-circuit at first detected MISSING value. + ExprValueType.MISSING -> return@thunkEnvAsync ExprValue.missingValue + else -> v + } + } + when { + // Propagate NULL if any operand is NULL. + values.any { it.type == ExprValueType.NULL } -> ExprValue.nullValue + // compute the final value. + else -> values.reduce { first, second -> op(first, second) } + } + }.typeCheck(metas) + } + + /** See [ThunkFactoryAsync.thunkAndMap]. */ + override suspend fun thunkAndMap( + metas: MetaContainer, + argThunks: List>, + op: (ExprValue, ExprValue) -> Boolean + ): ThunkAsync { + require(argThunks.size >= 2) { "argThunks must have at least two elements" } + + return thunkEnvAsync(metas) thunkBlock@{ env -> + val values = argThunks.map { + val v = it(env) + when (v.type) { + // Short-circuit at first detected MISSING value. + ExprValueType.MISSING -> return@thunkBlock ExprValue.missingValue + else -> v + } + } + when { + // Propagate NULL if any operand is NULL. + values.any { it.type == ExprValueType.NULL } -> ExprValue.nullValue + else -> { + (0..(values.size - 2)).forEach { i -> + if (!op(values[i], values[i + 1])) + return@thunkBlock ExprValue.newBoolean(false) + } + + return@thunkBlock ExprValue.newBoolean(true) + } + } + } + } + + /** + * Handles exceptions appropriately for a run-time [Thunk] respecting [TypingMode.PERMISSIVE] behaviour. + * + * - Exceptions thrown by [block] that are [EvaluationException] are caught and [ExprValue.missingValue] is returned. + * - Exceptions thrown by [block] that are not an [EvaluationException] cause an [EvaluationException] to be thrown + * with the original exception as the cause. + */ + override suspend fun handleExceptionAsync( + sourceLocation: SourceLocationMeta?, + block: suspend () -> ExprValue + ): ExprValue = + try { + block() + } catch (e: EvaluationException) { + thunkOptions.handleExceptionForPermissiveMode(e, sourceLocation) + when (e.errorCode.errorBehaviorInPermissiveMode) { + // Rethrows the exception as it does in LEGACY mode. + ErrorBehaviorInPermissiveMode.THROW_EXCEPTION -> throw populateErrorContext(e, sourceLocation) + ErrorBehaviorInPermissiveMode.RETURN_MISSING -> ExprValue.missingValue + } + } catch (e: Exception) { + thunkOptions.handleExceptionForLegacyMode(e, sourceLocation) + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalBexprToThunkConverter.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalBexprToThunkConverter.kt index 4158f0c61f..9a89d9aa8e 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalBexprToThunkConverter.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalBexprToThunkConverter.kt @@ -1,6 +1,5 @@ package org.partiql.lang.eval.physical -import com.amazon.ionelement.api.BoolElement import com.amazon.ionelement.api.MetaContainer import org.partiql.annotations.ExperimentalWindowFunctions import org.partiql.lang.ast.SourceLocationMeta @@ -325,6 +324,3 @@ internal class PhysicalBexprToThunkConverter( return bindingsExpr.toRelationThunk(node.metas) } } - -private fun PartiqlPhysical.Expr.isLitTrue() = - this is PartiqlPhysical.Expr.Lit && this.value is BoolElement && this.value.booleanValue diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalBexprToThunkConverterAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalBexprToThunkConverterAsync.kt new file mode 100644 index 0000000000..5dada878e5 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalBexprToThunkConverterAsync.kt @@ -0,0 +1,362 @@ +package org.partiql.lang.eval.physical + +import com.amazon.ionelement.api.BoolElement +import com.amazon.ionelement.api.MetaContainer +import org.partiql.annotations.ExperimentalWindowFunctions +import org.partiql.lang.ast.SourceLocationMeta +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.NaturalExprValueComparators +import org.partiql.lang.eval.ThunkAsync +import org.partiql.lang.eval.ThunkValueAsync +import org.partiql.lang.eval.physical.operators.AggregateOperatorFactoryAsync +import org.partiql.lang.eval.physical.operators.CompiledAggregateFunctionAsync +import org.partiql.lang.eval.physical.operators.CompiledGroupKeyAsync +import org.partiql.lang.eval.physical.operators.CompiledSortKeyAsync +import org.partiql.lang.eval.physical.operators.CompiledWindowFunctionAsync +import org.partiql.lang.eval.physical.operators.FilterRelationalOperatorFactoryAsync +import org.partiql.lang.eval.physical.operators.JoinRelationalOperatorFactoryAsync +import org.partiql.lang.eval.physical.operators.LetRelationalOperatorFactoryAsync +import org.partiql.lang.eval.physical.operators.LimitRelationalOperatorFactoryAsync +import org.partiql.lang.eval.physical.operators.OffsetRelationalOperatorFactoryAsync +import org.partiql.lang.eval.physical.operators.ProjectRelationalOperatorFactoryAsync +import org.partiql.lang.eval.physical.operators.RelationExpressionAsync +import org.partiql.lang.eval.physical.operators.RelationalOperatorFactory +import org.partiql.lang.eval.physical.operators.RelationalOperatorFactoryKey +import org.partiql.lang.eval.physical.operators.RelationalOperatorKind +import org.partiql.lang.eval.physical.operators.ScanRelationalOperatorFactoryAsync +import org.partiql.lang.eval.physical.operators.SortOperatorFactoryAsync +import org.partiql.lang.eval.physical.operators.UnpivotOperatorFactoryAsync +import org.partiql.lang.eval.physical.operators.WindowRelationalOperatorFactoryAsync +import org.partiql.lang.eval.physical.operators.valueExpressionAsync +import org.partiql.lang.eval.physical.window.createBuiltinWindowFunctionAsync +import org.partiql.lang.util.toIntExact + +/** Converts instances of [PartiqlPhysical.Bexpr] to any [T]. A `suspend` version of the physical plan converter + * interface is added since PIG currently does not output async functions. + */ +internal interface Converter { + suspend fun convert(node: PartiqlPhysical.Bexpr): T = when (node) { + is PartiqlPhysical.Bexpr.Project -> convertProject(node) + is PartiqlPhysical.Bexpr.Scan -> convertScan(node) + is PartiqlPhysical.Bexpr.Unpivot -> convertUnpivot(node) + is PartiqlPhysical.Bexpr.Filter -> convertFilter(node) + is PartiqlPhysical.Bexpr.Join -> convertJoin(node) + is PartiqlPhysical.Bexpr.Sort -> convertSort(node) + is PartiqlPhysical.Bexpr.Aggregate -> convertAggregate(node) + is PartiqlPhysical.Bexpr.Offset -> convertOffset(node) + is PartiqlPhysical.Bexpr.Limit -> convertLimit(node) + is PartiqlPhysical.Bexpr.Let -> convertLet(node) + is PartiqlPhysical.Bexpr.Window -> convertWindow(node) + } + + suspend fun convertProject(node: PartiqlPhysical.Bexpr.Project): T + suspend fun convertScan(node: PartiqlPhysical.Bexpr.Scan): T + suspend fun convertUnpivot(node: PartiqlPhysical.Bexpr.Unpivot): T + suspend fun convertFilter(node: PartiqlPhysical.Bexpr.Filter): T + suspend fun convertJoin(node: PartiqlPhysical.Bexpr.Join): T + suspend fun convertSort(node: PartiqlPhysical.Bexpr.Sort): T + suspend fun convertAggregate(node: PartiqlPhysical.Bexpr.Aggregate): T + suspend fun convertOffset(node: PartiqlPhysical.Bexpr.Offset): T + suspend fun convertLimit(node: PartiqlPhysical.Bexpr.Limit): T + suspend fun convertLet(node: PartiqlPhysical.Bexpr.Let): T + suspend fun convertWindow(node: PartiqlPhysical.Bexpr.Window): T +} + +/** A specialization of [ThunkAsync] that we use for evaluation of physical plans. */ +internal typealias PhysicalPlanThunkAsync = ThunkAsync + +/** A specialization of [ThunkValueAsync] that we use for evaluation of physical plans. */ +internal typealias PhysicalPlanThunkValueAsync = ThunkValueAsync + +internal class PhysicalBexprToThunkConverterAsync( + private val exprConverter: PhysicalPlanCompilerAsync, + private val relationalOperatorFactory: Map +) : Converter { + + private fun PhysicalPlanThunkAsync.toValueExpr(sourceLocationMeta: SourceLocationMeta?) = + valueExpressionAsync(sourceLocationMeta) { state -> this(state) } + + private suspend fun RelationExpressionAsync.toRelationThunk(metas: MetaContainer) = + relationThunkAsync(metas) { state -> this.evaluate(state) } + + private inline fun findOperatorFactory( + operator: RelationalOperatorKind, + name: String + ): T { + val key = RelationalOperatorFactoryKey(operator, name) + val found = + relationalOperatorFactory[key] ?: error("Factory for operator ${key.operator} named '${key.name}' does not exist.") + return found as? T + ?: error( + "Internal error: Operator factory ${key.operator} named '${key.name}' does not derive from " + + T::class.java + "." + ) + } + + override suspend fun convertProject(node: PartiqlPhysical.Bexpr.Project): RelationThunkEnvAsync { + // recurse into children + val argExprs = node.args.map { exprConverter.convert(it).toValueExpr(it.metas.sourceLocationMeta) } + + // locate operator factory + val factory = findOperatorFactory(RelationalOperatorKind.PROJECT, node.i.name.text) + + // create operator implementation + val bindingsExpr = factory.create(node.i, node.binding.toSetVariableFunc(), argExprs) + + // wrap in thunk. + return bindingsExpr.toRelationThunk(node.metas) + } + + override suspend fun convertAggregate(node: PartiqlPhysical.Bexpr.Aggregate): RelationThunkEnvAsync { + val source = this.convert(node.source) + + // Compile Arguments + val compiledFunctions = node.functionList.functions.map { func -> + val setAggregateVal = func.asVar.toSetVariableFunc() + val value = exprConverter.convert(func.arg).toValueExpr(func.arg.metas.sourceLocationMeta) + CompiledAggregateFunctionAsync(func.name.text, setAggregateVal, value, func.quantifier) + } + val compiledKeys = node.groupList.keys.map { key -> + val value = exprConverter.convert(key.expr).toValueExpr(key.expr.metas.sourceLocationMeta) + val function = key.asVar.toSetVariableFunc() + CompiledGroupKeyAsync(function, value, key.asVar) + } + + // Get Implementation + val factory = findOperatorFactory(RelationalOperatorKind.AGGREGATE, node.i.name.text) + val relationExpression = factory.create({ state -> source.invoke(state) }, node.strategy, compiledKeys, compiledFunctions) + return relationExpression.toRelationThunk(node.metas) + } + + override suspend fun convertScan(node: PartiqlPhysical.Bexpr.Scan): RelationThunkEnvAsync { + // recurse into children + val valueExpr = exprConverter.convert(node.expr).toValueExpr(node.expr.metas.sourceLocationMeta) + val asSetter = node.asDecl.toSetVariableFunc() + val atSetter = node.atDecl?.toSetVariableFunc() + val bySetter = node.byDecl?.toSetVariableFunc() + + // locate operator factory + val factory = findOperatorFactory(RelationalOperatorKind.SCAN, node.i.name.text) + + // create operator implementation + val bindingsExpr = factory.create( + impl = node.i, + expr = valueExpr, + setAsVar = asSetter, + setAtVar = atSetter, + setByVar = bySetter + ) + + // wrap in thunk + return bindingsExpr.toRelationThunk(node.metas) + } + + override suspend fun convertUnpivot(node: PartiqlPhysical.Bexpr.Unpivot): RelationThunkEnvAsync { + val valueExpr = exprConverter.convert(node.expr).toValueExpr(node.expr.metas.sourceLocationMeta) + val asSetter = node.asDecl.toSetVariableFunc() + val atSetter = node.atDecl?.toSetVariableFunc() + val bySetter = node.byDecl?.toSetVariableFunc() + + val factory = findOperatorFactory(RelationalOperatorKind.UNPIVOT, node.i.name.text) + + val bindingsExpr = factory.create( + expr = valueExpr, + setAsVar = asSetter, + setAtVar = atSetter, + setByVar = bySetter + ) + + return bindingsExpr.toRelationThunk(node.metas) + } + + override suspend fun convertFilter(node: PartiqlPhysical.Bexpr.Filter): RelationThunkEnvAsync { + // recurse into children + val predicateValueExpr = exprConverter.convert(node.predicate).toValueExpr(node.predicate.metas.sourceLocationMeta) + val sourceBindingsExpr = this.convert(node.source) + + // locate operator factory + val factory = findOperatorFactory(RelationalOperatorKind.FILTER, node.i.name.text) + + // create operator implementation + val bindingsExpr = factory.create(node.i, predicateValueExpr) { state -> sourceBindingsExpr.invoke(state) } + + // wrap in thunk + return bindingsExpr.toRelationThunk(node.metas) + } + + override suspend fun convertJoin(node: PartiqlPhysical.Bexpr.Join): RelationThunkEnvAsync { + // recurse into children + val leftBindingsExpr = this.convert(node.left) + val rightBindingsExpr = this.convert(node.right) + val predicateValueExpr = node.predicate?.let { predicate -> + exprConverter.convert(predicate) + .takeIf { !predicate.isLitTrue() } + ?.toValueExpr(predicate.metas.sourceLocationMeta) + } + + // locate operator factory + val factory = findOperatorFactory(RelationalOperatorKind.JOIN, node.i.name.text) + + // Compute a function to set the left-side variables to NULL. This is for use with RIGHT JOIN, when the left + // side of the join is empty or no rows match the predicate. + val leftVariableIndexes = node.left.extractAccessibleVarDecls().map { it.index.value.toIntExact() } + val setLeftSideVariablesToNull: (EvaluatorState) -> Unit = { state -> + leftVariableIndexes.forEach { state.registers[it] = ExprValue.nullValue } + } + // Compute a function to set the right-side variables to NULL. This is for use with LEFT JOIN, when the right + // side of the join is empty or no rows match the predicate. + val rightVariableIndexes = node.right.extractAccessibleVarDecls().map { it.index.value.toIntExact() } + val setRightSideVariablesToNull: (EvaluatorState) -> Unit = { state -> + rightVariableIndexes.forEach { state.registers[it] = ExprValue.nullValue } + } + + return factory.create( + impl = node.i, + joinType = node.joinType, + leftBexpr = { state -> leftBindingsExpr(state) }, + rightBexpr = { state -> rightBindingsExpr(state) }, + predicateExpr = predicateValueExpr, + setLeftSideVariablesToNull = setLeftSideVariablesToNull, + setRightSideVariablesToNull = setRightSideVariablesToNull + ).toRelationThunk(node.metas) + } + + private fun PartiqlPhysical.Bexpr.extractAccessibleVarDecls(): List = + // This fold traverses a [PartiqlPhysical.Bexpr] node and extracts all variable declarations within + // It avoids recursing into sub-queries. + object : PartiqlPhysical.VisitorFold>() { + override fun visitVarDecl( + node: PartiqlPhysical.VarDecl, + accumulator: List + ): List = accumulator + node + + /** + * Avoids recursion into expressions, since these may contain sub-queries with other var-decls that we don't + * care about here. + */ + override fun walkExpr( + node: PartiqlPhysical.Expr, + accumulator: List + ): List { + return accumulator + } + }.walkBexpr(this, emptyList()) + + override suspend fun convertOffset(node: PartiqlPhysical.Bexpr.Offset): RelationThunkEnvAsync { + // recurse into children + val rowCountExpr = exprConverter.convert(node.rowCount).toValueExpr(node.rowCount.metas.sourceLocationMeta) + val sourceBexpr = this.convert(node.source) + + // locate operator factory + val factory = findOperatorFactory(RelationalOperatorKind.OFFSET, node.i.name.text) + + // create operator implementation + val bindingsExpr = factory.create(node.i, rowCountExpr) { state -> sourceBexpr(state) } + // wrap in thunk + return bindingsExpr.toRelationThunk(node.metas) + } + + override suspend fun convertLimit(node: PartiqlPhysical.Bexpr.Limit): RelationThunkEnvAsync { + // recurse into children + val rowCountExpr = exprConverter.convert(node.rowCount).toValueExpr(node.rowCount.metas.sourceLocationMeta) + val sourceBexpr = this.convert(node.source) + + // locate operator factory + val factory = findOperatorFactory(RelationalOperatorKind.LIMIT, node.i.name.text) + + // create operator implementation + val bindingsExpr = factory.create(node.i, rowCountExpr) { state -> sourceBexpr(state) } + + // wrap in thunk + return bindingsExpr.toRelationThunk(node.metas) + } + + override suspend fun convertSort(node: PartiqlPhysical.Bexpr.Sort): RelationThunkEnvAsync { + // Compile Arguments + val source = this.convert(node.source) + val sortKeys = compileSortSpecsAsync(node.sortSpecs) + + // Get Implementation + val factory = findOperatorFactory(RelationalOperatorKind.SORT, node.i.name.text) + val bindingsExpr = factory.create(sortKeys) { state -> source(state) } + return bindingsExpr.toRelationThunk(node.metas) + } + + override suspend fun convertLet(node: PartiqlPhysical.Bexpr.Let): RelationThunkEnvAsync { + // recurse into children + val sourceBexpr = this.convert(node.source) + val compiledBindings = node.bindings.map { + VariableBindingAsync( + it.decl.toSetVariableFunc(), + exprConverter.convert(it.value).toValueExpr(it.value.metas.sourceLocationMeta) + ) + } + // locate operator factory + val factory = findOperatorFactory(RelationalOperatorKind.LET, node.i.name.text) + + // create operator implementation + val bindingsExpr = factory.create(node.i, { state -> sourceBexpr(state) }, compiledBindings) + + // wrap in thunk + return bindingsExpr.toRelationThunk(node.metas) + } + + /** + * Returns a list of [CompiledSortKeyAsync] with the aim of pre-computing the [NaturalExprValueComparators] prior to + * evaluation and leaving the [PartiqlPhysical.SortSpec]'s [PartiqlPhysical.Expr] to be evaluated later. + */ + private suspend fun compileSortSpecsAsync(specs: List): List = specs.map { spec -> + val comp = when (spec.orderingSpec ?: PartiqlPhysical.OrderingSpec.Asc()) { + is PartiqlPhysical.OrderingSpec.Asc -> + when (spec.nullsSpec) { + is PartiqlPhysical.NullsSpec.NullsFirst -> NaturalExprValueComparators.NULLS_FIRST_ASC + is PartiqlPhysical.NullsSpec.NullsLast -> NaturalExprValueComparators.NULLS_LAST_ASC + null -> NaturalExprValueComparators.NULLS_LAST_ASC + } + + is PartiqlPhysical.OrderingSpec.Desc -> + when (spec.nullsSpec) { + is PartiqlPhysical.NullsSpec.NullsFirst -> NaturalExprValueComparators.NULLS_FIRST_DESC + is PartiqlPhysical.NullsSpec.NullsLast -> NaturalExprValueComparators.NULLS_LAST_DESC + null -> NaturalExprValueComparators.NULLS_FIRST_DESC + } + } + val value = exprConverter.convert(spec.expr).toValueExpr(spec.expr.metas.sourceLocationMeta) + CompiledSortKeyAsync(comp, value) + } + + @OptIn(ExperimentalWindowFunctions::class) + override suspend fun convertWindow(node: PartiqlPhysical.Bexpr.Window): RelationThunkEnvAsync { + val source = this.convert(node.source) + + val windowPartitionList = node.windowSpecification.partitionBy + + val windowSortSpecList = node.windowSpecification.orderBy + + val compiledPartitionBy = windowPartitionList?.exprs?.map { + exprConverter.convert(it).toValueExpr(it.metas.sourceLocationMeta) + } ?: emptyList() + + val compiledOrderBy = windowSortSpecList?.sortSpecs?.let { compileSortSpecsAsync(it) } ?: emptyList() + + val compiledWindowFunctions = node.windowExpressionList.map { windowExpression -> + CompiledWindowFunctionAsync( + createBuiltinWindowFunctionAsync(windowExpression.funcName.text), + windowExpression.args.map { exprConverter.convert(it).toValueExpr(it.metas.sourceLocationMeta) }, + windowExpression.decl + ) + } + + // locate operator factory + val factory = findOperatorFactory(RelationalOperatorKind.WINDOW, node.i.name.text) + + // create operator implementation + val bindingsExpr = factory.create({ state -> source(state) }, compiledPartitionBy, compiledOrderBy, compiledWindowFunctions) + // wrap in thunk + return bindingsExpr.toRelationThunk(node.metas) + } +} + +internal fun PartiqlPhysical.Expr.isLitTrue() = + this is PartiqlPhysical.Expr.Lit && this.value is BoolElement && this.value.booleanValue diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalPlanCompilerAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalPlanCompilerAsync.kt new file mode 100644 index 0000000000..25efe3f114 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalPlanCompilerAsync.kt @@ -0,0 +1,13 @@ +package org.partiql.lang.eval.physical + +import org.partiql.lang.domains.PartiqlPhysical + +/** + * Simple API that defines a method to convert a [PartiqlPhysical.Expr] to a [PhysicalPlanThunkAsync]. + * + * Intended to prevent [PhysicalBexprToThunkConverterAsync] from having to take a direct dependency on + * [org.partiql.lang.eval.EvaluatingCompiler]. + */ +internal interface PhysicalPlanCompilerAsync { + suspend fun convert(expr: PartiqlPhysical.Expr): PhysicalPlanThunkAsync +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalPlanCompilerAsyncImpl.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalPlanCompilerAsyncImpl.kt new file mode 100644 index 0000000000..89430245e7 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalPlanCompilerAsyncImpl.kt @@ -0,0 +1,1899 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.lang.eval.physical + +import com.amazon.ion.IonString +import com.amazon.ion.IonValue +import com.amazon.ion.Timestamp +import com.amazon.ion.system.IonSystemBuilder +import com.amazon.ionelement.api.MetaContainer +import com.amazon.ionelement.api.emptyMetaContainer +import com.amazon.ionelement.api.toIonValue +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.flow.withIndex +import org.partiql.errors.ErrorCode +import org.partiql.errors.Property +import org.partiql.errors.PropertyValueMap +import org.partiql.lang.ast.IsOrderedMeta +import org.partiql.lang.ast.SourceLocationMeta +import org.partiql.lang.ast.sourceLocation +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.domains.staticType +import org.partiql.lang.domains.toBindingCase +import org.partiql.lang.eval.AnyOfCastTable +import org.partiql.lang.eval.ArityMismatchException +import org.partiql.lang.eval.BaseExprValue +import org.partiql.lang.eval.BindingCase +import org.partiql.lang.eval.BindingName +import org.partiql.lang.eval.CastFunc +import org.partiql.lang.eval.DEFAULT_COMPARATOR +import org.partiql.lang.eval.ErrorDetails +import org.partiql.lang.eval.EvaluationException +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.ExprFunction +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.ExprValueBagOp +import org.partiql.lang.eval.ExprValueType +import org.partiql.lang.eval.Expression +import org.partiql.lang.eval.ExpressionAsync +import org.partiql.lang.eval.FunctionNotFoundException +import org.partiql.lang.eval.Named +import org.partiql.lang.eval.PartiQLResult +import org.partiql.lang.eval.ProjectionIterationBehavior +import org.partiql.lang.eval.StructOrdering +import org.partiql.lang.eval.ThunkValueAsync +import org.partiql.lang.eval.TypedOpBehavior +import org.partiql.lang.eval.TypingMode +import org.partiql.lang.eval.booleanValue +import org.partiql.lang.eval.builtins.storedprocedure.StoredProcedure +import org.partiql.lang.eval.call +import org.partiql.lang.eval.cast +import org.partiql.lang.eval.compareTo +import org.partiql.lang.eval.createErrorSignaler +import org.partiql.lang.eval.createThunkFactoryAsync +import org.partiql.lang.eval.distinct +import org.partiql.lang.eval.err +import org.partiql.lang.eval.errorContextFrom +import org.partiql.lang.eval.errorIf +import org.partiql.lang.eval.exprEquals +import org.partiql.lang.eval.fillErrorContext +import org.partiql.lang.eval.impl.FunctionManager +import org.partiql.lang.eval.isNotUnknown +import org.partiql.lang.eval.isUnknown +import org.partiql.lang.eval.like.parsePattern +import org.partiql.lang.eval.longValue +import org.partiql.lang.eval.namedValue +import org.partiql.lang.eval.numberValue +import org.partiql.lang.eval.rangeOver +import org.partiql.lang.eval.relation.RelationType +import org.partiql.lang.eval.sourceLocationMeta +import org.partiql.lang.eval.stringValue +import org.partiql.lang.eval.syntheticColumnName +import org.partiql.lang.eval.time.Time +import org.partiql.lang.eval.timestampValue +import org.partiql.lang.eval.unnamedValue +import org.partiql.lang.planner.EvaluatorOptions +import org.partiql.lang.types.StaticTypeUtils.getRuntimeType +import org.partiql.lang.types.StaticTypeUtils.isInstance +import org.partiql.lang.types.StaticTypeUtils.staticTypeFromExprValue +import org.partiql.lang.types.TypedOpParameter +import org.partiql.lang.types.UnknownArguments +import org.partiql.lang.types.toTypedOpParameter +import org.partiql.lang.util.checkThreadInterrupted +import org.partiql.lang.util.codePointSequence +import org.partiql.lang.util.div +import org.partiql.lang.util.exprValue +import org.partiql.lang.util.isZero +import org.partiql.lang.util.minus +import org.partiql.lang.util.plus +import org.partiql.lang.util.rem +import org.partiql.lang.util.stringValue +import org.partiql.lang.util.times +import org.partiql.lang.util.toIntExact +import org.partiql.lang.util.totalMinutes +import org.partiql.lang.util.unaryMinus +import org.partiql.types.AnyOfType +import org.partiql.types.AnyType +import org.partiql.types.IntType +import org.partiql.types.SingleType +import org.partiql.types.StaticType +import org.partiql.types.UnsupportedTypeCheckException +import java.util.LinkedList +import java.util.TreeSet +import java.util.regex.Pattern + +/** + * A basic "compiler" that converts an instance of [PartiqlPhysical.Expr] to an [Expression]. + * + * This is a modified copy of the legacy `EvaluatingCompiler` class, which is now legacy. + * The primary differences between this class an `EvaluatingCompiler` are: + * + * - All references to `PartiqlPhysical` are replaced with `PartiqlPhysical`. + * - `EvaluatingCompiler` compiles "monolithic" SFW queries--this class compiles relational + * operators (in concert with [PhysicalBexprToThunkConverter]). + * + * This implementation produces a "compiled" form consisting of context-threaded + * code in the form of a tree of [PhysicalPlanThunkAsync]s. An overview of this technique can be found + * [here][1]. + * + * **Note:** *threaded* in this context is used in how the code gets *threaded* together for + * interpretation and **not** the concurrency primitive. That is to say this code is NOT thread + * safe. + * + * [1]: https://www.complang.tuwien.ac.at/anton/lvas/sem06w/fest.pdf + */ +internal class PhysicalPlanCompilerAsyncImpl( + private val functions: List, + private val customTypedOpParameters: Map, + private val procedures: Map, + private val evaluatorOptions: EvaluatorOptions = EvaluatorOptions.standard(), + private val bexperConverter: PhysicalBexprToThunkConverterAsync, +) : PhysicalPlanCompilerAsync { + @Deprecated("Use constructor with List instead", level = DeprecationLevel.WARNING) + constructor( + functions: Map, + customTypedOpParameters: Map, + procedures: Map, + evaluatorOptions: EvaluatorOptions = EvaluatorOptions.standard(), + bexperConverter: PhysicalBexprToThunkConverterAsync + ) : this( + functions = functions.values.toList(), + customTypedOpParameters = customTypedOpParameters, + procedures = procedures, + evaluatorOptions = evaluatorOptions, + bexperConverter = bexperConverter + ) + + // TODO: remove this once we migrate from `IonValue` to `IonElement`. + private val ion = IonSystemBuilder.standard().build() + + private val errorSignaler = evaluatorOptions.typingMode.createErrorSignaler() + private val thunkFactory = evaluatorOptions.typingMode.createThunkFactoryAsync(evaluatorOptions.thunkOptions) + + private val functionManager = FunctionManager(functions) + + private fun Boolean.exprValue(): ExprValue = ExprValue.newBoolean(this) + private fun String.exprValue(): ExprValue = ExprValue.newString(this) + + /** + * Compiles a [PartiqlPhysical.Statement] tree to an [ExpressionAsync]. + * + * Checks [Thread.interrupted] before every expression and sub-expression is compiled + * and throws [InterruptedException] if [Thread.interrupted] it has been set in the + * hope that long-running compilations may be aborted by the caller. + */ + suspend fun compile(plan: PartiqlPhysical.Plan): ExpressionAsync { + val thunk = compileAstStatement(plan.stmt) + + return object : ExpressionAsync { + override suspend fun eval(session: EvaluationSession): PartiQLResult { + val env = EvaluatorState( + session = session, + registers = Array(plan.locals.size) { ExprValue.missingValue } + ) + val value = thunk(env) + return PartiQLResult.Value(value = value) + } + } + } + + /** + * Compiles a [PartiqlPhysical.Expr] tree to an [ExpressionAsync]. + * + * Checks [Thread.interrupted] before every expression and sub-expression is compiled + * and throws [InterruptedException] if [Thread.interrupted] it has been set in the + * hope that long-running compilations may be aborted by the caller. + */ + internal suspend fun compile(expr: PartiqlPhysical.Expr, localsSize: Int): ExpressionAsync { + val thunk = compileAstExpr(expr) + + return object : ExpressionAsync { + override suspend fun eval(session: EvaluationSession): PartiQLResult { + val env = EvaluatorState( + session = session, + registers = Array(localsSize) { ExprValue.missingValue } + ) + val value = thunk(env) + return PartiQLResult.Value(value = value) + } + } + } + + override suspend fun convert(expr: PartiqlPhysical.Expr): PhysicalPlanThunkAsync = this.compileAstExpr(expr) + + /** + * Compiles the specified [PartiqlPhysical.Statement] into a [PhysicalPlanThunkAsync]. + * + * This function will [InterruptedException] if [Thread.interrupted] has been set. + */ + private suspend fun compileAstStatement(ast: PartiqlPhysical.Statement): PhysicalPlanThunkAsync { + return when (ast) { + is PartiqlPhysical.Statement.Query -> compileAstExpr(ast.expr) + is PartiqlPhysical.Statement.Exec -> compileExec(ast) + is PartiqlPhysical.Statement.Dml, + is PartiqlPhysical.Statement.Explain -> { + val value = ExprValue.newBoolean(true) + thunkFactory.thunkEnvAsync(emptyMetaContainer()) { value } + } + } + } + + private suspend fun compileAstExpr(expr: PartiqlPhysical.Expr): PhysicalPlanThunkAsync { + checkThreadInterrupted() + val metas = expr.metas + + return when (expr) { + is PartiqlPhysical.Expr.Lit -> compileLit(expr, metas) + is PartiqlPhysical.Expr.Missing -> compileMissing(metas) + is PartiqlPhysical.Expr.LocalId -> compileLocalId(expr, metas) + is PartiqlPhysical.Expr.GlobalId -> compileGlobalId(expr) + is PartiqlPhysical.Expr.SimpleCase -> compileSimpleCase(expr, metas) + is PartiqlPhysical.Expr.SearchedCase -> compileSearchedCase(expr, metas) + is PartiqlPhysical.Expr.Path -> compilePath(expr, metas) + is PartiqlPhysical.Expr.Struct -> compileStruct(expr) + is PartiqlPhysical.Expr.Parameter -> compileParameter(expr, metas) + is PartiqlPhysical.Expr.Date -> compileDate(expr, metas) + is PartiqlPhysical.Expr.LitTime -> compileLitTime(expr, metas) + + // arithmetic operations + is PartiqlPhysical.Expr.Plus -> compilePlus(expr, metas) + is PartiqlPhysical.Expr.Times -> compileTimes(expr, metas) + is PartiqlPhysical.Expr.Minus -> compileMinus(expr, metas) + is PartiqlPhysical.Expr.Divide -> compileDivide(expr, metas) + is PartiqlPhysical.Expr.Modulo -> compileModulo(expr, metas) + is PartiqlPhysical.Expr.BitwiseAnd -> compileBitwiseAnd(expr, metas) + + // comparison operators + is PartiqlPhysical.Expr.And -> compileAnd(expr, metas) + is PartiqlPhysical.Expr.Between -> compileBetween(expr, metas) + is PartiqlPhysical.Expr.Eq -> compileEq(expr, metas) + is PartiqlPhysical.Expr.Gt -> compileGt(expr, metas) + is PartiqlPhysical.Expr.Gte -> compileGte(expr, metas) + is PartiqlPhysical.Expr.Lt -> compileLt(expr, metas) + is PartiqlPhysical.Expr.Lte -> compileLte(expr, metas) + is PartiqlPhysical.Expr.Like -> compileLike(expr, metas) + is PartiqlPhysical.Expr.InCollection -> compileIn(expr, metas) + + // logical operators + is PartiqlPhysical.Expr.Ne -> compileNe(expr, metas) + is PartiqlPhysical.Expr.Or -> compileOr(expr, metas) + + // unary + is PartiqlPhysical.Expr.Not -> compileNot(expr, metas) + is PartiqlPhysical.Expr.Pos -> compilePos(expr, metas) + is PartiqlPhysical.Expr.Neg -> compileNeg(expr, metas) + + // other operators + is PartiqlPhysical.Expr.Concat -> compileConcat(expr, metas) + is PartiqlPhysical.Expr.Call -> compileCall(expr, metas) + is PartiqlPhysical.Expr.NullIf -> compileNullIf(expr, metas) + is PartiqlPhysical.Expr.Coalesce -> compileCoalesce(expr, metas) + + // "typed" operators (RHS is a data type and not an expression) + is PartiqlPhysical.Expr.Cast -> compileCast(expr, metas) + is PartiqlPhysical.Expr.IsType -> compileIs(expr, metas) + is PartiqlPhysical.Expr.CanCast -> compileCanCast(expr, metas) + is PartiqlPhysical.Expr.CanLosslessCast -> compileCanLosslessCast(expr, metas) + + // sequence constructors + is PartiqlPhysical.Expr.List -> compileSeq(ExprValueType.LIST, expr.values, metas) + is PartiqlPhysical.Expr.Sexp -> compileSeq(ExprValueType.SEXP, expr.values, metas) + is PartiqlPhysical.Expr.Bag -> compileSeq(ExprValueType.BAG, expr.values, metas) + + // bag operators + is PartiqlPhysical.Expr.BagOp -> compileBagOp(expr, metas) + is PartiqlPhysical.Expr.BindingsToValues -> compileBindingsToValues(expr) + is PartiqlPhysical.Expr.Pivot -> compilePivot(expr, metas) + is PartiqlPhysical.Expr.GraphMatch -> TODO("Physical compilation of GraphMatch expression") + is PartiqlPhysical.Expr.Timestamp -> TODO() + } + } + + private suspend fun compileBindingsToValues(expr: PartiqlPhysical.Expr.BindingsToValues): PhysicalPlanThunkAsync { + val mapThunk = compileAstExpr(expr.exp) + val bexprThunk: RelationThunkEnvAsync = bexperConverter.convert(expr.query) + + val relationType = when (expr.metas.containsKey(IsOrderedMeta.tag)) { + true -> RelationType.LIST + false -> RelationType.BAG + } + + return thunkFactory.thunkEnvAsync(expr.metas) { env -> + // we create a snapshot for currentRegister to use during the evaluation + // this is to avoid issue when iterator planner result + val currentRegister = env.registers.clone() + val elements: Flow = flow { + env.load(currentRegister) + val relItr = bexprThunk(env) + while (relItr.nextRow()) { + emit(mapThunk(env)) + } + } + when (relationType) { + RelationType.LIST -> ExprValue.newList(elements.toList()) + RelationType.BAG -> ExprValue.newBag(elements.toList()) + } + } + } + + private suspend fun compileAstExprs(args: List) = args.map { compileAstExpr(it) } + + private suspend fun compileNullIf(expr: PartiqlPhysical.Expr.NullIf, metas: MetaContainer): PhysicalPlanThunkAsync { + val expr1Thunk = compileAstExpr(expr.expr1) + val expr2Thunk = compileAstExpr(expr.expr2) + + // Note: NULLIF does not propagate the unknown values and .exprEquals provides the correct semantics. + return thunkFactory.thunkEnvAsync(metas) { env -> + val expr1Value = expr1Thunk(env) + val expr2Value = expr2Thunk(env) + when { + expr1Value.exprEquals(expr2Value) -> ExprValue.nullValue + else -> expr1Value + } + } + } + + private suspend fun compileCoalesce(expr: PartiqlPhysical.Expr.Coalesce, metas: MetaContainer): PhysicalPlanThunkAsync { + val argThunks = compileAstExprs(expr.args) + + return thunkFactory.thunkEnvAsync(metas) { env -> + var nullFound = false + var knownValue: ExprValue? = null + for (thunk in argThunks) { + val argValue = thunk(env) + if (argValue.isNotUnknown()) { + knownValue = argValue + // No need to execute remaining thunks to save computation as first non-unknown value is found + break + } + if (argValue.type == ExprValueType.NULL) { + nullFound = true + } + } + when (knownValue) { + null -> when { + evaluatorOptions.typingMode == TypingMode.PERMISSIVE && !nullFound -> ExprValue.missingValue + else -> ExprValue.nullValue + } + else -> knownValue + } + } + } + + /** + * Returns a function that accepts an [ExprValue] as an argument and returns true it is `NULL`, `MISSING`, or + * within the range specified by [range]. + */ + private fun integerValueValidator( + range: LongRange + ): (ExprValue) -> Boolean = { value -> + when (value.type) { + ExprValueType.NULL, ExprValueType.MISSING -> true + ExprValueType.INT -> { + val longValue: Long = value.scalar.numberValue()?.toLong() + ?: error( + "ExprValue.numberValue() must not be `NULL` when its type is INT." + + "This indicates that the ExprValue instance has a bug." + ) + + // PRO-TIP: make sure to use the `Long` primitive type here with `.contains` otherwise + // Kotlin will use the version of `.contains` that treats [range] as a collection, and it will + // be very slow! + range.contains(longValue) + } + else -> error( + "The expression's static type was supposed to be INT but instead it was ${value.type}" + + "This may indicate the presence of a bug in the type inferencer." + ) + } + } + + /** + * For operators which could return integer type, check integer overflow in case of [TypingMode.PERMISSIVE]. + */ + private suspend fun checkIntegerOverflow(computeThunk: PhysicalPlanThunkAsync, metas: MetaContainer): PhysicalPlanThunkAsync = + when (val staticTypes = metas.staticType?.type?.getTypes()) { + // No staticType, can't validate integer size. + null -> computeThunk + else -> { + when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> { + // integer size constraints have not been tested under [TypingMode.LEGACY] because the + // [StaticTypeInferenceVisitorTransform] doesn't support being used with legacy mode yet. + // throw an exception in case we encounter this untested scenario. This might work fine, but I + // wouldn't bet on it. + val hasConstrainedInteger = staticTypes.any { + it is IntType && it.rangeConstraint != IntType.IntRangeConstraint.UNCONSTRAINED + } + if (hasConstrainedInteger) { + TODO("Legacy mode doesn't support integer size constraints yet.") + } else { + computeThunk + } + } + TypingMode.PERMISSIVE -> { + val biggestIntegerType = staticTypes.filterIsInstance().maxByOrNull { + it.rangeConstraint.numBytes + } + when (biggestIntegerType) { + is IntType -> { + val validator = integerValueValidator(biggestIntegerType.rangeConstraint.validRange) + + thunkFactory.thunkEnvAsync(metas) { env -> + val naryResult = computeThunk(env) + errorSignaler.errorIf( + !validator(naryResult), + ErrorCode.EVALUATOR_INTEGER_OVERFLOW, + { ErrorDetails(metas, "Integer overflow", errorContextFrom(metas)) }, + { naryResult } + ) + } + } + // If there is no IntType StaticType, can't validate the integer size either. + null -> computeThunk + else -> computeThunk + } + } + } + } + } + + private suspend fun compilePlus(expr: PartiqlPhysical.Expr.Plus, metas: MetaContainer): PhysicalPlanThunkAsync { + if (expr.operands.size < 2) { + error("Internal Error: PartiqlPhysical.Expr.Plus must have at least 2 arguments") + } + + val argThunks = compileAstExprs(expr.operands) + + val computeThunk = thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + (lValue.numberValue() + rValue.numberValue()).exprValue() + } + + return checkIntegerOverflow(computeThunk, metas) + } + + private suspend fun compileMinus(expr: PartiqlPhysical.Expr.Minus, metas: MetaContainer): PhysicalPlanThunkAsync { + if (expr.operands.size < 2) { + error("Internal Error: PartiqlPhysical.Expr.Minus must have at least 2 arguments") + } + + val argThunks = compileAstExprs(expr.operands) + + val computeThunk = thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + (lValue.numberValue() - rValue.numberValue()).exprValue() + } + + return checkIntegerOverflow(computeThunk, metas) + } + + private suspend fun compilePos(expr: PartiqlPhysical.Expr.Pos, metas: MetaContainer): PhysicalPlanThunkAsync { + val exprThunk = compileAstExpr(expr.expr) + + val computeThunk = thunkFactory.thunkEnvOperands(metas, exprThunk) { _, value -> + // Invoking .numberValue() here makes this essentially just a type check + value.numberValue() + // Original value is returned unmodified. + value + } + + return checkIntegerOverflow(computeThunk, metas) + } + + private suspend fun compileNeg(expr: PartiqlPhysical.Expr.Neg, metas: MetaContainer): PhysicalPlanThunkAsync { + val exprThunk = compileAstExpr(expr.expr) + + val computeThunk = thunkFactory.thunkEnvOperands(metas, exprThunk) { _, value -> + (-value.numberValue()).exprValue() + } + + return checkIntegerOverflow(computeThunk, metas) + } + + private suspend fun compileTimes(expr: PartiqlPhysical.Expr.Times, metas: MetaContainer): PhysicalPlanThunkAsync { + val argThunks = compileAstExprs(expr.operands) + + val computeThunk = thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + (lValue.numberValue() * rValue.numberValue()).exprValue() + } + + return checkIntegerOverflow(computeThunk, metas) + } + + private suspend fun compileDivide(expr: PartiqlPhysical.Expr.Divide, metas: MetaContainer): PhysicalPlanThunkAsync { + val argThunks = compileAstExprs(expr.operands) + + val computeThunk = thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + val denominator = rValue.numberValue() + + errorSignaler.errorIf( + denominator.isZero(), + ErrorCode.EVALUATOR_DIVIDE_BY_ZERO, + { ErrorDetails(metas, "/ by zero") } + ) { + try { + (lValue.numberValue() / denominator).exprValue() + } catch (e: ArithmeticException) { + // Setting the internal flag as true as it is not clear what + // ArithmeticException may be thrown by the above + throw EvaluationException( + cause = e, + errorCode = ErrorCode.EVALUATOR_ARITHMETIC_EXCEPTION, + internal = true + ) + } + } + } + + return checkIntegerOverflow(computeThunk, metas) + } + + private suspend fun compileModulo(expr: PartiqlPhysical.Expr.Modulo, metas: MetaContainer): PhysicalPlanThunkAsync { + val argThunks = compileAstExprs(expr.operands) + + val computeThunk = thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + val denominator = rValue.numberValue() + if (denominator.isZero()) { + err("% by zero", ErrorCode.EVALUATOR_MODULO_BY_ZERO, errorContextFrom(metas), internal = false) + } + + (lValue.numberValue() % denominator).exprValue() + } + + return checkIntegerOverflow(computeThunk, metas) + } + + private suspend fun compileBitwiseAnd(expr: PartiqlPhysical.Expr.BitwiseAnd, metas: MetaContainer): PhysicalPlanThunkAsync { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + (lValue.longValue() and rValue.longValue()).exprValue() + } + } + + private suspend fun compileEq(expr: PartiqlPhysical.Expr.Eq, metas: MetaContainer): PhysicalPlanThunkAsync { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkAndMap(metas, argThunks) { lValue, rValue -> + (lValue.exprEquals(rValue)) + } + } + + private suspend fun compileNe(expr: PartiqlPhysical.Expr.Ne, metas: MetaContainer): PhysicalPlanThunkAsync { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + ((!lValue.exprEquals(rValue)).exprValue()) + } + } + + private suspend fun compileLt(expr: PartiqlPhysical.Expr.Lt, metas: MetaContainer): PhysicalPlanThunkAsync { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkAndMap(metas, argThunks) { lValue, rValue -> lValue < rValue } + } + + private suspend fun compileLte(expr: PartiqlPhysical.Expr.Lte, metas: MetaContainer): PhysicalPlanThunkAsync { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkAndMap(metas, argThunks) { lValue, rValue -> lValue <= rValue } + } + + private suspend fun compileGt(expr: PartiqlPhysical.Expr.Gt, metas: MetaContainer): PhysicalPlanThunkAsync { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkAndMap(metas, argThunks) { lValue, rValue -> lValue > rValue } + } + + private suspend fun compileGte(expr: PartiqlPhysical.Expr.Gte, metas: MetaContainer): PhysicalPlanThunkAsync { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkAndMap(metas, argThunks) { lValue, rValue -> lValue >= rValue } + } + + private suspend fun compileBetween(expr: PartiqlPhysical.Expr.Between, metas: MetaContainer): PhysicalPlanThunkAsync { + val valueThunk = compileAstExpr(expr.value) + val fromThunk = compileAstExpr(expr.from) + val toThunk = compileAstExpr(expr.to) + + return thunkFactory.thunkEnvOperands(metas, valueThunk, fromThunk, toThunk) { _, v, f, t -> + (v >= f && v <= t).exprValue() + } + } + + /** + * `IN` can *almost* be thought of has being syntactic sugar for the `OR` operator. + * + * `a IN (b, c, d)` is equivalent to `a = b OR a = c OR a = d`. On deep inspection, there + * are important implications to this regarding propagation of unknown values. Specifically, the + * presence of any unknown in `b`, `c`, or `d` will result in unknown propagation iif `a` does not + * equal `b`, `c`, or `d`. i.e.: + * + * - `1 in (null, 2, 3)` -> `null` + * - `2 in (null, 2, 3)` -> `true` + * - `2 in (1, 2, 3)` -> `true` + * - `0 in (1, 2, 4)` -> `false` + * + * `IN` is varies from the `OR` operator in that this behavior holds true when other types of expressions are + * used on the right side of `IN` such as sub-queries and variables whose value is that of a list or bag. + */ + private suspend fun compileIn(expr: PartiqlPhysical.Expr.InCollection, metas: MetaContainer): PhysicalPlanThunkAsync { + val args = expr.operands + val leftThunk = compileAstExpr(args[0]) + val rightOp = args[1] + + fun isOptimizedCase(values: List): Boolean = values.all { it is PartiqlPhysical.Expr.Lit && !it.value.isNull } + + suspend fun optimizedCase(values: List): PhysicalPlanThunkAsync { + // Put all the literals in the sequence into a pre-computed map to be checked later by the thunk. + // If the left-hand value is one of these we can short-circuit with a result of TRUE. + // This is the fastest possible case and allows for hundreds of literal values (or more) in the + // sequence without a huge performance penalty. + // NOTE: we cannot use a [HashSet<>] here because [ExprValue] does not implement [Object.hashCode] or + // [Object.equals]. + val precomputedLiteralsMap = values + .filterIsInstance() + .mapTo(TreeSet(DEFAULT_COMPARATOR)) { + ExprValue.of( + it.value.toIonValue(ion) + ) + } + + // the compiled thunk simply checks if the left side is contained on the right side. + // thunkEnvOperands takes care of unknown propagation for the left side; for the right, + // this unknown propagation does not apply since we've eliminated the possibility of unknowns above. + return thunkFactory.thunkEnvOperands(metas, leftThunk) { _, leftValue -> + precomputedLiteralsMap.contains(leftValue).exprValue() + } + } + + return when { + // We can significantly optimize this if rightArg is a sequence constructor which is composed of entirely + // of non-null literal values. + rightOp is PartiqlPhysical.Expr.List && isOptimizedCase(rightOp.values) -> optimizedCase(rightOp.values) + rightOp is PartiqlPhysical.Expr.Bag && isOptimizedCase(rightOp.values) -> optimizedCase(rightOp.values) + rightOp is PartiqlPhysical.Expr.Sexp && isOptimizedCase(rightOp.values) -> optimizedCase(rightOp.values) + // The unoptimized case... + else -> { + val rightThunk = compileAstExpr(rightOp) + + // Legacy mode: + // Returns FALSE when the right side of IN is not a sequence + // Returns NULL if the right side is MISSING or any value on the right side is MISSING + // Permissive mode: + // Returns MISSING when the right side of IN is not a sequence + // Returns MISSING if the right side is MISSING or any value on the right side is MISSING + val (propagateMissingAs, propagateNotASeqAs) = when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> ExprValue.nullValue to ExprValue.newBoolean(false) + TypingMode.PERMISSIVE -> ExprValue.missingValue to ExprValue.missingValue + } + + // Note that standard unknown propagation applies to the left and right operands. Both [TypingMode]s + // are handled by [ThunkFactory.thunkEnvOperands] and that additional rules for unknown propagation are + // implemented within the thunk for the values within the sequence on the right side of IN. + thunkFactory.thunkEnvOperands(metas, leftThunk, rightThunk) { _, leftValue, rightValue -> + var nullSeen = false + var missingSeen = false + + when { + rightValue.type == ExprValueType.MISSING -> propagateMissingAs + !rightValue.type.isSequence -> propagateNotASeqAs + else -> { + rightValue.forEach { + when (it.type) { + ExprValueType.NULL -> nullSeen = true + ExprValueType.MISSING -> missingSeen = true + // short-circuit to TRUE on the first matching value + else -> if (it.exprEquals(leftValue)) { + return@thunkEnvOperands ExprValue.newBoolean(true) + } + } + } + // If we make it here then there was no match. Propagate MISSING, NULL or return false. + // Note that if both MISSING and NULL was encountered, MISSING takes precedence. + when { + missingSeen -> propagateMissingAs + nullSeen -> ExprValue.nullValue + else -> ExprValue.newBoolean(false) + } + } + } + } + } + } + } + + private suspend fun compileNot(expr: PartiqlPhysical.Expr.Not, metas: MetaContainer): PhysicalPlanThunkAsync { + val argThunk = compileAstExpr(expr.expr) + + return thunkFactory.thunkEnvOperands(metas, argThunk) { _, value -> + (!value.booleanValue()).exprValue() + } + } + + private suspend fun compileAnd(expr: PartiqlPhysical.Expr.And, metas: MetaContainer): PhysicalPlanThunkAsync { + val argThunks = compileAstExprs(expr.operands) + + // can't use the null propagation supplied by [ThunkFactory.thunkEnv] here because AND short-circuits on + // false values and *NOT* on NULL or MISSING + return when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> thunkFactory.thunkEnvAsync(metas) thunk@{ env -> + var hasUnknowns = false + argThunks.forEach { currThunk -> + val currValue = currThunk(env) + when { + currValue.isUnknown() -> hasUnknowns = true + // Short circuit only if we encounter a known false value. + !currValue.booleanValue() -> return@thunk ExprValue.newBoolean(false) + } + } + + when (hasUnknowns) { + true -> ExprValue.nullValue + false -> ExprValue.newBoolean(true) + } + } + TypingMode.PERMISSIVE -> thunkFactory.thunkEnvAsync(metas) thunk@{ env -> + var hasNull = false + var hasMissing = false + argThunks.forEach { currThunk -> + val currValue = currThunk(env) + when (currValue.type) { + // Short circuit only if we encounter a known false value. + ExprValueType.BOOL -> if (!currValue.booleanValue()) return@thunk ExprValue.newBoolean(false) + ExprValueType.NULL -> hasNull = true + // type mismatch, return missing + else -> hasMissing = true + } + } + + when { + hasMissing -> ExprValue.missingValue + hasNull -> ExprValue.nullValue + else -> ExprValue.newBoolean(true) + } + } + } + } + + private suspend fun compileOr(expr: PartiqlPhysical.Expr.Or, metas: MetaContainer): PhysicalPlanThunkAsync { + val argThunks = compileAstExprs(expr.operands) + + // can't use the null propagation supplied by [ThunkFactory.thunkEnv] here because OR short-circuits on + // true values and *NOT* on NULL or MISSING + return when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> + thunkFactory.thunkEnvAsync(metas) thunk@{ env -> + var hasUnknowns = false + argThunks.forEach { currThunk -> + val currValue = currThunk(env) + // How null-propagation works for OR is rather weird according to the SQL-92 spec. + // Nulls are propagated like other expressions only when none of the terms are TRUE. + // If any one of them is TRUE, then the entire expression evaluates to TRUE, i.e.: + // NULL OR TRUE -> TRUE + // NULL OR FALSE -> NULL + // (strange but true) + when { + currValue.isUnknown() -> hasUnknowns = true + currValue.booleanValue() -> return@thunk ExprValue.newBoolean(true) + } + } + + when (hasUnknowns) { + true -> ExprValue.nullValue + false -> ExprValue.newBoolean(false) + } + } + TypingMode.PERMISSIVE -> thunkFactory.thunkEnvAsync(metas) thunk@{ env -> + var hasNull = false + var hasMissing = false + argThunks.forEach { currThunk -> + val currValue = currThunk(env) + when (currValue.type) { + // Short circuit only if we encounter a known true value. + ExprValueType.BOOL -> if (currValue.booleanValue()) return@thunk ExprValue.newBoolean(true) + ExprValueType.NULL -> hasNull = true + else -> hasMissing = true // type mismatch, return missing. + } + } + + when { + hasMissing -> ExprValue.missingValue + hasNull -> ExprValue.nullValue + else -> ExprValue.newBoolean(false) + } + } + } + } + + private suspend fun compileConcat(expr: PartiqlPhysical.Expr.Concat, metas: MetaContainer): PhysicalPlanThunkAsync { + val argThunks = compileAstExprs(expr.operands) + + return thunkFactory.thunkFold(metas, argThunks) { lValue, rValue -> + val lType = lValue.type + val rType = rValue.type + + if (lType.isText && rType.isText) { + // null/missing propagation is handled before getting here + (lValue.stringValue() + rValue.stringValue()).exprValue() + } else { + err( + "Wrong argument type for ||", + ErrorCode.EVALUATOR_CONCAT_FAILED_DUE_TO_INCOMPATIBLE_TYPE, + errorContextFrom(metas).also { + it[Property.ACTUAL_ARGUMENT_TYPES] = listOf(lType, rType).toString() + }, + internal = false + ) + } + } + } + + private suspend fun compileCall(expr: PartiqlPhysical.Expr.Call, metas: MetaContainer): PhysicalPlanThunkAsync { + val funcArgThunks = compileAstExprs(expr.args) + val arity = funcArgThunks.size + val name = expr.funcName.text + return thunkFactory.thunkEnvAsync(metas) { env -> + val args = funcArgThunks.map { thunk -> thunk(env) } + val argTypes = args.map { staticTypeFromExprValue(it) } + try { + val func = functionManager.get(name = name, arity = arity, args = argTypes) + val computeThunk = when (func.signature.unknownArguments) { + UnknownArguments.PROPAGATE -> thunkFactory.thunkEnvOperands(metas, funcArgThunks) { env, _ -> + func.call(env.session, args) + } + UnknownArguments.PASS_THRU -> thunkFactory.thunkEnvAsync(metas) { env -> + func.call(env.session, args) + } + } + checkIntegerOverflow(computeThunk, metas)(env) + } catch (e: FunctionNotFoundException) { + err( + "No such function: $name", + ErrorCode.EVALUATOR_NO_SUCH_FUNCTION, + errorContextFrom(metas).also { + it[Property.FUNCTION_NAME] = name + }, + internal = false + ) + } catch (e: ArityMismatchException) { + val (minArity, maxArity) = e.arity + val errorContext = errorContextFrom(metas).also { + it[Property.FUNCTION_NAME] = name + it[Property.EXPECTED_ARITY_MIN] = minArity + it[Property.EXPECTED_ARITY_MAX] = maxArity + it[Property.ACTUAL_ARITY] = arity + } + err( + "No function found with matching arity: $name", + ErrorCode.EVALUATOR_INCORRECT_NUMBER_OF_ARGUMENTS_TO_FUNC_CALL, + errorContext, + internal = false + ) + } + } + } + + private suspend fun compileLit(expr: PartiqlPhysical.Expr.Lit, metas: MetaContainer): PhysicalPlanThunkAsync { + val value = ExprValue.of(expr.value.toIonValue(ion)) + + return thunkFactory.thunkEnvAsync(metas) { value } + } + + private suspend fun compileMissing(metas: MetaContainer): PhysicalPlanThunkAsync = + thunkFactory.thunkEnvAsync(metas) { ExprValue.missingValue } + + private suspend fun compileGlobalId(expr: PartiqlPhysical.Expr.GlobalId): PhysicalPlanThunkAsync { + // TODO: we really should consider using something other than `Bindings` for global variables + // with the physical plan evaluator because `Bindings.get()` accepts a `BindingName` instance + // which contains the `case` property which is always set to `SENSITIVE` and is therefore redundant. + val bindingName = BindingName(expr.uniqueId.text, BindingCase.SENSITIVE) + return thunkFactory.thunkEnvAsync(expr.metas) { env -> + env.session.globals[bindingName] ?: throwUndefinedVariableException(bindingName, expr.metas) + } + } + + @Suppress("UNUSED_PARAMETER") + private suspend fun compileLocalId(expr: PartiqlPhysical.Expr.LocalId, metas: MetaContainer): PhysicalPlanThunkAsync { + val localIndex = expr.index.value.toIntExact() + return thunkFactory.thunkEnvAsync(metas) { env -> + env.registers[localIndex] + } + } + + private fun compileParameter(expr: PartiqlPhysical.Expr.Parameter, metas: MetaContainer): PhysicalPlanThunkAsync { + val ordinal = expr.index.value.toInt() + val index = ordinal - 1 + + return { env -> + val params = env.session.parameters + if (params.size <= index) { + throw EvaluationException( + "Unbound parameter for ordinal: $ordinal", + ErrorCode.EVALUATOR_UNBOUND_PARAMETER, + errorContextFrom(metas).also { + it[Property.EXPECTED_PARAMETER_ORDINAL] = ordinal + it[Property.BOUND_PARAMETER_COUNT] = params.size + }, + internal = false + ) + } + params[index] + } + } + + /** + * Returns a lambda that implements the `IS` operator type check according to the current + * [TypedOpBehavior]. + */ + private fun makeIsCheck( + staticType: SingleType, + typedOpParameter: TypedOpParameter, + metas: MetaContainer + ): (ExprValue) -> Boolean { + return when (evaluatorOptions.typedOpBehavior) { + TypedOpBehavior.HONOR_PARAMETERS -> { expValue: ExprValue -> + staticType.allTypes.any { + val matchesStaticType = try { + isInstance(expValue, it) + } catch (e: UnsupportedTypeCheckException) { + err( + e.message!!, + ErrorCode.UNIMPLEMENTED_FEATURE, + errorContextFrom(metas), + internal = true + ) + } + + when { + !matchesStaticType -> false + else -> when (val validator = typedOpParameter.validationThunk) { + null -> true + else -> validator(expValue) + } + } + } + } + } + } + + private suspend fun compileIs(expr: PartiqlPhysical.Expr.IsType, metas: MetaContainer): PhysicalPlanThunkAsync { + val expThunk = compileAstExpr(expr.value) + val typedOpParameter = expr.type.toTypedOpParameter(customTypedOpParameters) + if (typedOpParameter.staticType is AnyType) { + return thunkFactory.thunkEnvAsync(metas) { ExprValue.newBoolean(true) } + } + if (evaluatorOptions.typedOpBehavior == TypedOpBehavior.HONOR_PARAMETERS && expr.type is PartiqlPhysical.Type.FloatType && (expr.type as PartiqlPhysical.Type.FloatType).precision != null) { + err( + "FLOAT precision parameter is unsupported", + ErrorCode.SEMANTIC_FLOAT_PRECISION_UNSUPPORTED, + errorContextFrom(expr.type.metas), + internal = false + ) + } + + val typeMatchFunc = when (val staticType = typedOpParameter.staticType) { + is SingleType -> makeIsCheck(staticType, typedOpParameter, metas) + is AnyOfType -> staticType.types.map { childType -> + when (childType) { + is SingleType -> makeIsCheck(childType, typedOpParameter, metas) + else -> err( + "Union type cannot have ANY or nested AnyOf type for IS", + ErrorCode.SEMANTIC_UNION_TYPE_INVALID, + errorContextFrom(metas), + internal = true + ) + } + }.let { typeMatchFuncs -> + { expValue: ExprValue -> typeMatchFuncs.any { func -> func(expValue) } } + } + is AnyType -> throw IllegalStateException("Unexpected ANY type in IS compilation") + } + + return thunkFactory.thunkEnvAsync(metas) { env -> + val expValue = expThunk(env) + typeMatchFunc(expValue).exprValue() + } + } + + private suspend fun compileCastHelper(value: PartiqlPhysical.Expr, asType: PartiqlPhysical.Type, metas: MetaContainer): PhysicalPlanThunkAsync { + val expThunk = compileAstExpr(value) + val typedOpParameter = asType.toTypedOpParameter(customTypedOpParameters) + if (typedOpParameter.staticType is AnyType) { + return expThunk + } + if (evaluatorOptions.typedOpBehavior == TypedOpBehavior.HONOR_PARAMETERS && asType is PartiqlPhysical.Type.FloatType && asType.precision != null) { + err( + "FLOAT precision parameter is unsupported", + ErrorCode.SEMANTIC_FLOAT_PRECISION_UNSUPPORTED, + errorContextFrom(asType.metas), + internal = false + ) + } + + fun typeOpValidate( + value: ExprValue, + castOutput: ExprValue, + typeName: String, + locationMeta: SourceLocationMeta? + ) { + if (typedOpParameter.validationThunk?.let { it(castOutput) } == false) { + val errorContext = PropertyValueMap().also { + it[Property.CAST_FROM] = value.type.toString() + it[Property.CAST_TO] = typeName + } + + locationMeta?.let { fillErrorContext(errorContext, it) } + + throw EvaluationException( + "Validation failure for $asType", + ErrorCode.EVALUATOR_CAST_FAILED, + errorContext, + internal = false + ) + } + } + + fun singleTypeCastFunc(singleType: SingleType): CastFunc { + val locationMeta = metas.sourceLocationMeta + return { value -> + val castOutput = value.cast( + singleType, + evaluatorOptions.typedOpBehavior, + locationMeta, + evaluatorOptions.defaultTimezoneOffset + ) + typeOpValidate(value, castOutput, getRuntimeType(singleType).toString(), locationMeta) + castOutput + } + } + + fun compileSingleTypeCast(singleType: SingleType): PhysicalPlanThunkAsync { + val castFunc = singleTypeCastFunc(singleType) + // We do not use thunkFactory here because we want to explicitly avoid + // the optional evaluation-time type check for CAN_CAST below. + // Can cast needs that returns false if an + // exception is thrown during a normal cast operation. + return { env -> + val valueToCast = expThunk(env) + castFunc(valueToCast) + } + } + + fun compileCast(type: StaticType): PhysicalPlanThunkAsync = when (type) { + is SingleType -> compileSingleTypeCast(type) + is AnyOfType -> { + val locationMeta = metas.sourceLocationMeta + val castTable = AnyOfCastTable(type, metas, ::singleTypeCastFunc); + + // We do not use thunkFactory here because we want to explicitly avoid + // the optional evaluation-time type check for CAN_CAST below. + // note that this would interfere with the error handling for can_cast that returns false if an + // exception is thrown during a normal cast operation. + { env -> + val sourceValue = expThunk(env) + castTable.cast(sourceValue).also { + // TODO put the right type name here + typeOpValidate(sourceValue, it, "", locationMeta) + } + } + } + is AnyType -> throw IllegalStateException("Unreachable code") + } + + return compileCast(typedOpParameter.staticType) + } + + private suspend fun compileCast(expr: PartiqlPhysical.Expr.Cast, metas: MetaContainer): PhysicalPlanThunkAsync = + thunkFactory.thunkEnvAsync(metas, compileCastHelper(expr.value, expr.asType, metas)) + + private suspend fun compileCanCast(expr: PartiqlPhysical.Expr.CanCast, metas: MetaContainer): PhysicalPlanThunkAsync { + val typedOpParameter = expr.asType.toTypedOpParameter(customTypedOpParameters) + if (typedOpParameter.staticType is AnyType) { + return thunkFactory.thunkEnvAsync(metas) { ExprValue.newBoolean(true) } + } + + val expThunk = compileAstExpr(expr.value) + + // TODO consider making this more efficient by not directly delegating to CAST + // TODO consider also making the operand not double evaluated (e.g. having expThunk memoize) + val castThunkEnv = compileCastHelper(expr.value, expr.asType, expr.metas) + return thunkFactory.thunkEnvAsync(metas) { env -> + val sourceValue = expThunk(env) + try { + when { + // NULL/MISSING can cast to anything as themselves + sourceValue.isUnknown() -> ExprValue.newBoolean(true) + else -> { + val castedValue = castThunkEnv(env) + when { + // NULL/MISSING from cast is a permissive way to signal failure + castedValue.isUnknown() -> ExprValue.newBoolean(false) + else -> ExprValue.newBoolean(true) + } + } + } + } catch (e: EvaluationException) { + if (e.internal) { + throw e + } + ExprValue.newBoolean(false) + } + } + } + + private suspend fun compileCanLosslessCast(expr: PartiqlPhysical.Expr.CanLosslessCast, metas: MetaContainer): PhysicalPlanThunkAsync { + val typedOpParameter = expr.asType.toTypedOpParameter(customTypedOpParameters) + if (typedOpParameter.staticType is AnyType) { + return thunkFactory.thunkEnvAsync(metas) { ExprValue.newBoolean(true) } + } + + val expThunk = compileAstExpr(expr.value) + + // TODO consider making this more efficient by not directly delegating to CAST + val castThunkEnv = compileCastHelper(expr.value, expr.asType, expr.metas) + return thunkFactory.thunkEnvAsync(metas) { env -> + val sourceValue = expThunk(env) + val sourceType = staticTypeFromExprValue(sourceValue) + + suspend fun roundTrip(): ExprValue { + val castedValue = castThunkEnv(env) + + val locationMeta = metas.sourceLocationMeta + fun castFunc(singleType: SingleType) = + { value: ExprValue -> + value.cast( + singleType, + evaluatorOptions.typedOpBehavior, + locationMeta, + evaluatorOptions.defaultTimezoneOffset + ) + } + + val roundTripped = when (sourceType) { + is SingleType -> castFunc(sourceType)(castedValue) + is AnyOfType -> { + val castTable = AnyOfCastTable(sourceType, metas, ::castFunc) + castTable.cast(sourceValue) + } + // Should not be possible + is AnyType -> throw IllegalStateException("ANY type is not configured correctly in compiler") + } + + val lossless = sourceValue.exprEquals(roundTripped) + return ExprValue.newBoolean(lossless) + } + + try { + when (sourceValue.type) { + // NULL can cast to anything as itself + ExprValueType.NULL -> ExprValue.newBoolean(true) + + // Short-circuit timestamp -> date roundtrip if precision isn't [Timestamp.Precision.DAY] or + // [Timestamp.Precision.MONTH] or [Timestamp.Precision.YEAR] + ExprValueType.TIMESTAMP -> when (typedOpParameter.staticType) { + StaticType.DATE -> when (sourceValue.timestampValue().precision) { + Timestamp.Precision.DAY, Timestamp.Precision.MONTH, Timestamp.Precision.YEAR -> roundTrip() + else -> ExprValue.newBoolean(false) + } + StaticType.TIME -> ExprValue.newBoolean(false) + else -> roundTrip() + } + + // For all other cases, attempt a round-trip of the value through the source and dest types + else -> roundTrip() + } + } catch (e: EvaluationException) { + if (e.internal) { + throw e + } + ExprValue.newBoolean(false) + } + } + } + + private suspend fun compileSimpleCase(expr: PartiqlPhysical.Expr.SimpleCase, metas: MetaContainer): PhysicalPlanThunkAsync { + val valueThunk = compileAstExpr(expr.expr) + val branchThunks = expr.cases.pairs.map { Pair(compileAstExpr(it.first), compileAstExpr(it.second)) } + val elseThunk = when (val default = expr.default) { + null -> thunkFactory.thunkEnvAsync(metas) { ExprValue.nullValue } + else -> compileAstExpr(default) + } + + return thunkFactory.thunkEnvAsync(metas) thunk@{ env -> + val caseValue = valueThunk(env) + // if the case value is unknown then we can short-circuit to the elseThunk directly + when { + caseValue.isUnknown() -> elseThunk(env) + else -> { + branchThunks.forEach { bt -> + val branchValue = bt.first(env) + // Just skip any branch values that are unknown, which we consider the same as false here. + when { + branchValue.isUnknown() -> { /* intentionally blank */ + } + else -> { + if (caseValue.exprEquals(branchValue)) { + return@thunk bt.second(env) + } + } + } + } + } + } + elseThunk(env) + } + } + + private suspend fun compileSearchedCase(expr: PartiqlPhysical.Expr.SearchedCase, metas: MetaContainer): PhysicalPlanThunkAsync { + val branchThunks = expr.cases.pairs.map { compileAstExpr(it.first) to compileAstExpr(it.second) } + val elseThunk = when (val default = expr.default) { + null -> thunkFactory.thunkEnvAsync(metas) { ExprValue.nullValue } + else -> compileAstExpr(default) + } + + return when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> thunkFactory.thunkEnvAsync(metas) thunk@{ env -> + branchThunks.forEach { bt -> + val conditionValue = bt.first(env) + // Any unknown value is considered the same as false. + // Note that .booleanValue() here will throw an EvaluationException if + // the data type is not boolean. + // TODO: .booleanValue does not have access to metas, so the EvaluationException is reported to be + // at the line & column of the CASE statement, not the predicate, unfortunately. + if (conditionValue.isNotUnknown() && conditionValue.booleanValue()) { + return@thunk bt.second(env) + } + } + elseThunk(env) + } + // Permissive mode propagates data type mismatches as MISSING, which is + // equivalent to false for searched CASE predicates. To simplify this, + // all we really need to do is consider any non-boolean result from the + // predicate to be false. + TypingMode.PERMISSIVE -> thunkFactory.thunkEnvAsync(metas) thunk@{ env -> + branchThunks.forEach { bt -> + val conditionValue = bt.first(env) + if (conditionValue.type == ExprValueType.BOOL && conditionValue.booleanValue()) { + return@thunk bt.second(env) + } + } + elseThunk(env) + } + } + } + + private suspend fun compileStruct(expr: PartiqlPhysical.Expr.Struct): PhysicalPlanThunkAsync { + val structParts = compileStructParts(expr.parts) + + val ordering = if (expr.parts.none { it is PartiqlPhysical.StructPart.StructFields }) + StructOrdering.ORDERED + else + StructOrdering.UNORDERED + + return thunkFactory.thunkEnvAsync(expr.metas) { env -> + val columns = mutableListOf() + for (element in structParts) { + when (element) { + is CompiledStructPartAsync.Field -> { + val fieldName = element.nameThunk(env) + when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> + if (!fieldName.type.isText) { + err( + "Found struct field key to be of type ${fieldName.type}", + ErrorCode.EVALUATOR_NON_TEXT_STRUCT_FIELD_KEY, + errorContextFrom(expr.metas.sourceLocationMeta).also { pvm -> + pvm[Property.ACTUAL_TYPE] = fieldName.type.toString() + }, + internal = false + ) + } + TypingMode.PERMISSIVE -> + if (!fieldName.type.isText) { + continue + } + } + val fieldValue = element.valueThunk(env) + columns.add(fieldValue.namedValue(fieldName)) + } + is CompiledStructPartAsync.StructMerge -> { + for (projThunk in element.thunks) { + val value = projThunk(env) + if (value.type == ExprValueType.MISSING) continue + + val children = value.asSequence() + if (!children.any() || value.type.isSequence) { + val name = syntheticColumnName(columns.size).exprValue() + columns.add(value.namedValue(name)) + } else { + val valuesToProject = + when (evaluatorOptions.projectionIteration) { + ProjectionIterationBehavior.FILTER_MISSING -> { + value.filter { it.type != ExprValueType.MISSING } + } + ProjectionIterationBehavior.UNFILTERED -> value + } + for (childValue in valuesToProject) { + val namedFacet = childValue.asFacet(Named::class.java) + val name = namedFacet?.name + ?: syntheticColumnName(columns.size).exprValue() + columns.add(childValue.namedValue(name)) + } + } + } + } + } + } + createStructExprValue(columns.asSequence(), ordering) + } + } + + private suspend fun compileStructParts(projectItems: List): List = + projectItems.map { + when (it) { + is PartiqlPhysical.StructPart.StructField -> { + val fieldThunk = compileAstExpr(it.fieldName) + val valueThunk = compileAstExpr(it.value) + CompiledStructPartAsync.Field(fieldThunk, valueThunk) + } + is PartiqlPhysical.StructPart.StructFields -> { + CompiledStructPartAsync.StructMerge(listOf(compileAstExpr(it.partExpr))) + } + } + } + + private suspend fun compileSeq(seqType: ExprValueType, itemExprs: List, metas: MetaContainer): PhysicalPlanThunkAsync { + require(seqType.isSequence) { "seqType must be a sequence!" } + + val itemThunks = compileAstExprs(itemExprs) + + val makeItemThunkSequence = when (seqType) { + ExprValueType.BAG -> { env: EvaluatorState -> + itemThunks.asFlow().map { itemThunk -> + // call to unnamedValue() makes sure we don't expose any underlying value name/ordinal + itemThunk(env).unnamedValue() + } + } + else -> { env: EvaluatorState -> + itemThunks.asFlow().withIndex().map { indexedVal -> + indexedVal.value(env).namedValue(indexedVal.index.exprValue()) + } + } + } + + return thunkFactory.thunkEnvAsync(metas) { env -> + when (seqType) { + ExprValueType.BAG -> ExprValue.newBag(makeItemThunkSequence(env).toList()) + ExprValueType.LIST -> ExprValue.newList(makeItemThunkSequence(env).toList()) + ExprValueType.SEXP -> ExprValue.newSexp(makeItemThunkSequence(env).toList()) + else -> error("sequence type required") + } + } + } + + private suspend fun compilePath(expr: PartiqlPhysical.Expr.Path, metas: MetaContainer): PhysicalPlanThunkAsync { + val rootThunk = compileAstExpr(expr.root) + val remainingComponents = LinkedList() + + expr.steps.forEach { remainingComponents.addLast(it) } + + val componentThunk = compilePathComponents(remainingComponents, metas) + + return thunkFactory.thunkEnvAsync(metas) { env -> + val rootValue = rootThunk(env) + componentThunk(env, rootValue) + } + } + + private suspend fun compilePathComponents( + remainingComponents: LinkedList, + pathMetas: MetaContainer + ): PhysicalPlanThunkValueAsync { + + val componentThunks = ArrayList>() + + while (!remainingComponents.isEmpty()) { + val pathComponent = remainingComponents.removeFirst() + val componentMetas = pathComponent.metas + componentThunks.add( + when (pathComponent) { + is PartiqlPhysical.PathStep.PathExpr -> { + val indexExpr = pathComponent.index + val caseSensitivity = pathComponent.case + when { + // If indexExpr is a literal string, there is no need to evaluate it--just compile a + // thunk that directly returns a bound value + indexExpr is PartiqlPhysical.Expr.Lit && indexExpr.value.toIonValue(ion) is IonString -> { + val lookupName = BindingName( + indexExpr.value.toIonValue(ion).stringValue()!!, + caseSensitivity.toBindingCase() + ) + thunkFactory.thunkEnvValue(componentMetas) { _, componentValue -> + componentValue.bindings[lookupName] ?: ExprValue.missingValue + } + } + else -> { + val indexThunk = compileAstExpr(indexExpr) + thunkFactory.thunkEnvValue(componentMetas) { env, componentValue -> + val indexValue = indexThunk(env) + when { + indexValue.type == ExprValueType.INT -> { + componentValue.ordinalBindings[indexValue.numberValue().toInt()] + } + indexValue.type.isText -> { + val lookupName = + BindingName(indexValue.stringValue(), caseSensitivity.toBindingCase()) + componentValue.bindings[lookupName] + } + else -> { + when (evaluatorOptions.typingMode) { + TypingMode.LEGACY -> err( + "Cannot convert index to int/string: $indexValue", + ErrorCode.EVALUATOR_INVALID_CONVERSION, + errorContextFrom(componentMetas), + internal = false + ) + TypingMode.PERMISSIVE -> ExprValue.missingValue + } + } + } ?: ExprValue.missingValue + } + } + } + } + is PartiqlPhysical.PathStep.PathUnpivot -> { + when { + !remainingComponents.isEmpty() -> { + val tempThunk = compilePathComponents(remainingComponents, pathMetas) + thunkFactory.thunkEnvValue(componentMetas) { env, componentValue -> + val mapped = componentValue.unpivot() + .flatMap { tempThunk(env, it).rangeOver() } + .asSequence() + ExprValue.newBag(mapped) + } + } + else -> + thunkFactory.thunkEnvValue(componentMetas) { _, componentValue -> + ExprValue.newBag(componentValue.unpivot().asSequence()) + } + } + } + // this is for `path[*].component` + is PartiqlPhysical.PathStep.PathWildcard -> { + when { + !remainingComponents.isEmpty() -> { + val hasMoreWildCards = + remainingComponents.filterIsInstance().any() + val tempThunk = compilePathComponents(remainingComponents, pathMetas) + + when { + !hasMoreWildCards -> thunkFactory.thunkEnvValue(componentMetas) { env, componentValue -> + val mapped = componentValue + .rangeOver() + .map { tempThunk(env, it) } + .asSequence() + + ExprValue.newBag(mapped) + } + else -> thunkFactory.thunkEnvValue(componentMetas) { env, componentValue -> + val mapped = componentValue + .rangeOver() + .flatMap { + val tempValue = tempThunk(env, it) + tempValue + } + .asSequence() + + ExprValue.newBag(mapped) + } + } + } + else -> { + thunkFactory.thunkEnvValue(componentMetas) { _, componentValue -> + val mapped = componentValue.rangeOver().asSequence() + ExprValue.newBag(mapped) + } + } + } + } + } + ) + } + return when (componentThunks.size) { + 1 -> componentThunks.first() + else -> thunkFactory.thunkEnvValue(pathMetas) { env, rootValue -> + componentThunks.fold(rootValue) { componentValue, componentThunk -> + componentThunk(env, componentValue) + } + } + } + } + + /** + * Given an AST node that represents a `LIKE` predicate, return an ExprThunk that evaluates a `LIKE` predicate. + * + * Three cases + * + * 1. All arguments are literals, then compile and run the pattern + * 1. Search pattern and escape pattern are literals, compile the pattern. Running the pattern deferred to evaluation time. + * 1. Pattern or escape (or both) are *not* literals, compile and running of pattern deferred to evaluation time. + * + * ``` + * LIKE [ESCAPE ] + * ``` + * + * @return a thunk that when provided with an environment evaluates the `LIKE` predicate + */ + private suspend fun compileLike(expr: PartiqlPhysical.Expr.Like, metas: MetaContainer): PhysicalPlanThunkAsync { + val valueExpr = expr.value + val patternExpr = expr.pattern + val escapeExpr = expr.escape + + val patternLocationMeta = patternExpr.metas.sourceLocation + val escapeLocationMeta = escapeExpr?.metas?.sourceLocation + + // This is so that null short-circuits can be supported. + fun getRegexPattern(pattern: ExprValue, escape: ExprValue?): (() -> Pattern)? { + val patternArgs = listOfNotNull(pattern, escape) + when { + patternArgs.any { it.type.isUnknown } -> return null + patternArgs.any { !it.type.isText } -> return { + err( + "LIKE expression must be given non-null strings as input", + ErrorCode.EVALUATOR_LIKE_INVALID_INPUTS, + errorContextFrom(metas).also { + it[Property.LIKE_PATTERN] = pattern.toString() + if (escape != null) it[Property.LIKE_ESCAPE] = escape.toString() + }, + internal = false + ) + } + else -> { + val (patternString: String, escapeChar: Int?) = + checkPattern(pattern.stringValue(), patternLocationMeta, escape?.stringValue(), escapeLocationMeta) + val likeRegexPattern = when { + patternString.isEmpty() -> Pattern.compile("") + else -> parsePattern(patternString, escapeChar) + } + return { likeRegexPattern } + } + } + } + + fun matchRegexPattern(value: ExprValue, likePattern: (() -> Pattern)?): ExprValue { + return when { + likePattern == null || value.type.isUnknown -> ExprValue.nullValue + !value.type.isText -> err( + "LIKE expression must be given non-null strings as input", + ErrorCode.EVALUATOR_LIKE_INVALID_INPUTS, + errorContextFrom(metas).also { + it[Property.LIKE_VALUE] = value.toString() + }, + internal = false + ) + else -> ExprValue.newBoolean(likePattern().matcher(value.stringValue()).matches()) + } + } + + val valueThunk = compileAstExpr(valueExpr) + + // If the pattern and escape expressions are literals then we can compile the pattern now and + // re-use it with every execution. Otherwise, we must re-compile the pattern every time. + return when { + patternExpr is PartiqlPhysical.Expr.Lit && (escapeExpr == null || escapeExpr is PartiqlPhysical.Expr.Lit) -> { + val patternParts = getRegexPattern( + ExprValue.of(patternExpr.value.toIonValue(ion)), + (escapeExpr as? PartiqlPhysical.Expr.Lit)?.value?.toIonValue(ion) + ?.let { ExprValue.of(it) } + ) + + // If valueExpr is also a literal then we can evaluate this at compile time and return a constant. + if (valueExpr is PartiqlPhysical.Expr.Lit) { + val resultValue = matchRegexPattern( + ExprValue.of(valueExpr.value.toIonValue(ion)), + patternParts + ) + return thunkFactory.thunkEnvAsync(metas) { resultValue } + } else { + thunkFactory.thunkEnvOperands(metas, valueThunk) { _, value -> + matchRegexPattern(value, patternParts) + } + } + } + else -> { + val patternThunk = compileAstExpr(patternExpr) + when (escapeExpr) { + null -> { + // thunk that re-compiles the DFA every evaluation without a custom escape sequence + thunkFactory.thunkEnvOperands(metas, valueThunk, patternThunk) { _, value, pattern -> + val pps = getRegexPattern(pattern, null) + matchRegexPattern(value, pps) + } + } + else -> { + // thunk that re-compiles the pattern every evaluation but *with* a custom escape sequence + val escapeThunk = compileAstExpr(escapeExpr) + thunkFactory.thunkEnvOperands( + metas, + valueThunk, + patternThunk, + escapeThunk + ) { _, value, pattern, escape -> + val pps = getRegexPattern(pattern, escape) + matchRegexPattern(value, pps) + } + } + } + } + } + } + + /** + * Given the pattern and optional escape character in a `LIKE` predicate as [IonValue]s + * check their validity based on the SQL92 spec and return a triple that contains in order + * + * - the search pattern as a string + * - the escape character, possibly `null` + * - the length of the search pattern. The length of the search pattern is either + * - the length of the string representing the search pattern when no escape character is used + * - the length of the string representing the search pattern without counting uses of the escape character + * when an escape character is used + * + * A search pattern is valid when + * 1. pattern is not null + * 1. pattern contains characters where `_` means any 1 character and `%` means any string of length 0 or more + * 1. if the escape character is specified then pattern can be deterministically partitioned into character groups where + * 1. A length 1 character group consists of any character other than the ESCAPE character + * 1. A length 2 character group consists of the ESCAPE character followed by either `_` or `%` or the ESCAPE character itself + * + * @param pattern search pattern + * @param escape optional escape character provided in the `LIKE` predicate + * + * @return a triple that contains in order the search pattern as a [String], optionally the code point for the escape character if one was provided + * and the size of the search pattern excluding uses of the escape character + */ + private fun checkPattern( + pattern: String, + patternLocationMeta: SourceLocationMeta?, + escape: String?, + escapeLocationMeta: SourceLocationMeta? + ): Pair { + + escape?.let { + val escapeCharString = checkEscapeChar(escape, escapeLocationMeta) + val escapeCharCodePoint = escapeCharString.codePointAt(0) // escape is a string of length 1 + val validEscapedChars = setOf('_'.code, '%'.code, escapeCharCodePoint) + val iter = pattern.codePointSequence().iterator() + + while (iter.hasNext()) { + val current = iter.next() + if (current == escapeCharCodePoint && (!iter.hasNext() || !validEscapedChars.contains(iter.next()))) { + err( + "Invalid escape sequence : $pattern", + ErrorCode.EVALUATOR_LIKE_PATTERN_INVALID_ESCAPE_SEQUENCE, + errorContextFrom(patternLocationMeta).apply { + set(Property.LIKE_PATTERN, pattern) + set(Property.LIKE_ESCAPE, escapeCharString) + }, + internal = false + ) + } + } + return Pair(pattern, escapeCharCodePoint) + } + return Pair(pattern, null) + } + + /** + * Given an [IonValue] to be used as the escape character in a `LIKE` predicate check that it is + * a valid character based on the SQL Spec. + * + * + * A value is a valid escape when + * 1. it is 1 character long, and, + * 1. Cannot be null (SQL92 spec marks this cases as *unknown*) + * + * @param escape value provided as an escape character for a `LIKE` predicate + * + * @return the escape character as a [String] or throws an exception when the input is invalid + */ + private fun checkEscapeChar(escape: String, locationMeta: SourceLocationMeta?): String { + when (escape) { + "" -> { + err( + "Cannot use empty character as ESCAPE character in a LIKE predicate: $escape", + ErrorCode.EVALUATOR_LIKE_PATTERN_INVALID_ESCAPE_SEQUENCE, + errorContextFrom(locationMeta), + internal = false + ) + } + else -> { + if (escape.trim().length != 1) { + err( + "Escape character must have size 1 : $escape", + ErrorCode.EVALUATOR_LIKE_PATTERN_INVALID_ESCAPE_SEQUENCE, + errorContextFrom(locationMeta), + internal = false + ) + } + } + } + return escape + } + + private suspend fun compileExec(node: PartiqlPhysical.Statement.Exec): PhysicalPlanThunkAsync { + val metas = node.metas + val procedureName = node.procedureName.text + val procedure = procedures[procedureName] ?: err( + "No such stored procedure: $procedureName", + ErrorCode.EVALUATOR_NO_SUCH_PROCEDURE, + errorContextFrom(metas).also { + it[Property.PROCEDURE_NAME] = procedureName + }, + internal = false + ) + + val args = node.args + // Check arity + if (args.size !in procedure.signature.arity) { + val errorContext = errorContextFrom(metas).also { + it[Property.EXPECTED_ARITY_MIN] = procedure.signature.arity.first + it[Property.EXPECTED_ARITY_MAX] = procedure.signature.arity.last + } + + val message = when { + procedure.signature.arity.first == 1 && procedure.signature.arity.last == 1 -> + "${procedure.signature.name} takes a single argument, received: ${args.size}" + procedure.signature.arity.first == procedure.signature.arity.last -> + "${procedure.signature.name} takes exactly ${procedure.signature.arity.first} arguments, received: ${args.size}" + else -> + "${procedure.signature.name} takes between ${procedure.signature.arity.first} and " + + "${procedure.signature.arity.last} arguments, received: ${args.size}" + } + + throw EvaluationException( + message, + ErrorCode.EVALUATOR_INCORRECT_NUMBER_OF_ARGUMENTS_TO_PROCEDURE_CALL, + errorContext, + internal = false + ) + } + + // Compile the procedure's arguments + val argThunks = compileAstExprs(args) + + return thunkFactory.thunkEnvAsync(metas) { env -> + val procedureArgValues = argThunks.map { it(env) } + procedure.call(env.session, procedureArgValues) + } + } + + private suspend fun compileDate(expr: PartiqlPhysical.Expr.Date, metas: MetaContainer): PhysicalPlanThunkAsync = + thunkFactory.thunkEnvAsync(metas) { + ExprValue.newDate( + expr.year.value.toInt(), + expr.month.value.toInt(), + expr.day.value.toInt() + ) + } + + private suspend fun compileLitTime(expr: PartiqlPhysical.Expr.LitTime, metas: MetaContainer): PhysicalPlanThunkAsync = + thunkFactory.thunkEnvAsync(metas) { + // Add the default time zone if the type "TIME WITH TIME ZONE" does not have an explicitly specified time zone. + ExprValue.newTime( + Time.of( + expr.value.hour.value.toInt(), + expr.value.minute.value.toInt(), + expr.value.second.value.toInt(), + expr.value.nano.value.toInt(), + expr.value.precision.value.toInt(), + if (expr.value.withTimeZone.value && expr.value.tzMinutes == null) evaluatorOptions.defaultTimezoneOffset.totalMinutes else expr.value.tzMinutes?.value?.toInt() + ) + ) + } + + private suspend fun compileBagOp(node: PartiqlPhysical.Expr.BagOp, metas: MetaContainer): PhysicalPlanThunkAsync { + val lhs = compileAstExpr(node.operands[0]) + val rhs = compileAstExpr(node.operands[1]) + val op = ExprValueBagOp.create(node.op, metas) + return thunkFactory.thunkEnvAsync(metas) { env -> + val l = lhs(env) + val r = rhs(env) + val result = when (node.quantifier) { + is PartiqlPhysical.SetQuantifier.All -> op.eval(l, r) + is PartiqlPhysical.SetQuantifier.Distinct -> op.eval(l, r).distinct() + } + ExprValue.newBag(result) + } + } + + private suspend fun compilePivot(expr: PartiqlPhysical.Expr.Pivot, metas: MetaContainer): PhysicalPlanThunkAsync { + val inputBExpr: RelationThunkEnvAsync = bexperConverter.convert(expr.input) + // The names are intentionally flipped for clarity; consider fixing this in the AST + val valueExpr = compileAstExpr(expr.key) + val keyExpr = compileAstExpr(expr.value) + return thunkFactory.thunkEnvAsync(metas) { env -> + val attributes: Flow = flow { + val relation = inputBExpr(env) + while (relation.nextRow()) { + val key = keyExpr.invoke(env) + if (key.type.isText) { + val value = valueExpr.invoke(env) + emit(value.namedValue(key)) + } + } + } + ExprValue.newStruct(attributes.toList(), StructOrdering.UNORDERED) + } + } + + /** A special wrapper for `UNPIVOT` values as a BAG. */ + private class UnpivotedExprValue(private val values: Iterable) : BaseExprValue() { + override val type = ExprValueType.BAG + override fun iterator() = values.iterator() + } + + /** Unpivots a `struct`, and synthesizes a synthetic singleton `struct` for other [ExprValue]. */ + internal fun ExprValue.unpivot(): ExprValue = when { + // special case for our special UNPIVOT value to avoid double wrapping + this is UnpivotedExprValue -> this + // Wrap into a pseudo-BAG + type == ExprValueType.STRUCT || type == ExprValueType.MISSING -> UnpivotedExprValue(this) + // for non-struct, this wraps any value into a BAG with a synthetic name + else -> UnpivotedExprValue( + listOf( + this.namedValue(ExprValue.newString(syntheticColumnName(0))) + ) + ) + } + + private fun createStructExprValue(seq: Sequence, ordering: StructOrdering) = + ExprValue.newStruct( + when (evaluatorOptions.projectionIteration) { + ProjectionIterationBehavior.FILTER_MISSING -> seq.filter { it.type != ExprValueType.MISSING } + ProjectionIterationBehavior.UNFILTERED -> seq + }, + ordering + ) +} + +/** + * Represents an element in a select list that is to be projected into the final result. + * i.e. an expression, or a (project_all) node. + */ +private sealed class CompiledStructPartAsync { + + /** + * Represents a single compiled expression to be projected into the final result. + * Given `SELECT a + b as value FROM foo`: + * - `name` is "value" + * - `thunk` is compiled expression, i.e. `a + b` + */ + class Field(val nameThunk: PhysicalPlanThunkAsync, val valueThunk: PhysicalPlanThunkAsync) : CompiledStructPartAsync() + + /** + * Represents a wildcard ((path_project_all) node) expression to be projected into the final result. + * This covers two cases. For `SELECT foo.* FROM foo`, `exprThunks` contains a single compiled expression + * `foo`. + * + * For `SELECT * FROM foo, bar, bat`, `exprThunks` would contain a compiled expression for each of `foo`, `bar` and + * `bat`. + */ + class StructMerge(val thunks: List) : CompiledStructPartAsync() +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/RelationThunk.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/RelationThunk.kt index 4990e764d5..362292363a 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/RelationThunk.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/RelationThunk.kt @@ -10,6 +10,7 @@ import org.partiql.lang.eval.fillErrorContext import org.partiql.lang.eval.relation.RelationIterator /** A thunk that returns a [RelationIterator], which is the result of evaluating a relational operator. */ +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("RelationThunkEnvAsync")) internal typealias RelationThunkEnv = (EvaluatorState) -> RelationIterator /** @@ -18,6 +19,7 @@ internal typealias RelationThunkEnv = (EvaluatorState) -> RelationIterator * This function is not currently in [ThunkFactory] to avoid complicating it further. If a need arises, it could be * moved. */ +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("relationThunkAsync")) internal inline fun relationThunk(metas: MetaContainer, crossinline t: RelationThunkEnv): RelationThunkEnv { val sourceLocationMeta = metas[SourceLocationMeta.TAG] as? SourceLocationMeta return { env: EvaluatorState -> diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/RelationThunkAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/RelationThunkAsync.kt new file mode 100644 index 0000000000..884e136bff --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/RelationThunkAsync.kt @@ -0,0 +1,45 @@ +package org.partiql.lang.eval.physical + +import com.amazon.ionelement.api.MetaContainer +import org.partiql.errors.ErrorCode +import org.partiql.errors.Property +import org.partiql.lang.ast.SourceLocationMeta +import org.partiql.lang.eval.EvaluationException +import org.partiql.lang.eval.ThunkFactory +import org.partiql.lang.eval.errorContextFrom +import org.partiql.lang.eval.fillErrorContext +import org.partiql.lang.eval.relation.RelationIterator + +/** A thunk that returns a [RelationIterator], which is the result of evaluating a relational operator. */ +internal typealias RelationThunkEnvAsync = suspend (EvaluatorState) -> RelationIterator + +/** + * Invokes [t] with error handling like is supplied by [ThunkFactory]. + * + * This function is not currently in [ThunkFactory] to avoid complicating it further. If a need arises, it could be + * moved. + */ +internal suspend inline fun relationThunkAsync(metas: MetaContainer, crossinline t: RelationThunkEnvAsync): RelationThunkEnvAsync { + val sourceLocationMeta = metas[SourceLocationMeta.TAG] as? SourceLocationMeta + return { env: EvaluatorState -> + try { + t(env) + } catch (e: EvaluationException) { + // Only add source location data to the error context if it doesn't already exist + // in [errorContext]. + if (!e.errorContext.hasProperty(Property.LINE_NUMBER)) { + sourceLocationMeta?.let { fillErrorContext(e.errorContext, sourceLocationMeta) } + } + throw e + } catch (e: Exception) { + val message = e.message ?: "" + throw EvaluationException( + "Generic exception, $message", + errorCode = ErrorCode.EVALUATOR_GENERIC_EXCEPTION, + errorContext = errorContextFrom(sourceLocationMeta), + cause = e, + internal = true + ) + } + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/VariableBinding.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/VariableBinding.kt index 82934d7ede..bea2d1d266 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/VariableBinding.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/VariableBinding.kt @@ -8,6 +8,7 @@ import org.partiql.lang.eval.physical.operators.ValueExpression * @property setFunc The function to be invoked at evaluation-time to set the value of the variable. * @property expr The function to be invoked at evaluation-time to compute the value of the variable. */ +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("VariableBindingAsync")) class VariableBinding( val setFunc: SetVariableFunc, val expr: ValueExpression diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/VariableBindingAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/VariableBindingAsync.kt new file mode 100644 index 0000000000..272629f2a3 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/VariableBindingAsync.kt @@ -0,0 +1,14 @@ +package org.partiql.lang.eval.physical + +import org.partiql.lang.eval.physical.operators.ValueExpressionAsync + +/** + * A compiled variable binding. + * + * @property setFunc The function to be invoked at evaluation-time to set the value of the variable. + * @property expr The function to be invoked at evaluation-time to compute the value of the variable. + */ +class VariableBindingAsync( + val setFunc: SetVariableFunc, + val expr: ValueExpressionAsync +) diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/AggregateOperatorFactory.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/AggregateOperatorFactory.kt index 531eb89dc9..320aac61c9 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/AggregateOperatorFactory.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/AggregateOperatorFactory.kt @@ -31,10 +31,12 @@ import java.util.TreeMap * * @param name */ +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("AggregateOperatorFactoryAsync")) public abstract class AggregateOperatorFactory(name: String) : RelationalOperatorFactory { public override val key = RelationalOperatorFactoryKey(RelationalOperatorKind.AGGREGATE, name) + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("AggregateOperatorFactoryAsync.create")) public abstract fun create( source: RelationExpression, strategy: PartiqlPhysical.GroupingStrategy, @@ -43,12 +45,14 @@ public abstract class AggregateOperatorFactory(name: String) : RelationalOperato ): RelationExpression } +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("CompiledGroupKeyAsync")) public class CompiledGroupKey( val setGroupKeyVal: SetVariableFunc, val value: ValueExpression, val variable: PartiqlPhysical.VarDecl ) +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("CompiledAggregateFunctionAsync")) public class CompiledAggregateFunction( val name: String, val setAggregateVal: SetVariableFunc, diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/AggregateOperatorFactoryAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/AggregateOperatorFactoryAsync.kt new file mode 100644 index 0000000000..8f4cb2948a --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/AggregateOperatorFactoryAsync.kt @@ -0,0 +1,112 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ +package org.partiql.lang.eval.physical.operators + +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.DEFAULT_COMPARATOR +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.physical.EvaluatorState +import org.partiql.lang.eval.physical.SetVariableFunc +import org.partiql.lang.eval.relation.RelationIterator +import org.partiql.lang.eval.relation.RelationType +import org.partiql.lang.eval.relation.relation +import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME +import java.util.TreeMap + +/** + * Provides an implementation of the [PartiqlPhysical.Bexpr.Aggregate] operator. + * + * @constructor + * + * @param name + */ +public abstract class AggregateOperatorFactoryAsync(name: String) : RelationalOperatorFactory { + + public override val key = RelationalOperatorFactoryKey(RelationalOperatorKind.AGGREGATE, name) + + public abstract fun create( + source: RelationExpressionAsync, + strategy: PartiqlPhysical.GroupingStrategy, + keys: List, + functions: List + ): RelationExpressionAsync +} + +public class CompiledGroupKeyAsync( + val setGroupKeyVal: SetVariableFunc, + val value: ValueExpressionAsync, + val variable: PartiqlPhysical.VarDecl +) + +public class CompiledAggregateFunctionAsync( + val name: String, + val setAggregateVal: SetVariableFunc, + val value: ValueExpressionAsync, + val quantifier: PartiqlPhysical.SetQuantifier, +) + +internal object AggregateOperatorFactoryDefaultAsync : AggregateOperatorFactoryAsync(DEFAULT_IMPL_NAME) { + override fun create( + source: RelationExpressionAsync, + strategy: PartiqlPhysical.GroupingStrategy, + keys: List, + functions: List + ): RelationExpressionAsync = AggregateOperatorDefaultAsync(source, keys, functions) +} + +internal class AggregateOperatorDefaultAsync( + val source: RelationExpressionAsync, + val keys: List, + val functions: List +) : RelationExpressionAsync { + override suspend fun evaluate(state: EvaluatorState): RelationIterator = relation(RelationType.BAG) { + val aggregationMap = TreeMap>(DEFAULT_COMPARATOR) + + val sourceIter = source.evaluate(state) + while (sourceIter.nextRow()) { + + // Initialize the AggregationMap + val evaluatedGroupByKeys = + keys.map { it.value.invoke(state) }.let { ExprValue.newList(it) } + val accumulators = aggregationMap.getOrPut(evaluatedGroupByKeys) { + functions.map { function -> + Accumulator.create(function.name, function.quantifier) + } + } + + // Aggregate Values in Aggregation State + functions.forEachIndexed { index, function -> + val valueToAggregate = function.value(state) + accumulators[index].next(valueToAggregate) + } + } + + // No Aggregations Created + if (keys.isEmpty() && aggregationMap.isEmpty()) { + functions.forEach { function -> + val accumulator = Accumulator.create(function.name, function.quantifier) + function.setAggregateVal(state, accumulator.compute()) + } + yield() + return@relation + } + + // Place Aggregated Values into Result State + aggregationMap.forEach { (exprList, accumulators) -> + exprList.forEachIndexed { index, exprValue -> keys[index].setGroupKeyVal(state, exprValue) } + accumulators.forEachIndexed { index, acc -> functions[index].setAggregateVal(state, acc.compute()) } + yield() + } + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/FilterRelationalOperatorFactory.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/FilterRelationalOperatorFactory.kt index 4d2f0e864a..4a8e6622bc 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/FilterRelationalOperatorFactory.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/FilterRelationalOperatorFactory.kt @@ -16,6 +16,7 @@ import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME * * @param name */ +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("FilterRelationalOperatorFactoryAsync")) abstract class FilterRelationalOperatorFactory(name: String) : RelationalOperatorFactory { final override val key = RelationalOperatorFactoryKey(RelationalOperatorKind.FILTER, name) @@ -28,6 +29,7 @@ abstract class FilterRelationalOperatorFactory(name: String) : RelationalOperato * @param sourceBexpr * @return */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("FilterRelationalOperatorFactoryAsync.create")) abstract fun create( impl: PartiqlPhysical.Impl, predicate: ValueExpression, diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/FilterRelationalOperatorFactoryAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/FilterRelationalOperatorFactoryAsync.kt new file mode 100644 index 0000000000..4225db6f76 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/FilterRelationalOperatorFactoryAsync.kt @@ -0,0 +1,69 @@ +package org.partiql.lang.eval.physical.operators + +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.booleanValue +import org.partiql.lang.eval.isNotUnknown +import org.partiql.lang.eval.physical.EvaluatorState +import org.partiql.lang.eval.relation.RelationIterator +import org.partiql.lang.eval.relation.RelationType +import org.partiql.lang.eval.relation.relation +import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME + +/** + * Provides an implementation of the [PartiqlPhysical.Bexpr.Filter] operator. + * + * @constructor + * + * @param name + */ +abstract class FilterRelationalOperatorFactoryAsync(name: String) : RelationalOperatorFactory { + + final override val key = RelationalOperatorFactoryKey(RelationalOperatorKind.FILTER, name) + + /** + * Creates a [RelationExpressionAsync] instance for [PartiqlPhysical.Bexpr.Filter]. + * + * @param impl + * @param predicate + * @param sourceBexpr + * @return + */ + abstract fun create( + impl: PartiqlPhysical.Impl, + predicate: ValueExpressionAsync, + sourceBexpr: RelationExpressionAsync + ): RelationExpressionAsync +} + +internal object FilterRelationalOperatorFactoryDefaultAsync : FilterRelationalOperatorFactoryAsync(DEFAULT_IMPL_NAME) { + override fun create( + impl: PartiqlPhysical.Impl, + predicate: ValueExpressionAsync, + sourceBexpr: RelationExpressionAsync + ) = SelectOperatorDefaultAsync( + input = sourceBexpr, + predicate = predicate + ) +} + +internal class SelectOperatorDefaultAsync( + val input: RelationExpressionAsync, + val predicate: ValueExpressionAsync, +) : RelationExpressionAsync { + + override suspend fun evaluate(state: EvaluatorState): RelationIterator { + val input = input.evaluate(state) + return relation(RelationType.BAG) { + while (true) { + if (!input.nextRow()) { + break + } else { + val matches = predicate.invoke(state) + if (matches.isNotUnknown() && matches.booleanValue()) { + yield() + } + } + } + } + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/JoinRelationalOperatorFactory.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/JoinRelationalOperatorFactory.kt index 0cad0029fb..1590e0c89d 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/JoinRelationalOperatorFactory.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/JoinRelationalOperatorFactory.kt @@ -15,6 +15,7 @@ import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME * * @param name */ +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("JoinRelationalOperatorFactoryAsync")) abstract class JoinRelationalOperatorFactory(name: String) : RelationalOperatorFactory { final override val key = RelationalOperatorFactoryKey(RelationalOperatorKind.JOIN, name) @@ -31,6 +32,7 @@ abstract class JoinRelationalOperatorFactory(name: String) : RelationalOperatorF * @param setRightSideVariablesToNull * @return */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("JoinRelationalOperatorFactoryAsync.create")) abstract fun create( impl: PartiqlPhysical.Impl, joinType: PartiqlPhysical.JoinType, @@ -87,7 +89,7 @@ internal object JoinRelationalOperatorFactoryDefault : JoinRelationalOperatorFac /** * See specification 5.6 */ -internal class InnerJoinOperator( +private class InnerJoinOperator( private val lhs: RelationExpression, private val rhs: RelationExpression, private val condition: (EvaluatorState) -> Boolean @@ -109,7 +111,7 @@ internal class InnerJoinOperator( /** * See specification 5.6 */ -internal class LeftJoinOperator( +private class LeftJoinOperator( private val lhs: RelationExpression, private val rhs: RelationExpression, private val condition: (EvaluatorState) -> Boolean, @@ -138,7 +140,7 @@ internal class LeftJoinOperator( /** * See specification 5.6 */ -internal class RightJoinOperator( +private class RightJoinOperator( private val lhs: RelationExpression, private val rhs: RelationExpression, private val condition: (EvaluatorState) -> Boolean, diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/JoinRelationalOperatorFactoryAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/JoinRelationalOperatorFactoryAsync.kt new file mode 100644 index 0000000000..a476e9b91d --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/JoinRelationalOperatorFactoryAsync.kt @@ -0,0 +1,165 @@ +package org.partiql.lang.eval.physical.operators + +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.booleanValue +import org.partiql.lang.eval.isNotUnknown +import org.partiql.lang.eval.physical.EvaluatorState +import org.partiql.lang.eval.relation.RelationType +import org.partiql.lang.eval.relation.relation +import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME + +/** + * Provides an implementation of the [PartiqlPhysical.Bexpr.Join] operator. + * + * @constructor + * + * @param name + */ +abstract class JoinRelationalOperatorFactoryAsync(name: String) : RelationalOperatorFactory { + + final override val key = RelationalOperatorFactoryKey(RelationalOperatorKind.JOIN, name) + + /** + * Creates a [RelationExpressionAsync] instance for [PartiqlPhysical.Bexpr.Join]. + * + * @param impl static arguments + * @param joinType inner, left, right, outer + * @param leftBexpr left-hand-side of the join + * @param rightBexpr right-hand-side of the join + * @param predicateExpr condition for a theta join + * @param setLeftSideVariablesToNull + * @param setRightSideVariablesToNull + * @return + */ + abstract fun create( + impl: PartiqlPhysical.Impl, + joinType: PartiqlPhysical.JoinType, + leftBexpr: RelationExpressionAsync, + rightBexpr: RelationExpressionAsync, + predicateExpr: ValueExpressionAsync?, + setLeftSideVariablesToNull: (EvaluatorState) -> Unit, + setRightSideVariablesToNull: (EvaluatorState) -> Unit + ): RelationExpressionAsync +} + +internal object JoinRelationalOperatorFactoryDefaultAsync : JoinRelationalOperatorFactoryAsync(DEFAULT_IMPL_NAME) { + override fun create( + impl: PartiqlPhysical.Impl, + joinType: PartiqlPhysical.JoinType, + leftBexpr: RelationExpressionAsync, + rightBexpr: RelationExpressionAsync, + predicateExpr: ValueExpressionAsync?, + setLeftSideVariablesToNull: (EvaluatorState) -> Unit, + setRightSideVariablesToNull: (EvaluatorState) -> Unit + ): RelationExpressionAsync = when (joinType) { + is PartiqlPhysical.JoinType.Inner -> { + InnerJoinOperatorAsync( + lhs = leftBexpr, + rhs = rightBexpr, + condition = predicateExpr?.closure() ?: { true } + ) + } + is PartiqlPhysical.JoinType.Left -> { + LeftJoinOperatorAsync( + lhs = leftBexpr, + rhs = rightBexpr, + condition = predicateExpr?.closure() ?: { true }, + setRightSideVariablesToNull = setRightSideVariablesToNull + ) + } + is PartiqlPhysical.JoinType.Right -> { + RightJoinOperatorAsync( + lhs = leftBexpr, + rhs = rightBexpr, + condition = predicateExpr?.closure() ?: { true }, + setLeftSideVariablesToNull = setLeftSideVariablesToNull + ) + } + is PartiqlPhysical.JoinType.Full -> TODO("Full join") + } + + private fun ValueExpressionAsync.closure(): suspend (EvaluatorState) -> Boolean = { state: EvaluatorState -> + val v = invoke(state) + v.isNotUnknown() && v.booleanValue() + } +} + +/** + * See specification 5.6 + */ +private class InnerJoinOperatorAsync( + private val lhs: RelationExpressionAsync, + private val rhs: RelationExpressionAsync, + private val condition: suspend (EvaluatorState) -> Boolean +) : RelationExpressionAsync { + + override suspend fun evaluate(state: EvaluatorState) = relation(RelationType.BAG) { + val leftItr = lhs.evaluate(state) + while (leftItr.nextRow()) { + val rightItr = rhs.evaluate(state) + while (rightItr.nextRow()) { + if (condition(state)) { + yield() + } + } + } + } +} + +/** + * See specification 5.6 + */ +private class LeftJoinOperatorAsync( + private val lhs: RelationExpressionAsync, + private val rhs: RelationExpressionAsync, + private val condition: suspend (EvaluatorState) -> Boolean, + private val setRightSideVariablesToNull: (EvaluatorState) -> Unit +) : RelationExpressionAsync { + + override suspend fun evaluate(state: EvaluatorState) = relation(RelationType.BAG) { + val leftItr = lhs.evaluate(state) + while (leftItr.nextRow()) { + val rightItr = rhs.evaluate(state) + var yieldedSomething = false + while (rightItr.nextRow()) { + if (condition(state)) { + yield() + yieldedSomething = true + } + } + if (!yieldedSomething) { + setRightSideVariablesToNull(state) + yield() + } + } + } +} + +/** + * See specification 5.6 + */ +private class RightJoinOperatorAsync( + private val lhs: RelationExpressionAsync, + private val rhs: RelationExpressionAsync, + private val condition: suspend (EvaluatorState) -> Boolean, + private val setLeftSideVariablesToNull: (EvaluatorState) -> Unit +) : RelationExpressionAsync { + + override suspend fun evaluate(state: EvaluatorState) = relation(RelationType.BAG) { + val rightItr = rhs.evaluate(state) + while (rightItr.nextRow()) { + val leftItr = lhs.evaluate(state) + var yieldedSomething = false + while (leftItr.nextRow()) { + if (condition(state)) { + yield() + yieldedSomething = true + } + } + if (!yieldedSomething) { + setLeftSideVariablesToNull(state) + yield() + } + } + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/LetRelationalOperatorFactory.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/LetRelationalOperatorFactory.kt index 0da6d086cd..fb7047f51a 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/LetRelationalOperatorFactory.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/LetRelationalOperatorFactory.kt @@ -14,6 +14,7 @@ import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME * * @param name */ +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("LetRelationalOperatorFactoryAsync")) abstract class LetRelationalOperatorFactory(name: String) : RelationalOperatorFactory { final override val key = RelationalOperatorFactoryKey(RelationalOperatorKind.LET, name) @@ -26,6 +27,7 @@ abstract class LetRelationalOperatorFactory(name: String) : RelationalOperatorFa * @param bindings list of [VariableBinding]s in the `LET` clause * @return */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("LetRelationalOperatorFactoryAsync.create")) abstract fun create( impl: PartiqlPhysical.Impl, sourceBexpr: RelationExpression, diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/LetRelationalOperatorFactoryAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/LetRelationalOperatorFactoryAsync.kt new file mode 100644 index 0000000000..21ff9365c5 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/LetRelationalOperatorFactoryAsync.kt @@ -0,0 +1,64 @@ +package org.partiql.lang.eval.physical.operators + +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.physical.EvaluatorState +import org.partiql.lang.eval.physical.VariableBindingAsync +import org.partiql.lang.eval.relation.RelationIterator +import org.partiql.lang.eval.relation.relation +import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME + +/** + * Provides an implementation of the [PartiqlPhysical.Bexpr.Let] operator. + * + * @constructor + * + * @param name + */ +abstract class LetRelationalOperatorFactoryAsync(name: String) : RelationalOperatorFactory { + + final override val key = RelationalOperatorFactoryKey(RelationalOperatorKind.LET, name) + + /** + * Creates a [RelationExpressionAsync] instance for [PartiqlPhysical.Bexpr.Let]. + * + * @param impl + * @param sourceBexpr + * @param bindings list of [VariableBindingAsync]s in the `LET` clause + * @return + */ + abstract fun create( + impl: PartiqlPhysical.Impl, + sourceBexpr: RelationExpressionAsync, + bindings: List + ): RelationExpressionAsync +} + +internal object LetRelationalOperatorFactoryDefaultAsync : LetRelationalOperatorFactoryAsync(DEFAULT_IMPL_NAME) { + + override fun create( + impl: PartiqlPhysical.Impl, + sourceBexpr: RelationExpressionAsync, + bindings: List + ) = LetOperatorAsync( + input = sourceBexpr, + bindings = bindings, + ) +} + +internal class LetOperatorAsync( + private val input: RelationExpressionAsync, + private val bindings: List +) : RelationExpressionAsync { + + override suspend fun evaluate(state: EvaluatorState): RelationIterator { + val rows = input.evaluate(state) + return relation(rows.relType) { + while (rows.nextRow()) { + bindings.forEach { + it.setFunc(state, it.expr(state)) + } + yield() + } + } + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/LimitRelationalOperatorFactory.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/LimitRelationalOperatorFactory.kt index e9a4791804..893690f4fa 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/LimitRelationalOperatorFactory.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/LimitRelationalOperatorFactory.kt @@ -19,6 +19,7 @@ import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME * * @param name */ +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("LimitRelationalOperatorFactoryAsync")) abstract class LimitRelationalOperatorFactory(name: String) : RelationalOperatorFactory { final override val key = RelationalOperatorFactoryKey(RelationalOperatorKind.LIMIT, name) @@ -31,6 +32,7 @@ abstract class LimitRelationalOperatorFactory(name: String) : RelationalOperator * @param sourceBexpr * @return */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("LimitRelationalOperatorFactoryAsync.create")) abstract fun create( impl: PartiqlPhysical.Impl, rowCountExpr: ValueExpression, diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/LimitRelationalOperatorFactoryAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/LimitRelationalOperatorFactoryAsync.kt new file mode 100644 index 0000000000..9442c0fbf1 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/LimitRelationalOperatorFactoryAsync.kt @@ -0,0 +1,105 @@ +package org.partiql.lang.eval.physical.operators + +import org.partiql.errors.ErrorCode +import org.partiql.errors.Property +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.ExprValueType +import org.partiql.lang.eval.err +import org.partiql.lang.eval.errorContextFrom +import org.partiql.lang.eval.numberValue +import org.partiql.lang.eval.physical.EvaluatorState +import org.partiql.lang.eval.relation.RelationIterator +import org.partiql.lang.eval.relation.relation +import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME + +/** + * Provides an implementation of the [PartiqlPhysical.Bexpr.Limit] operator. + * + * @constructor + * + * @param name + */ +abstract class LimitRelationalOperatorFactoryAsync(name: String) : RelationalOperatorFactory { + + final override val key = RelationalOperatorFactoryKey(RelationalOperatorKind.LIMIT, name) + + /** + * Creates a [RelationExpressionAsync] instance for [PartiqlPhysical.Bexpr.Limit]. + * + * @param impl + * @param rowCountExpr + * @param sourceBexpr + * @return + */ + abstract fun create( + impl: PartiqlPhysical.Impl, + rowCountExpr: ValueExpressionAsync, + sourceBexpr: RelationExpressionAsync + ): RelationExpressionAsync +} + +internal object LimitRelationalOperatorFactoryDefaultAsync : LimitRelationalOperatorFactoryAsync(DEFAULT_IMPL_NAME) { + + override fun create( + impl: PartiqlPhysical.Impl, + rowCountExpr: ValueExpressionAsync, + sourceBexpr: RelationExpressionAsync + ) = LimitOperatorAsync( + input = sourceBexpr, + limit = rowCountExpr + ) +} + +internal class LimitOperatorAsync( + private val input: RelationExpressionAsync, + private val limit: ValueExpressionAsync, +) : RelationExpressionAsync { + + override suspend fun evaluate(state: EvaluatorState): RelationIterator { + val limit = evalLimitRowCount(limit, state) + val rows = input.evaluate(state) + return relation(rows.relType) { + var rowCount = 0L + while (rowCount++ < limit && rows.nextRow()) { + yield() + } + } + } + + private suspend fun evalLimitRowCount(rowCountExpr: ValueExpressionAsync, env: EvaluatorState): Long { + val limitExprValue = rowCountExpr(env) + if (limitExprValue.type != ExprValueType.INT) { + err( + "LIMIT value was not an integer", + ErrorCode.EVALUATOR_NON_INT_LIMIT_VALUE, + errorContextFrom(rowCountExpr.sourceLocation).also { + it[Property.ACTUAL_TYPE] = limitExprValue.type.toString() + }, + internal = false + ) + } + + val originalLimitValue = limitExprValue.numberValue() + val limitValue = originalLimitValue.toLong() + if (originalLimitValue != limitValue as Number) { // Make sure `Number.toLong()` is a lossless transformation + err( + "Integer exceeds Long.MAX_VALUE provided as LIMIT value", + ErrorCode.INTERNAL_ERROR, + errorContextFrom(rowCountExpr.sourceLocation), + internal = true + ) + } + + if (limitValue < 0) { + err( + "negative LIMIT", + ErrorCode.EVALUATOR_NEGATIVE_LIMIT, + errorContextFrom(rowCountExpr.sourceLocation), + internal = false + ) + } + // we can't use the Kotlin's Sequence.take(n) for this since it accepts only an integer. + // this references [Sequence.take(count: Long): Sequence] defined in [org.partiql.util]. + return limitValue + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/OffsetRelationalOperatorFactory.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/OffsetRelationalOperatorFactory.kt index d36fca5452..a27fd1f50c 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/OffsetRelationalOperatorFactory.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/OffsetRelationalOperatorFactory.kt @@ -19,6 +19,7 @@ import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME * * @param name */ +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("OffsetRelationalOperatorFactoryAsync")) abstract class OffsetRelationalOperatorFactory(name: String) : RelationalOperatorFactory { final override val key = RelationalOperatorFactoryKey(RelationalOperatorKind.OFFSET, name) @@ -31,6 +32,7 @@ abstract class OffsetRelationalOperatorFactory(name: String) : RelationalOperato * @param sourceBexpr * @return */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("OffsetRelationalOperatorFactoryAsync.create")) abstract fun create( impl: PartiqlPhysical.Impl, rowCountExpr: ValueExpression, diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/OffsetRelationalOperatorFactoryAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/OffsetRelationalOperatorFactoryAsync.kt new file mode 100644 index 0000000000..3df809d793 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/OffsetRelationalOperatorFactoryAsync.kt @@ -0,0 +1,107 @@ +package org.partiql.lang.eval.physical.operators + +import org.partiql.errors.ErrorCode +import org.partiql.errors.Property +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.ExprValueType +import org.partiql.lang.eval.err +import org.partiql.lang.eval.errorContextFrom +import org.partiql.lang.eval.numberValue +import org.partiql.lang.eval.physical.EvaluatorState +import org.partiql.lang.eval.relation.RelationIterator +import org.partiql.lang.eval.relation.relation +import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME + +/** + * Provides an implementation of the [PartiqlPhysical.Bexpr.Offset] operator. + * + * @constructor + * + * @param name + */ +abstract class OffsetRelationalOperatorFactoryAsync(name: String) : RelationalOperatorFactory { + + final override val key = RelationalOperatorFactoryKey(RelationalOperatorKind.OFFSET, name) + + /** + * Creates a [RelationExpressionAsync] instance for [PartiqlPhysical.Bexpr.Offset]. + * + * @param impl + * @param rowCountExpr + * @param sourceBexpr + * @return + */ + abstract fun create( + impl: PartiqlPhysical.Impl, + rowCountExpr: ValueExpressionAsync, + sourceBexpr: RelationExpressionAsync + ): RelationExpressionAsync +} + +internal object OffsetRelationalOperatorFactoryDefaultAsync : OffsetRelationalOperatorFactoryAsync(DEFAULT_IMPL_NAME) { + + override fun create( + impl: PartiqlPhysical.Impl, + rowCountExpr: ValueExpressionAsync, + sourceBexpr: RelationExpressionAsync + ) = OffsetOperatorAsync( + input = sourceBexpr, + offset = rowCountExpr, + ) +} + +internal class OffsetOperatorAsync( + private val input: RelationExpressionAsync, + private val offset: ValueExpressionAsync, +) : RelationExpressionAsync { + + override suspend fun evaluate(state: EvaluatorState): RelationIterator { + val skipCount: Long = evalOffsetRowCount(offset, state) + val rows = input.evaluate(state) + return relation(rows.relType) { + var rowCount = 0L + while (rowCount++ < skipCount) { + // stop iterating if we run out of rows before we hit the offset. + if (!rows.nextRow()) { + return@relation + } + } + yieldAll(rows) + } + } + + private suspend fun evalOffsetRowCount(rowCountExpr: ValueExpressionAsync, state: EvaluatorState): Long { + val offsetExprValue = rowCountExpr(state) + if (offsetExprValue.type != ExprValueType.INT) { + err( + "OFFSET value was not an integer", + ErrorCode.EVALUATOR_NON_INT_OFFSET_VALUE, + errorContextFrom(rowCountExpr.sourceLocation).also { + it[Property.ACTUAL_TYPE] = offsetExprValue.type.toString() + }, + internal = false + ) + } + + val originalOffsetValue = offsetExprValue.numberValue() + val offsetValue = originalOffsetValue.toLong() + if (originalOffsetValue != offsetValue as Number) { // Make sure `Number.toLong()` is a lossless transformation + err( + "Integer exceeds Long.MAX_VALUE provided as OFFSET value", + ErrorCode.INTERNAL_ERROR, + errorContextFrom(rowCountExpr.sourceLocation), + internal = true + ) + } + + if (offsetValue < 0) { + err( + "negative OFFSET", + ErrorCode.EVALUATOR_NEGATIVE_OFFSET, + errorContextFrom(rowCountExpr.sourceLocation), + internal = false + ) + } + return offsetValue + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ProjectRelationalOperatorFactory.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ProjectRelationalOperatorFactory.kt index efd451ba13..cd9a652f00 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ProjectRelationalOperatorFactory.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ProjectRelationalOperatorFactory.kt @@ -4,10 +4,12 @@ import org.partiql.lang.domains.PartiqlPhysical import org.partiql.lang.eval.physical.SetVariableFunc /** Provides an implementation of the [PartiqlPhysical.Bexpr.Project] operator.*/ +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("ProjectRelationalOperatorFactoryAsync")) abstract class ProjectRelationalOperatorFactory(name: String) : RelationalOperatorFactory { final override val key: RelationalOperatorFactoryKey = RelationalOperatorFactoryKey(RelationalOperatorKind.PROJECT, name) /** Creates a [RelationExpression] instance for [PartiqlPhysical.Bexpr.Project]. */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("ProjectRelationalOperatorFactoryAsync.create")) abstract fun create( /** * Contains any static arguments needed by the operator implementation that were supplied by the planner diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ProjectRelationalOperatorFactoryAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ProjectRelationalOperatorFactoryAsync.kt new file mode 100644 index 0000000000..4268e6d6e0 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ProjectRelationalOperatorFactoryAsync.kt @@ -0,0 +1,22 @@ +package org.partiql.lang.eval.physical.operators + +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.physical.SetVariableFunc + +/** Provides an implementation of the [PartiqlPhysical.Bexpr.Project] operator.*/ +abstract class ProjectRelationalOperatorFactoryAsync(name: String) : RelationalOperatorFactory { + final override val key: RelationalOperatorFactoryKey = RelationalOperatorFactoryKey(RelationalOperatorKind.PROJECT, name) + + /** Creates a [RelationExpressionAsync] instance for [PartiqlPhysical.Bexpr.Project]. */ + abstract fun create( + /** + * Contains any static arguments needed by the operator implementation that were supplied by the planner + * pass which specified the operator implementation. + */ + impl: PartiqlPhysical.Impl, + /** Invoke to set the binding for the current row. */ + setVar: SetVariableFunc, + /** Invoke to obtain evaluation-time arguments. */ + args: List + ): RelationExpressionAsync +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/RelationExpression.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/RelationExpression.kt index cea67cd5af..0f1261d770 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/RelationExpression.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/RelationExpression.kt @@ -14,6 +14,8 @@ import org.partiql.lang.eval.relation.RelationIterator * Like [ValueExpression], this is public API that is supported long term and is intended to avoid exposing * implementation details such as [org.partiql.lang.eval.physical.RelationThunkEnv]. */ +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("RelationExpressionAsync")) fun interface RelationExpression { + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("RelationExpressionAsync.evaluate")) fun evaluate(state: EvaluatorState): RelationIterator } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/RelationExpressionAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/RelationExpressionAsync.kt new file mode 100644 index 0000000000..61f6f4a795 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/RelationExpressionAsync.kt @@ -0,0 +1,19 @@ +package org.partiql.lang.eval.physical.operators + +import org.partiql.lang.eval.physical.EvaluatorState +import org.partiql.lang.eval.relation.RelationIterator + +/** + * An implementation of a physical plan relational operator. + * + * PartiQL's relational algebra is based on + * [E.F. Codd's Relational Algebra](https://en.wikipedia.org/wiki/Relational_algebra), but to better support + * semi-structured, schemaless data, our "relations" are actually logical collections of bindings. Still, the term + * "relation" has remained, as well as most other concepts from E.F. Codd's relational algebra. + * + * Like [ValueExpression], this is public API that is supported long term and is intended to avoid exposing + * implementation details such as [org.partiql.lang.eval.physical.RelationThunkEnv]. + */ +fun interface RelationExpressionAsync { + suspend fun evaluate(state: EvaluatorState): RelationIterator +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ScanRelationalOperatorFactory.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ScanRelationalOperatorFactory.kt index 72cb720599..1dfa3e60f8 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ScanRelationalOperatorFactory.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ScanRelationalOperatorFactory.kt @@ -20,6 +20,7 @@ import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME * * @param name */ +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("ScanRelationalOperatorFactoryAsync")) abstract class ScanRelationalOperatorFactory(name: String) : RelationalOperatorFactory { final override val key = RelationalOperatorFactoryKey(RelationalOperatorKind.SCAN, name) @@ -34,6 +35,7 @@ abstract class ScanRelationalOperatorFactory(name: String) : RelationalOperatorF * @param setByVar BY variable binding, if non-null * @return */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("ScanRelationalOperatorFactoryAsync.create")) abstract fun create( impl: PartiqlPhysical.Impl, expr: ValueExpression, diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ScanRelationalOperatorFactoryAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ScanRelationalOperatorFactoryAsync.kt new file mode 100644 index 0000000000..54b7531c63 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ScanRelationalOperatorFactoryAsync.kt @@ -0,0 +1,82 @@ +package org.partiql.lang.eval.physical.operators + +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.ExprValueType +import org.partiql.lang.eval.address +import org.partiql.lang.eval.name +import org.partiql.lang.eval.physical.EvaluatorState +import org.partiql.lang.eval.physical.SetVariableFunc +import org.partiql.lang.eval.relation.RelationIterator +import org.partiql.lang.eval.relation.RelationType +import org.partiql.lang.eval.relation.relation +import org.partiql.lang.eval.unnamedValue +import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME + +/** + * Provides an implementation of the [PartiqlPhysical.Bexpr.Scan] operator. + * + * @constructor + * + * @param name + */ +abstract class ScanRelationalOperatorFactoryAsync(name: String) : RelationalOperatorFactory { + + final override val key = RelationalOperatorFactoryKey(RelationalOperatorKind.SCAN, name) + + /** + * Creates a [RelationExpressionAsync] instance for [PartiqlPhysical.Bexpr.Scan]. + * + * @param impl static arguments + * @param expr invoked to obtain an iterable value + * @param setAsVar AS variable binding + * @param setAtVar AT variable binding, if non-null + * @param setByVar BY variable binding, if non-null + * @return + */ + abstract fun create( + impl: PartiqlPhysical.Impl, + expr: ValueExpressionAsync, + setAsVar: SetVariableFunc, + setAtVar: SetVariableFunc?, + setByVar: SetVariableFunc? + ): RelationExpressionAsync +} + +internal object ScanRelationalOperatorFactoryDefaultAsync : ScanRelationalOperatorFactoryAsync(DEFAULT_IMPL_NAME) { + override fun create( + impl: PartiqlPhysical.Impl, + expr: ValueExpressionAsync, + setAsVar: SetVariableFunc, + setAtVar: SetVariableFunc?, + setByVar: SetVariableFunc? + ) = ScanOperatorAsync(expr, setAsVar, setAtVar, setByVar) +} + +internal class ScanOperatorAsync( + private val expr: ValueExpressionAsync, + private val setAsVar: SetVariableFunc, + private val setAtVar: SetVariableFunc?, + private val setByVar: SetVariableFunc? +) : RelationExpressionAsync { + + override suspend fun evaluate(state: EvaluatorState): RelationIterator { + val value = expr(state) + val sequence: Sequence = when (value.type) { + ExprValueType.LIST, + ExprValueType.BAG -> value.asSequence() + else -> sequenceOf(value) + } + return relation(RelationType.BAG) { + val rows: Iterator = sequence.iterator() + while (rows.hasNext()) { + val item = rows.next() + // .unnamedValue() removes any ordinal that might exist on item + setAsVar(state, item.unnamedValue()) + setAtVar?.let { it(state, item.name ?: ExprValue.missingValue) } + setByVar?.let { it(state, item.address ?: ExprValue.missingValue) } + yield() + } + } + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/SortOperatorFactory.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/SortOperatorFactory.kt index 9ec9af5ea4..ec90d26ad0 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/SortOperatorFactory.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/SortOperatorFactory.kt @@ -4,12 +4,15 @@ import org.partiql.lang.domains.PartiqlPhysical import org.partiql.lang.eval.NaturalExprValueComparators /** Provides an implementation of the [PartiqlPhysical.Bexpr.Order] operator.*/ +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("SortOperatorFactoryAsync")) public abstract class SortOperatorFactory(name: String) : RelationalOperatorFactory { public final override val key: RelationalOperatorFactoryKey = RelationalOperatorFactoryKey(RelationalOperatorKind.SORT, name) + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("SortOperatorFactoryAsync.create")) public abstract fun create( sortKeys: List, sourceRelation: RelationExpression ): RelationExpression } +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("CompiledSortKeyAsync")) public class CompiledSortKey(val comparator: NaturalExprValueComparators, val value: ValueExpression) diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/SortOperatorFactoryAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/SortOperatorFactoryAsync.kt new file mode 100644 index 0000000000..4447939475 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/SortOperatorFactoryAsync.kt @@ -0,0 +1,15 @@ +package org.partiql.lang.eval.physical.operators + +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.NaturalExprValueComparators + +/** Provides an implementation of the [PartiqlPhysical.Bexpr.Sort] operator.*/ +public abstract class SortOperatorFactoryAsync(name: String) : RelationalOperatorFactory { + public final override val key: RelationalOperatorFactoryKey = RelationalOperatorFactoryKey(RelationalOperatorKind.SORT, name) + public abstract fun create( + sortKeys: List, + sourceRelation: RelationExpressionAsync + ): RelationExpressionAsync +} + +public class CompiledSortKeyAsync(val comparator: NaturalExprValueComparators, val value: ValueExpressionAsync) diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/SortOperatorFactoryDefaultAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/SortOperatorFactoryDefaultAsync.kt new file mode 100644 index 0000000000..66b17bafc3 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/SortOperatorFactoryDefaultAsync.kt @@ -0,0 +1,85 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.lang.eval.physical.operators + +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.NaturalExprValueComparators +import org.partiql.lang.eval.physical.EvaluatorState +import org.partiql.lang.eval.relation.RelationIterator +import org.partiql.lang.eval.relation.RelationType +import org.partiql.lang.eval.relation.relation +import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME + +internal object SortOperatorFactoryDefaultAsync : SortOperatorFactoryAsync(DEFAULT_IMPL_NAME) { + override fun create( + sortKeys: List, + sourceRelation: RelationExpressionAsync + ): RelationExpressionAsync = SortOperatorDefaultAsync(sortKeys, sourceRelation) +} + +internal class SortOperatorDefaultAsync(private val sortKeys: List, private val sourceRelation: RelationExpressionAsync) : RelationExpressionAsync { + override suspend fun evaluate(state: EvaluatorState): RelationIterator { + val source = sourceRelation.evaluate(state) + return relation(RelationType.LIST) { + val rows = mutableListOf>() + + // Consume Input + while (source.nextRow()) { + rows.add(state.registers.clone()) + } + + val rowWithValues = rows.map { row -> + state.load(row) + row to sortKeys.map { sk -> + sk.value(state) + } + }.toMutableList() + val comparator = getSortingComparator(sortKeys.map { it.comparator }) + + // Perform Sort + val sortedRows = rowWithValues.sortedWith(comparator) + + // Yield Sorted Rows + val iterator = sortedRows.iterator() + while (iterator.hasNext()) { + state.load(iterator.next().first) + yield() + } + } + } +} + +/** + * Returns a [Comparator] that compares arrays of registers by using un-evaluated sort keys. It does this by modifying + * the [EvaluatorState] to allow evaluation of the [sortKeys]. + */ +internal fun getSortingComparator(sortKeys: List): Comparator, List>> { + return object : Comparator, List>> { + override fun compare( + l: Pair, List>, + r: Pair, List> + ): Int { + val valsToCompare = l.second.zip(r.second) + sortKeys.zip(valsToCompare).map { + val comp = it.first + val cmpResult = comp.compare(it.second.first, it.second.second) + if (cmpResult != 0) { + return cmpResult + } + } + return 0 + } + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/UnpivotOperatorFactory.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/UnpivotOperatorFactory.kt index 8729359c1c..5ca8a3bf5f 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/UnpivotOperatorFactory.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/UnpivotOperatorFactory.kt @@ -4,10 +4,12 @@ import org.partiql.lang.domains.PartiqlPhysical import org.partiql.lang.eval.physical.SetVariableFunc /** Provides an implementation of the [PartiqlPhysical.Bexpr.Scan] operator.*/ +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("UnpivotOperatorFactoryAsync")) public abstract class UnpivotOperatorFactory(name: String) : RelationalOperatorFactory { public final override val key: RelationalOperatorFactoryKey = RelationalOperatorFactoryKey(RelationalOperatorKind.UNPIVOT, name) /** Creates a [RelationExpression] instance for [PartiqlPhysical.Bexpr.Scan]. */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("UnpivotOperatorFactoryAsync.create")) public abstract fun create( /** Invoke to obtain the value to be iterated over.*/ expr: ValueExpression, diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/UnpivotOperatorFactoryAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/UnpivotOperatorFactoryAsync.kt new file mode 100644 index 0000000000..5267e945e5 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/UnpivotOperatorFactoryAsync.kt @@ -0,0 +1,21 @@ +package org.partiql.lang.eval.physical.operators + +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.physical.SetVariableFunc + +/** Provides an implementation of the [PartiqlPhysical.Bexpr.Scan] operator.*/ +public abstract class UnpivotOperatorFactoryAsync(name: String) : RelationalOperatorFactory { + public final override val key: RelationalOperatorFactoryKey = RelationalOperatorFactoryKey(RelationalOperatorKind.UNPIVOT, name) + + /** Creates a [RelationExpressionAsync] instance for [PartiqlPhysical.Bexpr.Scan]. */ + public abstract fun create( + /** Invoke to obtain the value to be iterated over.*/ + expr: ValueExpressionAsync, + /** Invoke to set the `AS` variable binding. */ + setAsVar: SetVariableFunc, + /** Invoke to set the `AT` variable binding, if non-null */ + setAtVar: SetVariableFunc?, + /** Invoke to set the `BY` variable binding, if non-null. */ + setByVar: SetVariableFunc? + ): RelationExpressionAsync +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/UnpivotOperatorFactoryDefaultAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/UnpivotOperatorFactoryDefaultAsync.kt new file mode 100644 index 0000000000..6d768707a9 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/UnpivotOperatorFactoryDefaultAsync.kt @@ -0,0 +1,70 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.lang.eval.physical.operators + +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.ExprValueType +import org.partiql.lang.eval.address +import org.partiql.lang.eval.name +import org.partiql.lang.eval.namedValue +import org.partiql.lang.eval.physical.EvaluatorState +import org.partiql.lang.eval.physical.SetVariableFunc +import org.partiql.lang.eval.relation.RelationIterator +import org.partiql.lang.eval.relation.RelationType +import org.partiql.lang.eval.relation.relation +import org.partiql.lang.eval.syntheticColumnName +import org.partiql.lang.eval.unnamedValue +import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME + +internal object UnpivotOperatorFactoryDefaultAsync : UnpivotOperatorFactoryAsync(DEFAULT_IMPL_NAME) { + override fun create( + expr: ValueExpressionAsync, + setAsVar: SetVariableFunc, + setAtVar: SetVariableFunc?, + setByVar: SetVariableFunc? + ): RelationExpressionAsync = UnpivotOperatorDefaultAsync(expr, setAsVar, setAtVar, setByVar) +} + +internal class UnpivotOperatorDefaultAsync( + private val expr: ValueExpressionAsync, + private val setAsVar: SetVariableFunc, + private val setAtVar: SetVariableFunc?, + private val setByVar: SetVariableFunc? +) : RelationExpressionAsync { + override suspend fun evaluate(state: EvaluatorState): RelationIterator { + val originalValue = expr(state) + val unpivot = originalValue.unpivot() + + return relation(RelationType.BAG) { + val iter = unpivot.iterator() + while (iter.hasNext()) { + val item = iter.next() + setAsVar(state, item.unnamedValue()) + setAtVar?.let { it(state, item.name ?: ExprValue.missingValue) } + setByVar?.let { it(state, item.address ?: ExprValue.missingValue) } + yield() + } + } + } + + private fun ExprValue.unpivot(): ExprValue = when (type) { + ExprValueType.STRUCT, ExprValueType.MISSING -> this + else -> ExprValue.newBag( + listOf( + this.namedValue(ExprValue.newString(syntheticColumnName(0))) + ) + ) + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ValueExpression.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ValueExpression.kt index f6dfe685fa..d72da08eef 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ValueExpression.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ValueExpression.kt @@ -11,11 +11,14 @@ import org.partiql.lang.eval.physical.EvaluatorState * avoid exposing implementation details (i.e. [org.partiql.lang.eval.physical.PhysicalPlanThunk]) of the evaluator. * This implementation accomplishes that and is intended as a publicly usable API that is supported long term. */ +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("ValueExpressionAsync")) interface ValueExpression { /** Evaluates the expression. */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("ValueExpressionAsync.invoke")) operator fun invoke(state: EvaluatorState): ExprValue /** Provides the source location (line & column) of the expression, for error reporting purposes. */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("ValueExpressionAsync.sourceLocation")) val sourceLocation: SourceLocationMeta? } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ValueExpressionAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ValueExpressionAsync.kt new file mode 100644 index 0000000000..69093abcc3 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/ValueExpressionAsync.kt @@ -0,0 +1,30 @@ +package org.partiql.lang.eval.physical.operators + +import org.partiql.lang.ast.SourceLocationMeta +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.physical.EvaluatorState + +/** + * Evaluates a PartiQL expression returning an [ExprValue]. + * + * [RelationExpression] implementations need a mechanism to evaluate such expressions, and said mechanism should + * avoid exposing implementation details (i.e. [org.partiql.lang.eval.physical.PhysicalPlanThunk]) of the evaluator. + * This implementation accomplishes that and is intended as a publicly usable API that is supported long term. + */ +interface ValueExpressionAsync { + /** Evaluates the expression. */ + suspend operator fun invoke(state: EvaluatorState): ExprValue + + /** Provides the source location (line & column) of the expression, for error reporting purposes. */ + val sourceLocation: SourceLocationMeta? +} + +/** Convenience constructor for [ValueExpression]. */ +internal inline fun valueExpressionAsync( + sourceLocation: SourceLocationMeta?, + crossinline invoke: suspend (EvaluatorState) -> ExprValue +) = + object : ValueExpressionAsync { + override suspend fun invoke(state: EvaluatorState): ExprValue = invoke(state) + override val sourceLocation: SourceLocationMeta? get() = sourceLocation + } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/WindowRelationalOperatorFactory.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/WindowRelationalOperatorFactory.kt index 669b8b39a6..ed1a1acd26 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/WindowRelationalOperatorFactory.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/WindowRelationalOperatorFactory.kt @@ -6,11 +6,13 @@ import org.partiql.lang.eval.physical.SetVariableFunc import org.partiql.lang.eval.physical.window.WindowFunction @ExperimentalWindowFunctions +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("WindowRelationalOperatorFactoryAsync")) abstract class WindowRelationalOperatorFactory(name: String) : RelationalOperatorFactory { final override val key: RelationalOperatorFactoryKey = RelationalOperatorFactoryKey(RelationalOperatorKind.WINDOW, name) /** Creates a [RelationExpression] instance for [PartiqlPhysical.Bexpr.Window]. */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("WindowRelationalOperatorFactoryAsync.create")) abstract fun create( source: RelationExpression, windowPartitionList: List, @@ -21,6 +23,7 @@ abstract class WindowRelationalOperatorFactory(name: String) : RelationalOperato } @ExperimentalWindowFunctions +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("CompiledWindowFunctionAsync")) class CompiledWindowFunction( val func: WindowFunction, val parameters: List, diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/WindowRelationalOperatorFactoryAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/WindowRelationalOperatorFactoryAsync.kt new file mode 100644 index 0000000000..7e8c7775a3 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/WindowRelationalOperatorFactoryAsync.kt @@ -0,0 +1,32 @@ +package org.partiql.lang.eval.physical.operators + +import org.partiql.annotations.ExperimentalWindowFunctions +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.physical.SetVariableFunc +import org.partiql.lang.eval.physical.window.NavigationWindowFunctionAsync + +@ExperimentalWindowFunctions +abstract class WindowRelationalOperatorFactoryAsync(name: String) : RelationalOperatorFactory { + + final override val key: RelationalOperatorFactoryKey = RelationalOperatorFactoryKey(RelationalOperatorKind.WINDOW, name) + + /** Creates a [RelationExpressionAsync] instance for [PartiqlPhysical.Bexpr.Window]. */ + abstract fun create( + source: RelationExpressionAsync, + windowPartitionList: List, + windowSortSpecList: List, + compiledWindowFunctions: List + + ): RelationExpressionAsync +} + +@ExperimentalWindowFunctions +class CompiledWindowFunctionAsync( + val func: NavigationWindowFunctionAsync, + val parameters: List, + /** + * This is [PartiqlPhysical.VarDecl] instead of [SetVariableFunc] because we would like to access the index of variable in the register + * when processing rows within the partition. + */ + val windowVarDecl: PartiqlPhysical.VarDecl +) diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/WindowRelationalOperatorFactoryDefaultAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/WindowRelationalOperatorFactoryDefaultAsync.kt new file mode 100644 index 0000000000..034fcbd18e --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/operators/WindowRelationalOperatorFactoryDefaultAsync.kt @@ -0,0 +1,119 @@ +package org.partiql.lang.eval.physical.operators + +import org.partiql.annotations.ExperimentalWindowFunctions +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.NaturalExprValueComparators +import org.partiql.lang.eval.exprEquals +import org.partiql.lang.eval.physical.EvaluatorState +import org.partiql.lang.eval.relation.RelationIterator +import org.partiql.lang.eval.relation.RelationType +import org.partiql.lang.eval.relation.relation +import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME + +/** + * This is an experimental implementation of the window operator + * + * The general concept here is to sort the input relation, first by partition keys (if not null) then by sort keys (if not null). + * After sorting, we do a sequence scan to create partition and materialize all the element in the same partition + * + */ +@ExperimentalWindowFunctions +internal object WindowRelationalOperatorFactoryDefaultAsync : WindowRelationalOperatorFactoryAsync(DEFAULT_IMPL_NAME) { + override fun create( + source: RelationExpressionAsync, + windowPartitionList: List, + windowSortSpecList: List, + compiledWindowFunctions: List + ): RelationExpressionAsync = WindowOperatorDefaultAsync(source, windowPartitionList, windowSortSpecList, compiledWindowFunctions) +} + +@ExperimentalWindowFunctions +internal class WindowOperatorDefaultAsync( + private val source: RelationExpressionAsync, + private val windowPartitionList: List, + private val windowSortSpecList: List, + private val compiledWindowFunctions: List +) : RelationExpressionAsync { + override suspend fun evaluate(state: EvaluatorState): RelationIterator { + // the following corresponding to materialization process + val sourceIter = source.evaluate(state) + val registers = sequence { + while (sourceIter.nextRow()) { + yield(state.registers.clone()) + } + } + + val partitionSortSpec = windowPartitionList.map { + CompiledSortKeyAsync(NaturalExprValueComparators.NULLS_FIRST_ASC, it) + } + + val sortKeys = partitionSortSpec + windowSortSpecList + + val newRegisters = registers.toList().map { row -> + state.load(row) + row to sortKeys.map { sk -> + sk.value(state) + } + }.toMutableList() + + val sortedRegisters = newRegisters.sortedWith(getSortingComparator(sortKeys.map { it.comparator })).map { it.first } + + // create the partition here + val partition = mutableListOf>>() + + // entire partition + if (windowPartitionList.isEmpty()) { + partition.add(sortedRegisters.toList()) + } + // need to be partitioned + else { + val iter = sortedRegisters.iterator() + val rowInPartition = mutableListOf>() + var previousPartition: ExprValue? = null + while (iter.hasNext()) { + val currentRow = iter.next() + state.load(currentRow) + val currentPartition = ExprValue.newSexp( + windowPartitionList.map { + it.invoke(state) + } + ) + // for the first time, + if (previousPartition == null) { + rowInPartition.add(currentRow) + previousPartition = currentPartition + } else if (previousPartition.exprEquals(currentPartition)) { + rowInPartition.add(currentRow) + } else { + partition.add(rowInPartition.toList()) + rowInPartition.clear() + previousPartition = currentPartition + rowInPartition.add(currentRow) + } + } + // finish up + partition.add(rowInPartition.toList()) + rowInPartition.clear() + } + + return relation(RelationType.BAG) { + partition.forEach { rowsInPartition -> + compiledWindowFunctions.forEach { + val windowFunc = it.func + // set the window function partition to the current partition + windowFunc.reset(rowsInPartition) + } + + rowsInPartition.forEach { + // process current row + compiledWindowFunctions.forEach { compiledWindowFunction -> + compiledWindowFunction.func.processRow(state, compiledWindowFunction.parameters, compiledWindowFunction.windowVarDecl) + } + + // yield the result + yield() + } + } + } + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/BuiltInWindowFunction.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/BuiltInWindowFunction.kt index ac9ea7e15d..e22b83adcd 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/BuiltInWindowFunction.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/BuiltInWindowFunction.kt @@ -9,3 +9,11 @@ internal fun createBuiltinWindowFunction(name: String) = "lead" -> Lead() else -> error("Window function $name has not been implemented") } + +@ExperimentalWindowFunctions +internal fun createBuiltinWindowFunctionAsync(name: String) = + when (name) { + "lag" -> LagAsync() + "lead" -> LeadAsync() + else -> error("Window function $name has not been implemented") + } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/LagAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/LagAsync.kt new file mode 100644 index 0000000000..f31fd1acc4 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/LagAsync.kt @@ -0,0 +1,45 @@ +package org.partiql.lang.eval.physical.window + +import org.partiql.annotations.ExperimentalWindowFunctions +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.numberValue +import org.partiql.lang.eval.physical.EvaluatorState +import org.partiql.lang.eval.physical.operators.ValueExpressionAsync + +// TODO: Decide if we should reduce the code duplication by combining lead and lag function +@ExperimentalWindowFunctions +internal class LagAsync : NavigationWindowFunctionAsync() { + override val name = "lag" + + companion object { + const val DEFAULT_OFFSET_VALUE = 1L + } + + override suspend fun processRow(state: EvaluatorState, arguments: List, currentPos: Int): ExprValue { + val (target, offset, default) = when (arguments.size) { + 1 -> listOf(arguments[0], null, null) + 2 -> listOf(arguments[0], arguments[1], null) + 3 -> listOf(arguments[0], arguments[1], arguments[2]) + else -> error("Wrong number of Parameter for Lag Function") + } + + val offsetValue = offset?.let { + val numberValue = it.invoke(state).numberValue().toLong() + if (numberValue >= 0) { + numberValue + } else { + error("offset need to be non-negative integer") + } + } ?: DEFAULT_OFFSET_VALUE + val defaultValue = default?.invoke(state) ?: ExprValue.nullValue + val targetIndex = currentPos - offsetValue + + return if (targetIndex >= 0 && targetIndex <= currentPartition.lastIndex) { + val targetRow = currentPartition[targetIndex.toInt()] + state.load(targetRow) + target!!.invoke(state) + } else { + defaultValue + } + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/LeadAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/LeadAsync.kt new file mode 100644 index 0000000000..6d8464171b --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/LeadAsync.kt @@ -0,0 +1,46 @@ +package org.partiql.lang.eval.physical.window + +import org.partiql.annotations.ExperimentalWindowFunctions +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.numberValue +import org.partiql.lang.eval.physical.EvaluatorState +import org.partiql.lang.eval.physical.operators.ValueExpressionAsync + +// TODO: Decide if we should reduce the code duplication by combining lead and lag function. +@ExperimentalWindowFunctions +internal class LeadAsync : NavigationWindowFunctionAsync() { + + override val name = "lead" + + companion object { + const val DEFAULT_OFFSET_VALUE = 1L + } + + override suspend fun processRow(state: EvaluatorState, arguments: List, currentPos: Int): ExprValue { + val (target, offset, default) = when (arguments.size) { + 1 -> listOf(arguments[0], null, null) + 2 -> listOf(arguments[0], arguments[1], null) + 3 -> listOf(arguments[0], arguments[1], arguments[2]) + else -> error("Wrong number of Parameter for Lag Function") + } + + val offsetValue = offset?.let { + val numberValue = it.invoke(state).numberValue().toLong() + if (numberValue >= 0) { + numberValue + } else { + error("offset need to be non-negative integer") + } + } ?: DEFAULT_OFFSET_VALUE + val defaultValue = default?.invoke(state) ?: ExprValue.nullValue + val targetIndex = currentPos + offsetValue + + return if (targetIndex <= currentPartition.lastIndex) { + val targetRow = currentPartition[targetIndex.toInt()] + state.load(targetRow) + target!!.invoke(state) + } else { + defaultValue + } + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/NavigationWindowFunction.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/NavigationWindowFunction.kt index 0849398501..59d2caa4e7 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/NavigationWindowFunction.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/NavigationWindowFunction.kt @@ -12,6 +12,7 @@ import org.partiql.lang.eval.physical.toSetVariableFunc * TODO: When we support FIRST_VALUE, etc, we probably need to modify the process row function, since those function requires frame */ @ExperimentalWindowFunctions +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("NavigationWindowFunctionAsync")) abstract class NavigationWindowFunction() : WindowFunction { lateinit var currentPartition: List> @@ -38,5 +39,6 @@ abstract class NavigationWindowFunction() : WindowFunction { currentPos += 1 } + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("NavigationWindowFunctionAsync.processRow")) abstract fun processRow(state: EvaluatorState, arguments: List, currentPos: Int): ExprValue } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/NavigationWindowFunctionAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/NavigationWindowFunctionAsync.kt new file mode 100644 index 0000000000..b60fb2e0d1 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/NavigationWindowFunctionAsync.kt @@ -0,0 +1,42 @@ +package org.partiql.lang.eval.physical.window + +import org.partiql.annotations.ExperimentalWindowFunctions +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.physical.EvaluatorState +import org.partiql.lang.eval.physical.operators.ValueExpressionAsync +import org.partiql.lang.eval.physical.toSetVariableFunc + +/** + * This abstract class holds some common logic for navigation window function, i.e., LAG, LEAD + * TODO: When we support FIRST_VALUE, etc, we probably need to modify the process row function, since those function requires frame + */ +@ExperimentalWindowFunctions +abstract class NavigationWindowFunctionAsync : WindowFunctionAsync { + + lateinit var currentPartition: List> + private var currentPos: Int = 0 + + override fun reset(partition: List>) { + currentPartition = partition + currentPos = 0 + } + + override suspend fun processRow( + state: EvaluatorState, + arguments: List, + windowVarDecl: PartiqlPhysical.VarDecl + ) { + state.load(currentPartition[currentPos]) + val value = processRow(state, arguments, currentPos) + // before we declare the window function result, we need to go back to the current row + state.load(currentPartition[currentPos]) + windowVarDecl.toSetVariableFunc().invoke(state, value) + // make sure the change of state is reflected in the partition + // so the result of the current window function won't get removed by the time we process the next window function at the same row level. + currentPartition[currentPos][windowVarDecl.index.value.toInt()] = value + currentPos += 1 + } + + abstract suspend fun processRow(state: EvaluatorState, arguments: List, currentPos: Int): ExprValue +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/WindowFunction.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/WindowFunction.kt index ef214ba165..71af614f57 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/WindowFunction.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/WindowFunction.kt @@ -7,6 +7,7 @@ import org.partiql.lang.eval.physical.EvaluatorState import org.partiql.lang.eval.physical.operators.ValueExpression @ExperimentalWindowFunctions +@Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("WindowFunctionAsync")) interface WindowFunction { val name: String @@ -17,10 +18,12 @@ interface WindowFunction { * For now, a partition is represented by list>. * We could potentially benefit from further abstraction of partition. */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("WindowFunctionAsync.reset")) fun reset(partition: List>) /** * Process a row by outputting the result of the window function. */ + @Deprecated("To be removed in the next major version.", replaceWith = ReplaceWith("WindowFunctionAsync.processRow")) fun processRow(state: EvaluatorState, arguments: List, windowVarDecl: PartiqlPhysical.VarDecl) } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/WindowFunctionAsync.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/WindowFunctionAsync.kt new file mode 100644 index 0000000000..270e3e20eb --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/window/WindowFunctionAsync.kt @@ -0,0 +1,26 @@ +package org.partiql.lang.eval.physical.window + +import org.partiql.annotations.ExperimentalWindowFunctions +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.physical.EvaluatorState +import org.partiql.lang.eval.physical.operators.ValueExpressionAsync + +@ExperimentalWindowFunctions +interface WindowFunctionAsync { + + val name: String + + /** + * The reset function should be called before enter a new partition ( including the first one). + * + * For now, a partition is represented by list>. + * We could potentially benefit from further abstraction of partition. + */ + fun reset(partition: List>) + + /** + * Process a row by outputting the result of the window function. + */ + suspend fun processRow(state: EvaluatorState, arguments: List, windowVarDecl: PartiqlPhysical.VarDecl) +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigVisitor.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigVisitor.kt index 1d10d5ce49..98a864e91e 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigVisitor.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigVisitor.kt @@ -225,7 +225,11 @@ internal class PartiQLPigVisitor( } override fun visitDropTable(ctx: PartiQLParser.DropTableContext) = PartiqlAst.build { - val id = visitSymbolPrimitive(ctx.tableName().symbolPrimitive()) + val id = if (ctx.qualifiedName().qualifier.isEmpty()) { + visitSymbolPrimitive(ctx.qualifiedName().name) + } else { + throw ParserException("PIG Parser does not support qualified name as table name", ErrorCode.PARSE_UNEXPECTED_TOKEN) + } dropTable(id.toIdentifier(), ctx.DROP().getSourceMetaContainer()) } @@ -236,7 +240,11 @@ internal class PartiQLPigVisitor( } override fun visitCreateTable(ctx: PartiQLParser.CreateTableContext) = PartiqlAst.build { - val name = visitSymbolPrimitive(ctx.tableName().symbolPrimitive()).name + val name = if (ctx.qualifiedName().qualifier.isEmpty()) { + visitSymbolPrimitive(ctx.qualifiedName().name).name + } else { + throw ParserException("PIG Parser does not support qualified name as table name", ErrorCode.PARSE_UNEXPECTED_TOKEN) + } val def = ctx.tableDef()?.let { visitTableDef(it) } createTable_(name, def, ctx.CREATE().getSourceMetaContainer()) } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/types/PartiqlPhysicalTypeExtensions.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/types/PartiqlPhysicalTypeExtensions.kt new file mode 100644 index 0000000000..41026cd5a5 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/types/PartiqlPhysicalTypeExtensions.kt @@ -0,0 +1,114 @@ +package org.partiql.lang.types + +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.types.DecimalType +import org.partiql.types.IntType +import org.partiql.types.NumberConstraint +import org.partiql.types.StaticType +import org.partiql.types.StringType +import org.partiql.types.TimeType + +/** + * Helper to convert [PartiqlPhysical.Type] in AST to a [TypedOpParameter]. + */ +internal fun PartiqlPhysical.Type.toTypedOpParameter(customTypedOpParameters: Map): TypedOpParameter = + when (this) { + is PartiqlPhysical.Type.MissingType -> TypedOpParameter(StaticType.MISSING) + is PartiqlPhysical.Type.NullType -> TypedOpParameter(StaticType.NULL) + is PartiqlPhysical.Type.BooleanType -> TypedOpParameter(StaticType.BOOL) + is PartiqlPhysical.Type.SmallintType -> TypedOpParameter(IntType(IntType.IntRangeConstraint.SHORT)) + is PartiqlPhysical.Type.Integer4Type -> TypedOpParameter(IntType(IntType.IntRangeConstraint.INT4)) + is PartiqlPhysical.Type.Integer8Type -> TypedOpParameter(IntType(IntType.IntRangeConstraint.LONG)) + is PartiqlPhysical.Type.IntegerType -> TypedOpParameter(IntType(IntType.IntRangeConstraint.LONG)) + is PartiqlPhysical.Type.FloatType, is PartiqlPhysical.Type.RealType, is PartiqlPhysical.Type.DoublePrecisionType -> TypedOpParameter( + StaticType.FLOAT + ) + is PartiqlPhysical.Type.DecimalType -> when { + this.precision == null && this.scale == null -> TypedOpParameter(StaticType.DECIMAL) + this.precision != null && this.scale == null -> TypedOpParameter( + DecimalType( + DecimalType.PrecisionScaleConstraint.Constrained( + this.precision!!.value.toInt() + ) + ) + ) + else -> TypedOpParameter( + DecimalType( + DecimalType.PrecisionScaleConstraint.Constrained( + this.precision!!.value.toInt(), + this.scale!!.value.toInt() + ) + ) + ) + } + is PartiqlPhysical.Type.NumericType -> when { + this.precision == null && this.scale == null -> TypedOpParameter(StaticType.DECIMAL) + this.precision != null && this.scale == null -> TypedOpParameter( + DecimalType( + DecimalType.PrecisionScaleConstraint.Constrained( + this.precision!!.value.toInt() + ) + ) + ) + else -> TypedOpParameter( + DecimalType( + DecimalType.PrecisionScaleConstraint.Constrained( + this.precision!!.value.toInt(), + this.scale!!.value.toInt() + ) + ) + ) + } + is PartiqlPhysical.Type.TimestampType -> TypedOpParameter(StaticType.TIMESTAMP) + is PartiqlPhysical.Type.CharacterType -> when { + this.length == null -> TypedOpParameter( + StringType( + StringType.StringLengthConstraint.Constrained( + NumberConstraint.Equals(1) + ) + ) + ) + else -> TypedOpParameter( + StringType( + StringType.StringLengthConstraint.Constrained( + NumberConstraint.Equals(this.length!!.value.toInt()) + ) + ) + ) + } + is PartiqlPhysical.Type.CharacterVaryingType -> when (val length = this.length) { + null -> TypedOpParameter(StringType(StringType.StringLengthConstraint.Unconstrained)) + else -> TypedOpParameter( + StringType( + StringType.StringLengthConstraint.Constrained( + NumberConstraint.UpTo( + length.value.toInt() + ) + ) + ) + ) + } + is PartiqlPhysical.Type.StringType -> TypedOpParameter(StaticType.STRING) + is PartiqlPhysical.Type.SymbolType -> TypedOpParameter(StaticType.SYMBOL) + is PartiqlPhysical.Type.ClobType -> TypedOpParameter(StaticType.CLOB) + is PartiqlPhysical.Type.BlobType -> TypedOpParameter(StaticType.BLOB) + is PartiqlPhysical.Type.StructType -> TypedOpParameter(StaticType.STRUCT) + is PartiqlPhysical.Type.TupleType -> TypedOpParameter(StaticType.STRUCT) + is PartiqlPhysical.Type.ListType -> TypedOpParameter(StaticType.LIST) + is PartiqlPhysical.Type.SexpType -> TypedOpParameter(StaticType.SEXP) + is PartiqlPhysical.Type.BagType -> TypedOpParameter(StaticType.BAG) + is PartiqlPhysical.Type.AnyType -> TypedOpParameter(StaticType.ANY) + is PartiqlPhysical.Type.CustomType -> + customTypedOpParameters.mapKeys { (k, _) -> + k.lowercase() + }[this.name.text.lowercase()] ?: error("Could not find parameter for $this") + is PartiqlPhysical.Type.DateType -> TypedOpParameter(StaticType.DATE) + is PartiqlPhysical.Type.TimeType -> TypedOpParameter( + TimeType(this.precision?.value?.toInt(), withTimeZone = false) + ) + is PartiqlPhysical.Type.TimeWithTimeZoneType -> TypedOpParameter( + TimeType(this.precision?.value?.toInt(), withTimeZone = true) + ) + + is PartiqlPhysical.Type.TimestampWithTimeZoneType -> TODO() + } diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/IntegrationTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/IntegrationTests.kt index 591b9aea4b..556546e27a 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/IntegrationTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/IntegrationTests.kt @@ -1,5 +1,7 @@ package org.partiql.lang.compiler +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertFalse import org.junit.jupiter.api.Assertions.assertTrue @@ -30,15 +32,24 @@ class TestContext { assertEquals(expectedIon, result.toIonValue(ION)) } + // Executes query on async evaluator + suspend fun executeAndAssertAsync( + expectedResultAsIonText: String, + sql: String, + ) { + val expectedIon = ION.singleValue(expectedResultAsIonText) + val result = queryEngine.executeQueryAsync(sql) + assertEquals(expectedIon, result.toIonValue(ION)) + } + fun intKey(value: Int) = ExprValue.newList(listOf(ExprValue.newInt(value))) } /** * Tests the query planner with some basic DML and SFW queries against using [QueryEngine] and [MemoryDatabase]. */ - +@OptIn(ExperimentalCoroutinesApi::class) class IntegrationTests { - @Test fun `insert, select and delete`() { val ctx = TestContext() @@ -85,6 +96,52 @@ class IntegrationTests { assertFalse(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(3))) } + @Test + fun `insert, select and delete async`() = runTest { + val ctx = TestContext() + val db = ctx.db + val customerMetadata = db.findTableMetadata(BindingName("customer", BindingCase.SENSITIVE))!! + + // start by inserting 4 rows + ctx.executeAndAssertAsync("{rows_effected:1}", "INSERT INTO customer << { 'id': 1, 'name': 'bob' } >>") + ctx.executeAndAssertAsync("{rows_effected:1}", "INSERT INTO customer << { 'id': 2, 'name': 'jane' } >>") + ctx.executeAndAssertAsync("{rows_effected:1}", "INSERT INTO customer << { 'id': 3, 'name': 'moe' } >>") + ctx.executeAndAssertAsync("{rows_effected:1}", "INSERT INTO customer << { 'id': 4, 'name': 'sue' } >>") + + // assert each of the rows is present in the actual table. + assertEquals(4, db.getRowCount(customerMetadata.tableId)) + assertTrue(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(1))) + assertTrue(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(2))) + assertTrue(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(3))) + assertTrue(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(4))) + + // commented code intentionally kept. Uncomment to see detailed debug information in the console when + // this test is run + // ctx.queryEngine.enableDebugOutput = true + + // run some simple SFW queries + ctx.executeAndAssertAsync("$BAG_ANNOTATION::[{ name: \"bob\"}]", "SELECT c.name FROM customer AS c WHERE c.id = 1") + ctx.executeAndAssertAsync("$BAG_ANNOTATION::[{ name: \"jane\"}]", "SELECT c.name FROM customer AS c WHERE c.id = 2") + ctx.executeAndAssertAsync("$BAG_ANNOTATION::[{ name: \"moe\"}]", "SELECT c.name FROM customer AS c WHERE c.id = 3") + ctx.executeAndAssertAsync("$BAG_ANNOTATION::[{ name: \"sue\"}]", "SELECT c.name FROM customer AS c WHERE c.id = 4") + + // now delete 2 rows and assert that they are no longer present (test DELETE FROM with WHERE predicate) + + ctx.executeAndAssertAsync("{rows_effected:1}", "DELETE FROM customer AS c WHERE c.id = 2") + assertEquals(3, db.getRowCount(customerMetadata.tableId)) + assertFalse(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(2))) + + ctx.executeAndAssertAsync("{rows_effected:1}", "DELETE FROM customer AS c WHERE c.id = 4") + assertFalse(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(4))) + + // finally, delete all remaining rows (test DELETE FROM without WHERE predicate) + + ctx.executeAndAssertAsync("{rows_effected:2}", "DELETE FROM customer") + assertEquals(0, db.getRowCount(customerMetadata.tableId)) + assertFalse(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(1))) + assertFalse(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(3))) + } + @Test fun `insert with select`() { val ctx = TestContext() @@ -110,4 +167,30 @@ class IntegrationTests { ctx.executeAndAssert("$BAG_ANNOTATION::[{ name: \"bob\"}]", "SELECT c.name FROM more_customer AS c where c.id = 1") ctx.executeAndAssert("$BAG_ANNOTATION::[{ name: \"moe\"}]", "SELECT c.name FROM more_customer AS c where c.id = 3") } + + @Test + fun `insert with select async`() = runTest { + val ctx = TestContext() + val db = ctx.db + // first put some data into the customer table + ctx.executeAndAssertAsync("{rows_effected:1}", "INSERT INTO customer << { 'id': 1, 'name': 'bob' } >>") + ctx.executeAndAssertAsync("{rows_effected:1}", "INSERT INTO customer << { 'id': 2, 'name': 'jane' } >>") + ctx.executeAndAssertAsync("{rows_effected:1}", "INSERT INTO customer << { 'id': 3, 'name': 'moe' } >>") + ctx.executeAndAssertAsync("{rows_effected:1}", "INSERT INTO customer << { 'id': 4, 'name': 'sue' } >>") + + // copy that data into the more_customer table by INSERTing the result of an SFW query + ctx.executeAndAssertAsync( + "{rows_effected:2}", + "INSERT INTO more_customer SELECT c.id, c.name FROM customer AS c WHERE c.id IN (1, 3)" + ) + + val moreCustomerMetadata = db.findTableMetadata(BindingName("more_customer", BindingCase.SENSITIVE))!! + assertEquals(2, db.getRowCount(moreCustomerMetadata.tableId)) + assertTrue(db.tableContainsKey(moreCustomerMetadata.tableId, ctx.intKey(1))) + assertTrue(db.tableContainsKey(moreCustomerMetadata.tableId, ctx.intKey(3))) + + // lastly, assert we have the correct data + ctx.executeAndAssertAsync("$BAG_ANNOTATION::[{ name: \"bob\"}]", "SELECT c.name FROM more_customer AS c where c.id = 1") + ctx.executeAndAssertAsync("$BAG_ANNOTATION::[{ name: \"moe\"}]", "SELECT c.name FROM more_customer AS c where c.id = 3") + } } diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipelineAsyncSmokeTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipelineAsyncSmokeTests.kt new file mode 100644 index 0000000000..e67c00a0e5 --- /dev/null +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipelineAsyncSmokeTests.kt @@ -0,0 +1,191 @@ +package org.partiql.lang.compiler + +import com.amazon.ionelement.api.ionInt +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertDoesNotThrow +import org.junit.jupiter.api.assertThrows +import org.partiql.annotations.ExperimentalPartiQLCompilerPipeline +import org.partiql.errors.Problem +import org.partiql.errors.ProblemDetails +import org.partiql.errors.ProblemLocation +import org.partiql.errors.ProblemSeverity +import org.partiql.lang.ast.SourceLocationMeta +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.errors.PartiQLException +import org.partiql.lang.eval.physical.SetVariableFunc +import org.partiql.lang.eval.physical.operators.RelationExpressionAsync +import org.partiql.lang.eval.physical.operators.ScanRelationalOperatorFactoryAsync +import org.partiql.lang.eval.physical.operators.ValueExpressionAsync +import org.partiql.lang.eval.physical.sourceLocationMetaOrUnknown +import org.partiql.lang.planner.PartiQLPhysicalPass +import org.partiql.lang.planner.PartiQLPlanner +import org.partiql.lang.planner.PlannerEventCallback +import org.partiql.lang.planner.PlanningProblemDetails +import org.partiql.lang.planner.createFakeGlobalsResolver +import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME +import org.partiql.lang.planner.transforms.PLAN_VERSION_NUMBER + +internal fun createFakeErrorProblem(sourceLocationMeta: SourceLocationMeta): Problem { + data class FakeProblemDetails( + override val severity: ProblemSeverity = ProblemSeverity.ERROR, + override val message: String = "Ack, the query author presented us with a logical conundrum!" + ) : ProblemDetails + + return Problem( + sourceLocationMeta.toProblemLocation(), + FakeProblemDetails() + ) +} + +@OptIn(ExperimentalCoroutinesApi::class, ExperimentalPartiQLCompilerPipeline::class) +class PartiQLCompilerPipelineAsyncSmokeTests { + private fun createPlannerPipelineAsyncForTest( + allowUndefinedVariables: Boolean, + plannerEventCallback: PlannerEventCallback?, + block: PartiQLCompilerPipelineAsync.Builder.() -> Unit = { } + ) = PartiQLCompilerPipelineAsync.build { + planner.options( + PartiQLPlanner.Options( + allowedUndefinedVariables = allowUndefinedVariables, + ) + ).callback { + plannerEventCallback?.invoke(it) + }.globalVariableResolver(createFakeGlobalsResolver("Customer" to "fake_uid_for_Customer")) + block() + } + + @Test + fun `happy path`() = runTest { + var pecCallbacks = 0 + val plannerEventCallback: PlannerEventCallback = { _ -> + pecCallbacks++ + } + + val pipeline = createPlannerPipelineAsyncForTest(allowUndefinedVariables = true, plannerEventCallback = plannerEventCallback) + + // the constructed ASTs are tested separately, here we check the compile function does not throw any exception. + assertDoesNotThrow { + pipeline.compile("SELECT c.* FROM Customer AS c WHERE c.primaryKey = 42") + } + + // pec should be called once for each pass in the planner: + // - normalize ast + // - ast -> logical + // - logical -> logical resolved + // - logical resolved -> default physical + assertEquals(4, pecCallbacks) + } + + @Test + fun `undefined variable`() = runTest { + val qp = createPlannerPipelineAsyncForTest(allowUndefinedVariables = false, plannerEventCallback = null) + + val error = assertThrows { + qp.compile("SELECT undefined.* FROM Customer AS c") + } + + // TODO: We use string comparison until we finalized the error reporting mechanism for PartiQLCompilerPipeline + assertEquals( + listOf(Problem(ProblemLocation(1, 8, 9), PlanningProblemDetails.UndefinedVariable("undefined", caseSensitive = false))).toString(), + error.message + ) + } + + @Test + fun `physical plan pass - happy path`() = runTest { + val qp = createPlannerPipelineAsyncForTest(allowUndefinedVariables = false, plannerEventCallback = null) { + planner.physicalPlannerPasses( + listOf( + PartiQLPhysicalPass { plan, _ -> + assertEquals(createFakePlan(1), plan) + createFakePlan(2) + }, + PartiQLPhysicalPass { plan, _ -> + assertEquals(createFakePlan(2), plan) + createFakePlan(3) + }, + PartiQLPhysicalPass { plan, _ -> + assertEquals(createFakePlan(3), plan) + createFakePlan(4) + }, + ) + ) + } + + assertDoesNotThrow { + qp.compile("1") + } + } + + private fun createFakePlan(number: Int) = + PartiqlPhysical.build { + plan( + stmt = query(lit(ionInt(number.toLong()))), + version = PLAN_VERSION_NUMBER + ) + } + + @Test + fun `physical plan pass - first user pass sends semantic error`() = runTest { + val qp = createPlannerPipelineAsyncForTest(allowUndefinedVariables = false, plannerEventCallback = null) { + planner.physicalPlannerPasses( + listOf( + PartiQLPhysicalPass { plan, problemHandler -> + problemHandler.handleProblem( + createFakeErrorProblem(plan.stmt.metas.sourceLocationMetaOrUnknown) + ) + plan + }, + PartiQLPhysicalPass { _, _ -> + error( + "This pass should not be reached due to an error being sent to to the problem handler " + + "in the previous pass" + ) + }, + PartiQLPhysicalPass { plan, _ -> + assertEquals(createFakePlan(3), plan) + createFakePlan(4) + }, + ) + ) + } + val expectedError = createFakeErrorProblem(SourceLocationMeta(1, 1, 57)) + + val error = assertThrows { + qp.compile( + // the actual expression doesn't matter as long as it doesn't have an error detected by a built-in pass + "'the meaning of life, the universe, and everything is 42'" + ) + } + + // TODO: We use string comparison until we finalized the error reporting mechanism for PartiQLCompilerPipeline + assertEquals(listOf(expectedError).toString(), error.message) + } + + @Test + fun `duplicate physical operator factories are blocked`() { + // This will duplicate the default async scan operator factory. + val fakeOperator = object : ScanRelationalOperatorFactoryAsync(DEFAULT_IMPL_NAME) { + override fun create( + impl: PartiqlPhysical.Impl, + expr: ValueExpressionAsync, + setAsVar: SetVariableFunc, + setAtVar: SetVariableFunc?, + setByVar: SetVariableFunc? + ): RelationExpressionAsync { + TODO("doesn't matter won't be invoked") + } + } + + assertThrows { + PartiQLCompilerPipelineAsync.build { + compiler.customOperatorFactories( + listOf(fakeOperator) + ) + } + } + } +} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipelineExplainTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipelineExplainTests.kt index ce26714c38..0b93120bf2 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipelineExplainTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipelineExplainTests.kt @@ -15,6 +15,8 @@ package org.partiql.lang.compiler import com.amazon.ionelement.api.ionInt +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.test.runTest import org.junit.Assert.assertEquals import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ArgumentsSource @@ -32,6 +34,7 @@ import org.partiql.lang.util.ArgumentsProviderBase class PartiQLCompilerPipelineExplainTests { val compiler = PartiQLCompilerPipeline.standard() + private val compilerAsync = PartiQLCompilerPipelineAsync.standard() data class ExplainTestCase( val description: String? = null, @@ -44,12 +47,23 @@ class PartiQLCompilerPipelineExplainTests { @ParameterizedTest fun successTests(tc: ExplainTestCase) = runSuccessTest(tc) + @ArgumentsSource(SuccessTestProvider::class) + @ParameterizedTest + fun successTestsAsync(tc: ExplainTestCase) = runSuccessTestAsync(tc) + private fun runSuccessTest(tc: ExplainTestCase) { val statement = compiler.compile(tc.query) val result = statement.eval(tc.session) assertEquals(tc.expected, result) } + @OptIn(ExperimentalCoroutinesApi::class) + private fun runSuccessTestAsync(tc: ExplainTestCase) = runTest { + val statement = compilerAsync.compile(tc.query) + val result = statement.eval(tc.session) + assertEquals(tc.expected, result) + } + class SuccessTestProvider : ArgumentsProviderBase() { override fun getParameters(): List = listOf( ExplainTestCase( diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipelineSmokeTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipelineSmokeTests.kt index 67a414c1e7..e5ff7df4d0 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipelineSmokeTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/PartiQLCompilerPipelineSmokeTests.kt @@ -7,9 +7,7 @@ import org.junit.jupiter.api.assertDoesNotThrow import org.junit.jupiter.api.assertThrows import org.partiql.annotations.ExperimentalPartiQLCompilerPipeline import org.partiql.errors.Problem -import org.partiql.errors.ProblemDetails import org.partiql.errors.ProblemLocation -import org.partiql.errors.ProblemSeverity import org.partiql.lang.ast.SourceLocationMeta import org.partiql.lang.domains.PartiqlPhysical import org.partiql.lang.errors.PartiQLException @@ -27,6 +25,8 @@ import org.partiql.lang.planner.transforms.DEFAULT_IMPL_NAME import org.partiql.lang.planner.transforms.PLAN_VERSION_NUMBER @OptIn(ExperimentalPartiQLCompilerPipeline::class) +// Equivalent to `PartiQLCompilerPipelineAsyncSmokeTests.kt` but using synchronous physical plan evaluator APIs. +// To be removed next major version class PartiQLCompilerPipelineSmokeTests { private fun createPlannerPipelineForTest( @@ -152,18 +152,6 @@ class PartiQLCompilerPipelineSmokeTests { assertEquals(listOf(expectedError).toString(), error.message) } - private fun createFakeErrorProblem(sourceLocationMeta: SourceLocationMeta): Problem { - data class FakeProblemDetails( - override val severity: ProblemSeverity = ProblemSeverity.ERROR, - override val message: String = "Ack, the query author presented us with a logical conundrum!" - ) : ProblemDetails - - return Problem( - sourceLocationMeta.toProblemLocation(), - FakeProblemDetails() - ) - } - @Test fun `duplicate physical operator factories are blocked`() { // This will duplicate the default scan operator factory. diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/async/AsyncOperatorTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/async/AsyncOperatorTests.kt new file mode 100644 index 0000000000..0b145bc0af --- /dev/null +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/async/AsyncOperatorTests.kt @@ -0,0 +1,120 @@ +package org.partiql.lang.compiler.async + +import com.amazon.ionelement.api.ionInt +import com.amazon.ionelement.api.ionString +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Test +import org.partiql.annotations.ExperimentalPartiQLCompilerPipeline +import org.partiql.lang.compiler.PartiQLCompilerPipelineAsync +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.PartiQLResult +import org.partiql.lang.eval.PartiQLStatementAsync +import org.partiql.lang.eval.booleanValue +import org.partiql.lang.eval.isNotUnknown +import org.partiql.lang.eval.physical.operators.FilterRelationalOperatorFactoryAsync +import org.partiql.lang.eval.physical.operators.RelationExpressionAsync +import org.partiql.lang.eval.physical.operators.ValueExpressionAsync +import org.partiql.lang.eval.relation.RelationType +import org.partiql.lang.eval.relation.relation +import org.partiql.lang.planner.litTrue +import org.partiql.lang.planner.transforms.DEFAULT_IMPL +import org.partiql.lang.planner.transforms.PLAN_VERSION_NUMBER + +private const val FAKE_IMPL_NAME = "test_async_fake" +private val FAKE_IMPL_NODE = PartiqlPhysical.build { impl(FAKE_IMPL_NAME) } + +/** + * Test is included to demonstrate the previous behavior for a relational operator expression that calls an async + * functions. Previously, in the synchronous evaluator, making an async function call would require wrapping the call + * in [runBlocking], which blocks the current thread of execution. This results in the 10 evaluation calls to be + * executed one after the other, waiting for the previous call to finish. + * + * Since the [PartiQLStatementAsync] evaluation is now async, the [runBlocking] around the async function is no longer + * required. Thus, the result is the 10 evaluation calls can be executed without waiting for the previous call to + * finish. + */ +@OptIn(ExperimentalPartiQLCompilerPipeline::class) +class AsyncOperatorTests { + private val fakeOperatorFactories = listOf( + object : FilterRelationalOperatorFactoryAsync(FAKE_IMPL_NAME) { + override fun create( + impl: PartiqlPhysical.Impl, + predicate: ValueExpressionAsync, + sourceBexpr: RelationExpressionAsync + ): RelationExpressionAsync = RelationExpressionAsync { state -> + // If `RelationExpressionAsync`'s `evaluate` was NOT a `suspend fun`, then `runBlocking` would be + // required +// runBlocking { + println("Calling") + someAsyncOp() +// } + val input = sourceBexpr.evaluate(state) + + relation(RelationType.BAG) { + while (true) { + if (!input.nextRow()) { + break + } else { + val matches = predicate.invoke(state) + if (matches.isNotUnknown() && matches.booleanValue()) { + yield() + } + } + } + } + } + } + ) + + private suspend fun someAsyncOp() { + println("sleeping") + delay(2000L) + println("done sleeping") + } + + @OptIn(ExperimentalCoroutinesApi::class) + @Test + fun compilePlan() = runTest { + val pipeline = PartiQLCompilerPipelineAsync.build { + compiler + .customOperatorFactories( + fakeOperatorFactories.map { it } + ) + } + val plan = PartiqlPhysical.build { + plan( + stmt = query( + bindingsToValues( + exp = lit(ionInt(42)), + query = filter( + i = FAKE_IMPL_NODE, + predicate = litTrue(), + source = scan( + i = DEFAULT_IMPL, + expr = bag(struct(listOf(structField(fieldName = lit(ionString("a")), value = lit(ionInt(1)))))), + asDecl = varDecl(0) + ) + ) + ) + ), + version = PLAN_VERSION_NUMBER, + locals = listOf(localVariable("_1", 0)) + ) + } + val statement = pipeline.compile(plan) + // asynchronously evaluate 10 statements and print out the results + repeat(10) { index -> + launch { + print("\nCompiling $index. ") + val result = statement.eval(EvaluationSession.standard()) as PartiQLResult.Value + println("About to print value; $index") + println(result.value) + } + } + } +} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/QueryEngine.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/QueryEngine.kt index f7efe7e1a1..5b69067ab8 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/QueryEngine.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/QueryEngine.kt @@ -4,7 +4,9 @@ import com.amazon.ionelement.api.toIonValue import org.partiql.annotations.ExperimentalPartiQLCompilerPipeline import org.partiql.lang.ION import org.partiql.lang.compiler.PartiQLCompilerPipeline +import org.partiql.lang.compiler.PartiQLCompilerPipelineAsync import org.partiql.lang.compiler.memorydb.operators.GetByKeyProjectRelationalOperatorFactory +import org.partiql.lang.compiler.memorydb.operators.GetByKeyProjectRelationalOperatorFactoryAsync import org.partiql.lang.domains.PartiqlPhysical import org.partiql.lang.eval.BindingCase import org.partiql.lang.eval.BindingName @@ -17,6 +19,7 @@ import org.partiql.lang.eval.namedValue import org.partiql.lang.planner.GlobalResolutionResult import org.partiql.lang.planner.GlobalVariableResolver import org.partiql.lang.planner.PartiQLPhysicalPass +import org.partiql.lang.planner.PartiQLPlannerBuilder import org.partiql.lang.planner.StaticTypeResolver import org.partiql.lang.planner.transforms.optimizations.createConcatWindowFunctionPass import org.partiql.lang.planner.transforms.optimizations.createFilterScanToKeyLookupPass @@ -36,7 +39,7 @@ internal const val DB_CONTEXT_VAR = "in-memory-database" */ @OptIn(ExperimentalPartiQLCompilerPipeline::class) class QueryEngine(val db: MemoryDatabase) { - var enableDebugOutput = false + private var enableDebugOutput = false /** Given a [BindingName], inform the planner the unique identifier of the global variable (usually a table). */ private val globalVariableResolver = GlobalVariableResolver { bindingName -> @@ -60,7 +63,7 @@ class QueryEngine(val db: MemoryDatabase) { // TODO: nothing in the planner uses the contentClosed property yet, but "technically" do have open // content since nothing is constraining the fields in the table. contentClosed = false, - // The FilterScanTokeyLookup pass does use this. + // The FilterScanToKeyLookup pass does use this. primaryKeyFields = tableMetadata.primaryKeyFields ) ) @@ -92,75 +95,85 @@ class QueryEngine(val db: MemoryDatabase) { } } - private val compilerPipeline = PartiQLCompilerPipeline.build { - planner - .callback { - fun prettyPrint(label: String, data: Any) { - val padding = 10 - when (data) { - is DomainNode -> { - println("$label:") - val sexpElement = data.toIonElement() - println(SexpAstPrettyPrinter.format(sexpElement.asAnyElement().toIonValue(ION))) - } - else -> - println("$label:".padEnd(padding) + data.toString()) + // session data + val session = EvaluationSession.build { + globals(bindings) + // Please note that the context here is immutable once the call to .build above + // returns, (Hopefully that will reduce the chances of it being abused.) + withContextVariable("in-memory-database", db) + } + + private fun PartiQLPlannerBuilder.plannerBlock() = this + .callback { + fun prettyPrint(label: String, data: Any) { + val padding = 10 + when (data) { + is DomainNode -> { + println("$label:") + val sexpElement = data.toIonElement() + println(SexpAstPrettyPrinter.format(sexpElement.asAnyElement().toIonValue(ION))) } - } - if (this@QueryEngine.enableDebugOutput) { - prettyPrint("event", it.eventName) - prettyPrint("duration", it.duration) - if (it.eventName == "parse_sql") prettyPrint("input", it.input) - prettyPrint("output", it.output) + else -> + println("$label:".padEnd(padding) + data.toString()) } } - .globalVariableResolver(globalVariableResolver) - .physicalPlannerPasses( - listOf( - // TODO: push-down filters on top of scans before this pass. - PartiQLPhysicalPass { plan, problemHandler -> - createFilterScanToKeyLookupPass( - customProjectOperatorName = GET_BY_KEY_PROJECT_IMPL_NAME, - staticTypeResolver = staticTypeResolver, - createKeyValueConstructor = { recordType, keyFieldEqualityPredicates -> - require(recordType.primaryKeyFields.size == keyFieldEqualityPredicates.size) - PartiqlPhysical.build { - list( - // Key values are expressed to the in-memory storage engine as ordered list. Therefore, we need - // to ensure that the list we pass in as an argument to the custom_get_by_key project operator - // impl is in the right order. - recordType.primaryKeyFields.map { keyFieldName -> - keyFieldEqualityPredicates.single { it.keyFieldName == keyFieldName }.equivalentValue - } - ) - } + if (this@QueryEngine.enableDebugOutput) { + prettyPrint("event", it.eventName) + prettyPrint("duration", it.duration) + if (it.eventName == "parse_sql") prettyPrint("input", it.input) + prettyPrint("output", it.output) + } + } + .globalVariableResolver(globalVariableResolver) + .physicalPlannerPasses( + listOf( + // TODO: push-down filters on top of scans before this pass. + PartiQLPhysicalPass { plan, problemHandler -> + createFilterScanToKeyLookupPass( + customProjectOperatorName = GET_BY_KEY_PROJECT_IMPL_NAME, + staticTypeResolver = staticTypeResolver, + createKeyValueConstructor = { recordType, keyFieldEqualityPredicates -> + require(recordType.primaryKeyFields.size == keyFieldEqualityPredicates.size) + PartiqlPhysical.build { + list( + // Key values are expressed to the in-memory storage engine as ordered list. Therefore, we need + // to ensure that the list we pass in as an argument to the custom_get_by_key project operator + // impl is in the right order. + recordType.primaryKeyFields.map { keyFieldName -> + keyFieldEqualityPredicates.single { it.keyFieldName == keyFieldName }.equivalentValue + } + ) } - ).apply(plan, problemHandler) - }, - // Note that the order of the following plans is relevant--the "remove useless filters" pass - // will not work correctly if "remove useless ands" pass is not executed first. - - // After the filter-scan-to-key-lookup pass above, we may be left with some `(and ...)` expressions - // whose operands were replaced with `(lit true)`. This pass removes `(lit true)` operands from `and` - // expressions, and replaces any `and` expressions with only `(lit true)` operands with `(lit true)`. - // This happens recursively, so an entire tree of useless `(and ...)` expressions will be replaced - // with a single `(lit true)`. - // A constant folding pass might one day eliminate the need for this, but that is not within the current scope. - PartiQLPhysicalPass { plan, problemHandler -> - createRemoveUselessAndsPass().apply(plan, problemHandler) - }, - - // After the previous pass, we may have some `(filter ... )` nodes with `(lit true)` as a predicate. - // This pass removes these useless filter nodes. - PartiQLPhysicalPass { plan, problemHandler -> - createRemoveUselessFiltersPass().apply(plan, problemHandler) - }, - - PartiQLPhysicalPass { plan, problemHandler -> - createConcatWindowFunctionPass().apply(plan, problemHandler) - }, - ) + } + ).apply(plan, problemHandler) + }, + // Note that the order of the following plans is relevant--the "remove useless filters" pass + // will not work correctly if "remove useless ands" pass is not executed first. + + // After the filter-scan-to-key-lookup pass above, we may be left with some `(and ...)` expressions + // whose operands were replaced with `(lit true)`. This pass removes `(lit true)` operands from `and` + // expressions, and replaces any `and` expressions with only `(lit true)` operands with `(lit true)`. + // This happens recursively, so an entire tree of useless `(and ...)` expressions will be replaced + // with a single `(lit true)`. + // A constant folding pass might one day eliminate the need for this, but that is not within the current scope. + PartiQLPhysicalPass { plan, problemHandler -> + createRemoveUselessAndsPass().apply(plan, problemHandler) + }, + + // After the previous pass, we may have some `(filter ... )` nodes with `(lit true)` as a predicate. + // This pass removes these useless filter nodes. + PartiQLPhysicalPass { plan, problemHandler -> + createRemoveUselessFiltersPass().apply(plan, problemHandler) + }, + + PartiQLPhysicalPass { plan, problemHandler -> + createConcatWindowFunctionPass().apply(plan, problemHandler) + }, ) + ) + + private val compilerPipeline = PartiQLCompilerPipeline.build { + planner.plannerBlock() compiler .customOperatorFactories( listOf( @@ -169,23 +182,38 @@ class QueryEngine(val db: MemoryDatabase) { ) } + private val compilerPipelineAsync = PartiQLCompilerPipelineAsync.build { + planner.plannerBlock() + compiler + .customOperatorFactories( + listOf( + GetByKeyProjectRelationalOperatorFactoryAsync() // using async version here + ) + ) + } + fun executeQuery(sql: String): ExprValue { + // compile query to statement + val statement = compilerPipeline.compile(sql) - // session data - val session = EvaluationSession.build { - globals(bindings) - // Please note that the context here is immutable once the call to .build above - // returns, (Hopefully that will reduce the chances of it being abused.) - withContextVariable("in-memory-database", db) - } + // First step is to plan the query. + // This parses the query and runs it through all the planner passes: + // AST -> logical plan -> resolved logical plan -> default physical plan -> custom physical plan + return convertResultToExprValue(statement.eval(session)) + } + suspend fun executeQueryAsync(sql: String): ExprValue { // compile query to statement - val statement = compilerPipeline.compile(sql) + val statement = compilerPipelineAsync.compile(sql) // First step is to plan the query. // This parses the query and runs it through all the planner passes: // AST -> logical plan -> resolved logical plan -> default physical plan -> custom physical plan - return when (val result = statement.eval(session)) { + return convertResultToExprValue(statement.eval(session)) + } + + private fun convertResultToExprValue(result: PartiQLResult): ExprValue = + when (result) { is PartiQLResult.Value -> result.value is PartiQLResult.Delete -> { val targetTableId = UUID.fromString(result.target) @@ -220,5 +248,4 @@ class QueryEngine(val db: MemoryDatabase) { is PartiQLResult.Replace -> TODO("Not implemented yet") is PartiQLResult.Explain.Domain -> TODO("Not implemented yet") } - } } diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/operators/GetByKeyProjectRelationalOperatorFactoryAsync.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/operators/GetByKeyProjectRelationalOperatorFactoryAsync.kt new file mode 100644 index 0000000000..323ec13b6c --- /dev/null +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/operators/GetByKeyProjectRelationalOperatorFactoryAsync.kt @@ -0,0 +1,108 @@ +package org.partiql.lang.compiler.memorydb.operators + +import org.partiql.lang.compiler.memorydb.DB_CONTEXT_VAR +import org.partiql.lang.compiler.memorydb.GET_BY_KEY_PROJECT_IMPL_NAME +import org.partiql.lang.compiler.memorydb.MemoryDatabase +import org.partiql.lang.domains.PartiqlPhysical +import org.partiql.lang.eval.physical.SetVariableFunc +import org.partiql.lang.eval.physical.operators.ProjectRelationalOperatorFactoryAsync +import org.partiql.lang.eval.physical.operators.RelationExpressionAsync +import org.partiql.lang.eval.physical.operators.ValueExpressionAsync +import org.partiql.lang.eval.relation.RelationIterator +import org.partiql.lang.eval.relation.RelationScope +import org.partiql.lang.eval.relation.RelationType +import org.partiql.lang.eval.relation.relation +import java.util.UUID + +/** + * A `project` operator implementation that performs a lookup of a single record stored in a [MemoryDatabase] given its + * primary key. + * + * Operator implementations comprise two phases: + * + * - A compile phase, where one-time computation can be performed and stored in a [RelationExpressionAsync], which + * is essentially a closure. + *- An evaluation phase, where the closure is invoked. The closure returns a [RelationIterator], which is a + * coroutine created by the [relation] function. + * + * In general, the `project` operator implementations must fetch the next row from the data store, call the provided + * [SetVariableFunc] to set the variable, and then call [RelationScope.yield]. + */ + +class GetByKeyProjectRelationalOperatorFactoryAsync : ProjectRelationalOperatorFactoryAsync(GET_BY_KEY_PROJECT_IMPL_NAME) { + /** + * This function is called at compile-time to create an instance of the operator [RelationExpressionAsync] + * that will be invoked at evaluation-time. + */ + override fun create( + impl: PartiqlPhysical.Impl, + setVar: SetVariableFunc, + args: List + ): RelationExpressionAsync { + // Compile phase starts here. We should do as much pre-computation as possible to avoid repeating during the + // evaluation phase. + + // Sanity check the static and dynamic arguments of this operator. If either of these checks fail, it would + // indicate a bug in the rewrite which created this (project ...) operator. + require(impl.staticArgs.size == 1) { + "Expected one static argument to $GET_BY_KEY_PROJECT_IMPL_NAME but found ${args.size}" + } + require(args.size == 1) { + "Expected one argument to $GET_BY_KEY_PROJECT_IMPL_NAME but found ${args.size}" + } + + // Extract the key value constructor + val keyValueExpressionAsync = args.single() + + // Parse the tableId, so we don't have to at evaluation-time + val tableId = UUID.fromString(impl.staticArgs.single().textValue) + + var exhausted = false + + // Finally, return a RelationExpressionAsync which evaluates the key value expression and returns a + // RelationIterator containing a single row corresponding to the key (or no rows if nothing matches) + return RelationExpressionAsync { state -> + // this code runs at evaluation-time. + + if (exhausted) { + throw IllegalStateException("Exhausted result set") + } + + // Get the current database from the EvaluationSession context. + // Please note that the state.session.context map is immutable, therefore it is not possible + // for custom operators or functions to put stuff in there. (Hopefully that will reduce the + // chances of it being abused.) + val db = state.session.context[DB_CONTEXT_VAR] as MemoryDatabase + + // Compute the value of the key using the keyValueExpressionAsync + val keyValue = keyValueExpressionAsync.invoke(state) + + // get the record requested. + val record = db.getRecordByKey(tableId, keyValue) + + exhausted = true + + // if the record was not found, return an empty relation: + if (record == null) + relation(RelationType.BAG) { + // this relation is empty because there is no call to yield() + } + else { + // Return the relation which is Kotlin-coroutine that simply projects the single record we + // found above into the one variable allowed by the project operator, yields, and then returns. + relation(RelationType.BAG) { + // `state` is sacrosanct and should not be modified outside PartiQL. PartiQL + // provides the setVar function so that embedders can safely set the value of the + // variable from within the relation without clobbering anything else. + // It is important to call setVar *before* the yield since otherwise the value + // of the variable will not be assigned before it is accessed. + setVar(state, record) + yield() + + // also note that in this case there is only one record--to return multiple records we would + // iterate over each record normally, calling `setVar` and `yield` once for each record. + } + } + } + } +} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/operators/CustomOperatorFactoryTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/operators/CustomOperatorFactoryTests.kt index 93766dbdcb..c1dcfee4be 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/operators/CustomOperatorFactoryTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/operators/CustomOperatorFactoryTests.kt @@ -1,26 +1,39 @@ package org.partiql.lang.compiler.operators import com.amazon.ionelement.api.ionBool +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.assertThrows import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ArgumentsSource import org.partiql.annotations.ExperimentalPartiQLCompilerPipeline import org.partiql.lang.compiler.PartiQLCompilerPipeline +import org.partiql.lang.compiler.PartiQLCompilerPipelineAsync import org.partiql.lang.domains.PartiqlPhysical import org.partiql.lang.eval.physical.EvaluatorState import org.partiql.lang.eval.physical.SetVariableFunc import org.partiql.lang.eval.physical.VariableBinding +import org.partiql.lang.eval.physical.VariableBindingAsync import org.partiql.lang.eval.physical.operators.FilterRelationalOperatorFactory +import org.partiql.lang.eval.physical.operators.FilterRelationalOperatorFactoryAsync import org.partiql.lang.eval.physical.operators.JoinRelationalOperatorFactory +import org.partiql.lang.eval.physical.operators.JoinRelationalOperatorFactoryAsync import org.partiql.lang.eval.physical.operators.LetRelationalOperatorFactory +import org.partiql.lang.eval.physical.operators.LetRelationalOperatorFactoryAsync import org.partiql.lang.eval.physical.operators.LimitRelationalOperatorFactory +import org.partiql.lang.eval.physical.operators.LimitRelationalOperatorFactoryAsync import org.partiql.lang.eval.physical.operators.OffsetRelationalOperatorFactory +import org.partiql.lang.eval.physical.operators.OffsetRelationalOperatorFactoryAsync import org.partiql.lang.eval.physical.operators.ProjectRelationalOperatorFactory +import org.partiql.lang.eval.physical.operators.ProjectRelationalOperatorFactoryAsync import org.partiql.lang.eval.physical.operators.RelationExpression +import org.partiql.lang.eval.physical.operators.RelationExpressionAsync import org.partiql.lang.eval.physical.operators.RelationalOperatorKind import org.partiql.lang.eval.physical.operators.ScanRelationalOperatorFactory +import org.partiql.lang.eval.physical.operators.ScanRelationalOperatorFactoryAsync import org.partiql.lang.eval.physical.operators.ValueExpression +import org.partiql.lang.eval.physical.operators.ValueExpressionAsync import org.partiql.lang.planner.transforms.DEFAULT_IMPL import org.partiql.lang.planner.transforms.PLAN_VERSION_NUMBER import org.partiql.lang.util.ArgumentsProviderBase @@ -96,6 +109,67 @@ class CustomOperatorFactoryTests { } ) + private val fakeAsyncOperatorFactories = listOf( + object : ProjectRelationalOperatorFactoryAsync(FAKE_IMPL_NAME) { + override fun create( + impl: PartiqlPhysical.Impl, + setVar: SetVariableFunc, + args: List + ): RelationExpressionAsync = + throw CreateFunctionWasCalledException(RelationalOperatorKind.PROJECT) + }, + object : ScanRelationalOperatorFactoryAsync(FAKE_IMPL_NAME) { + override fun create( + impl: PartiqlPhysical.Impl, + expr: ValueExpressionAsync, + setAsVar: SetVariableFunc, + setAtVar: SetVariableFunc?, + setByVar: SetVariableFunc? + ): RelationExpressionAsync = + throw CreateFunctionWasCalledException(RelationalOperatorKind.SCAN) + }, + object : FilterRelationalOperatorFactoryAsync(FAKE_IMPL_NAME) { + override fun create(impl: PartiqlPhysical.Impl, predicate: ValueExpressionAsync, sourceBexpr: RelationExpressionAsync) = + throw CreateFunctionWasCalledException(RelationalOperatorKind.FILTER) + }, + object : JoinRelationalOperatorFactoryAsync(FAKE_IMPL_NAME) { + override fun create( + impl: PartiqlPhysical.Impl, + joinType: PartiqlPhysical.JoinType, + leftBexpr: RelationExpressionAsync, + rightBexpr: RelationExpressionAsync, + predicateExpr: ValueExpressionAsync?, + setLeftSideVariablesToNull: (EvaluatorState) -> Unit, + setRightSideVariablesToNull: (EvaluatorState) -> Unit + ): RelationExpressionAsync = + throw CreateFunctionWasCalledException(RelationalOperatorKind.JOIN) + }, + object : OffsetRelationalOperatorFactoryAsync(FAKE_IMPL_NAME) { + override fun create( + impl: PartiqlPhysical.Impl, + rowCountExpr: ValueExpressionAsync, + sourceBexpr: RelationExpressionAsync + ): RelationExpressionAsync = + throw CreateFunctionWasCalledException(RelationalOperatorKind.OFFSET) + }, + object : LimitRelationalOperatorFactoryAsync(FAKE_IMPL_NAME) { + override fun create( + impl: PartiqlPhysical.Impl, + rowCountExpr: ValueExpressionAsync, + sourceBexpr: RelationExpressionAsync + ): RelationExpressionAsync = + throw CreateFunctionWasCalledException(RelationalOperatorKind.LIMIT) + }, + object : LetRelationalOperatorFactoryAsync(FAKE_IMPL_NAME) { + override fun create( + impl: PartiqlPhysical.Impl, + sourceBexpr: RelationExpressionAsync, + bindings: List + ) = + throw CreateFunctionWasCalledException(RelationalOperatorKind.LET) + } + ) + @ParameterizedTest @ArgumentsSource(CustomOperatorCases::class) fun `make sure custom operator implementations are called`(tc: CustomOperatorCases.TestCase) { @@ -113,11 +187,29 @@ class CustomOperatorFactoryTests { assertEquals(tc.expectedThrownFromOperator, ex.thrownFromOperator) } + @OptIn(ExperimentalCoroutinesApi::class) + @ParameterizedTest + @ArgumentsSource(CustomOperatorCases::class) + fun `make sure custom async operator implementations are called`(tc: CustomOperatorCases.TestCase) = runTest { + val pipeline = PartiQLCompilerPipelineAsync.build { + compiler + .customOperatorFactories( + fakeAsyncOperatorFactories.map { + it + } + ) + } + val ex = assertThrows { + pipeline.compile(tc.plan) + } + assertEquals(tc.expectedThrownFromOperator, ex.thrownFromOperator) + } + class CustomOperatorCases : ArgumentsProviderBase() { class TestCase(val expectedThrownFromOperator: RelationalOperatorKind, val plan: PartiqlPhysical.Plan) override fun getParameters() = listOf( // The key parts of the cases below are the setting of FAKE_IMPL_NODE which causes the custom operator - // factories to be called. The rest is the minimum gibberish needed to make complete PartiqlPhsyical.Bexpr + // factories to be called. The rest is the minimum gibberish needed to make complete PartiqlPhysical.Bexpr // nodes. There must only be one FAKE_IMPL_NODE per plan otherwise the CreateFunctionWasCalledException // might be called for an operator other than the one intended. createTestCase(RelationalOperatorKind.PROJECT) { project(FAKE_IMPL_NODE, varDecl(0)) }, diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerCollectionAggregationsTest.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerCollectionAggregationsTest.kt index 7f38f567ec..c533581ff2 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerCollectionAggregationsTest.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerCollectionAggregationsTest.kt @@ -51,6 +51,13 @@ internal class EvaluatingCompilerCollectionAggregationsTest : EvaluatorTestBase( runEvaluatorTestCase(newTc, SESSION) } + @ParameterizedTest + @ArgumentsSource(ValidTestArguments::class) + fun validTestsAsync(tc: EvaluatorTestCase) { + val newTc = tc.copy(targetPipeline = EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC) + runEvaluatorTestCase(newTc, SESSION) + } + @ParameterizedTest @ArgumentsSource(ErrorTestArguments::class) fun errorTests(tc: EvaluatorErrorTestCase) { @@ -58,6 +65,13 @@ internal class EvaluatingCompilerCollectionAggregationsTest : EvaluatorTestBase( runEvaluatorErrorTestCase(newTc, SESSION) } + @ParameterizedTest + @ArgumentsSource(ErrorTestArguments::class) + fun errorTestsAsync(tc: EvaluatorErrorTestCase) { + val newTc = tc.copy(targetPipeline = EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC) + runEvaluatorErrorTestCase(newTc, SESSION) + } + internal class ValidTestArguments : ArgumentsProviderBase() { override fun getParameters() = listOf( EvaluatorTestCase( diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerExcludeTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerExcludeTests.kt index 931539de15..874898ac18 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerExcludeTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerExcludeTests.kt @@ -18,6 +18,14 @@ class EvaluatingCompilerExcludeTests : EvaluatorTestBase() { "SELECT t.* EXCLUDE t.a FROM <<{'a': {'b': 2}, 'foo': 'bar', 'foo2': 'bar2'}>> AS t", """<<{'foo': 'bar', 'foo2': 'bar2'}>>""" ), + EvaluatorTestCase( + """ + SELECT tbl2.* EXCLUDE tbl2.derivedColumn FROM + (SELECT tbl1.*, tbl1.a.b + 2 AS derivedColumn FROM + <<{'a': {'b': 2}, 'foo': 'bar', 'foo2': 'bar2'}>> AS tbl1) + AS tbl2""", + " <<{'a': {'b': 2}, 'foo': 'bar', 'foo2': 'bar2'}>>" + ), EvaluatorTestCase( // EXCLUDE tuple attr using bracket syntax; same output as above "SELECT t.* EXCLUDE t['a'] FROM <<{'a': {'b': 2}, 'foo': 'bar', 'foo2': 'bar2'}>> AS t", """<<{'foo': 'bar', 'foo2': 'bar2'}>>""" diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerFromLetTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerFromLetTests.kt index 67b5f9371c..0fd6dd179f 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerFromLetTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerFromLetTests.kt @@ -73,6 +73,11 @@ class EvaluatingCompilerFromLetTests : EvaluatorTestBase() { """<< { 'id': 1 }>>""", target = EvaluatorTestTarget.PARTIQL_PIPELINE ), + EvaluatorTestCase( + "SELECT * FROM A LET 100 AS A", + """<< { 'id': 1 }>>""", + target = EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC + ), // LET using other variables EvaluatorTestCase( diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerGroupByTest.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerGroupByTest.kt index 01a5c45296..da1ce6027e 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerGroupByTest.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerGroupByTest.kt @@ -104,8 +104,9 @@ class EvaluatingCompilerGroupByTest : EvaluatorTestBase() { ) /** - * The [EvaluatorTestTarget.PARTIQL_PIPELINE] does NOT support [UndefinedVariableBehavior.MISSING], so if the - * [compOptions] includes the [UndefinedVariableBehavior], we should use the [EvaluatorTestTarget.COMPILER_PIPELINE] + * The [EvaluatorTestTarget.PARTIQL_PIPELINE] and [EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC] do NOT support + * [UndefinedVariableBehavior.MISSING], so if the [compOptions] includes the [UndefinedVariableBehavior], we + * should use the [EvaluatorTestTarget.COMPILER_PIPELINE]. */ private fun getTestTarget(compOptions: CompOptions, default: EvaluatorTestTarget): EvaluatorTestTarget = when (compOptions) { CompOptions.UNDEF_VAR_MISSING -> EvaluatorTestTarget.COMPILER_PIPELINE @@ -1058,6 +1059,45 @@ class EvaluatingCompilerGroupByTest : EvaluatorTestBase() { """, targetPipeline = EvaluatorTestTarget.PARTIQL_PIPELINE ), + EvaluatorTestCase( + groupName = "SELECT with nested aggregates (complex)", + query = """ + SELECT + i2 AS outerKey, + g2 AS outerGroupAs, + COUNT(*) AS outerCount, + SUM(innerQuery.innerSum) AS outerSum, + MIN(innerQuery.innerSum) AS outerMin + FROM ( + SELECT + i, + g, + SUM(col1) AS innerSum + FROM simple_1_col_1_group_2 AS innerFromSource + GROUP BY col1 AS i GROUP AS g + ) AS innerQuery + GROUP BY innerQuery.i AS i2, innerQuery.g AS g2 + """, + expectedResult = """ + << + { + 'outerKey': 1, + 'outerGroupAs': << { 'innerFromSource': { 'col1': 1 } } >>, + 'outerCount': 1, + 'outerSum': 1, + 'outerMin': 1 + }, + { + 'outerKey': 5, + 'outerGroupAs': << { 'innerFromSource': { 'col1': 5 } } >>, + 'outerCount': 1, + 'outerSum': 5, + 'outerMin': 5 + } + >> + """, + targetPipeline = EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC + ), ) @Test @@ -1103,6 +1143,45 @@ class EvaluatingCompilerGroupByTest : EvaluatorTestBase() { """, targetPipeline = EvaluatorTestTarget.PARTIQL_PIPELINE ), + EvaluatorTestCase( + groupName = "SELECT with nested aggregates (complex)", + query = """ + SELECT + i2 AS outerKey, + g2 AS outerGroupAs, + MIN(innerQuery.innerSum) AS outerMin, + ( + SELECT VALUE SUM(i2) + FROM << 0, 1 >> + ) AS projListSubQuery + FROM ( + SELECT + i, + g, + SUM(col1) AS innerSum + FROM simple_1_col_1_group_2 AS innerFromSource + GROUP BY col1 AS i GROUP AS g + ) AS innerQuery + GROUP BY innerQuery.i AS i2, innerQuery.g AS g2 + """, + expectedResult = """ + << + { + 'outerKey': 1, + 'outerGroupAs': << { 'innerFromSource': { 'col1': 1 } } >>, + 'outerMin': 1, + 'projListSubQuery': << 2 >> + }, + { + 'outerKey': 5, + 'outerGroupAs': << { 'innerFromSource': { 'col1': 5 } } >>, + 'outerMin': 5, + 'projListSubQuery': << 10 >> + } + >> + """, + targetPipeline = EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC + ), ) @Test @@ -1355,6 +1434,12 @@ class EvaluatingCompilerGroupByTest : EvaluatorTestBase() { propertyValueMapOf(1, 8, Property.BINDING_NAME to "foo"), target = EvaluatorTestTarget.PARTIQL_PIPELINE ) + runEvaluatorErrorTestCase( + "SELECT foo AS someSelectListAlias FROM <<{ 'a': 1 }>> GROUP BY someSelectListAlias", + ErrorCode.EVALUATOR_VARIABLE_NOT_INCLUDED_IN_GROUP_BY, + propertyValueMapOf(1, 8, Property.BINDING_NAME to "foo"), + target = EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC + ) } @Test @@ -1422,6 +1507,16 @@ class EvaluatingCompilerGroupByTest : EvaluatorTestBase() { target = EvaluatorTestTarget.PARTIQL_PIPELINE, session = session ) + runEvaluatorErrorTestCase( + """ + SELECT "O".customerId, MAX(o.cost) + FROM orders as o + """, + ErrorCode.EVALUATOR_VARIABLE_NOT_INCLUDED_IN_GROUP_BY, + propertyValueMapOf(2, 28, Property.BINDING_NAME to "O"), + target = EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC, + session = session + ) } @Test diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatorTestBase.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatorTestBase.kt index 13293d0f99..7623b07c91 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatorTestBase.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatorTestBase.kt @@ -27,6 +27,7 @@ import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestCase import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.eval.evaluatortestframework.ExpectedResultFormat import org.partiql.lang.eval.evaluatortestframework.MultipleTestAdapter +import org.partiql.lang.eval.evaluatortestframework.PartiQLCompilerPipelineFactoryAsync import org.partiql.lang.eval.evaluatortestframework.PipelineEvaluatorTestAdapter import org.partiql.lang.eval.evaluatortestframework.VisitorTransformBaseTestAdapter import org.partiql.lang.graph.ExternalGraphReader @@ -42,6 +43,7 @@ abstract class EvaluatorTestBase : TestBase() { listOf( PipelineEvaluatorTestAdapter(CompilerPipelineFactory()), PipelineEvaluatorTestAdapter(PartiQLCompilerPipelineFactory()), + PipelineEvaluatorTestAdapter(PartiQLCompilerPipelineFactoryAsync()), VisitorTransformBaseTestAdapter() ) ) diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatorTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatorTests.kt index 5f0673a424..c5af21ffb3 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatorTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatorTests.kt @@ -8,6 +8,7 @@ import org.partiql.lang.eval.evaluatortestframework.CompilerPipelineFactory import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestCase import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.eval.evaluatortestframework.ExpectedResultFormat +import org.partiql.lang.eval.evaluatortestframework.PartiQLCompilerPipelineFactoryAsync import org.partiql.lang.eval.evaluatortestframework.PipelineEvaluatorTestAdapter import org.partiql.lang.mockdb.MockDb import org.partiql.lang.util.testdsl.IonResultTestCase @@ -106,6 +107,10 @@ class EvaluatorTests { @ParameterizedTest @MethodSource("planEvaluatorTests") fun planEvaluatorTests(tc: IonResultTestCase) = tc.runTestCase(mockDb, EvaluatorTestTarget.PARTIQL_PIPELINE) + + @ParameterizedTest + @MethodSource("planEvaluatorTests") + fun planEvaluatorTestsAsync(tc: IonResultTestCase) = tc.runTestCase(mockDb, EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC) } fun IonResultTestCase.runTestCase( @@ -118,6 +123,7 @@ fun IonResultTestCase.runTestCase( when (target) { EvaluatorTestTarget.COMPILER_PIPELINE -> CompilerPipelineFactory() EvaluatorTestTarget.PARTIQL_PIPELINE -> PartiQLCompilerPipelineFactory() + EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC -> PartiQLCompilerPipelineFactoryAsync() // We don't support ALL_PIPELINES here because each pipeline needs a separate skip list, which // is decided by the caller of this function. EvaluatorTestTarget.ALL_PIPELINES -> error("May only test one pipeline at a time with IonResultTestCase") diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/evaluatortestframework/PartiQLCompilerPipelineFactoryAsync.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/evaluatortestframework/PartiQLCompilerPipelineFactoryAsync.kt new file mode 100644 index 0000000000..7ab6947261 --- /dev/null +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/evaluatortestframework/PartiQLCompilerPipelineFactoryAsync.kt @@ -0,0 +1,108 @@ +package org.partiql.lang.eval.evaluatortestframework + +import kotlinx.coroutines.runBlocking +import org.partiql.annotations.ExperimentalPartiQLCompilerPipeline +import org.partiql.lang.compiler.PartiQLCompilerAsyncBuilder +import org.partiql.lang.compiler.PartiQLCompilerPipelineAsync +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.PartiQLResult +import org.partiql.lang.eval.TypingMode +import org.partiql.lang.eval.UndefinedVariableBehavior +import org.partiql.lang.planner.EvaluatorOptions +import org.partiql.lang.planner.GlobalResolutionResult +import org.partiql.lang.planner.GlobalVariableResolver +import org.partiql.lang.planner.PartiQLPlanner +import org.partiql.lang.planner.PartiQLPlannerBuilder +import org.partiql.lang.syntax.PartiQLParserBuilder +import kotlin.test.assertNotEquals +import kotlin.test.assertNull + +/** + * TODO delete this once evaluator tests are replaced by `partiql-tests` + */ +@OptIn(ExperimentalPartiQLCompilerPipeline::class) +internal class PartiQLCompilerPipelineFactoryAsync : PipelineFactory { + + override val pipelineName: String = "PartiQLCompilerPipelineAsync" + + override val target: EvaluatorTestTarget = EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC + + override fun createPipeline( + evaluatorTestDefinition: EvaluatorTestDefinition, + session: EvaluationSession, + forcePermissiveMode: Boolean + ): AbstractPipeline { + + // Construct a legacy CompilerPipeline + val legacyPipeline = evaluatorTestDefinition.createCompilerPipeline(forcePermissiveMode) + val co = legacyPipeline.compileOptions + + assertNotEquals( + co.undefinedVariable, UndefinedVariableBehavior.MISSING, + "The planner and physical plan evaluator do not support UndefinedVariableBehavior.MISSING. " + + "Please set target = EvaluatorTestTarget.COMPILER_PIPELINE for this test.\n" + + "Test groupName: ${evaluatorTestDefinition.groupName}" + ) + + assertNull( + legacyPipeline.globalTypeBindings, + "The planner and evaluator do not currently support globalTypeBindings" + + "Please set target = EvaluatorTestTarget.COMPILER_PIPELINE for this test." + ) + + val evaluatorOptions = EvaluatorOptions.build { + typingMode(co.typingMode) + thunkOptions(co.thunkOptions) + defaultTimezoneOffset(co.defaultTimezoneOffset) + typedOpBehavior(co.typedOpBehavior) + projectionIteration(co.projectionIteration) + } + + val globalVariableResolver = GlobalVariableResolver { + val value = session.globals[it] + if (value != null) { + GlobalResolutionResult.GlobalVariable(it.name) + } else { + GlobalResolutionResult.Undefined + } + } + + val plannerOptions = PartiQLPlanner.Options( + allowedUndefinedVariables = true, + typedOpBehavior = evaluatorOptions.typedOpBehavior + ) + + val pipeline = PartiQLCompilerPipelineAsync( + parser = PartiQLParserBuilder().customTypes(legacyPipeline.customDataTypes).build(), + planner = PartiQLPlannerBuilder.standard() + .options(plannerOptions) + .globalVariableResolver(globalVariableResolver) + .build(), + compiler = PartiQLCompilerAsyncBuilder.standard() + .options(evaluatorOptions) + .customTypes(legacyPipeline.customDataTypes) + .customFunctions(legacyPipeline.functions.values.toList()) + .customProcedures(legacyPipeline.procedures.values.toList()) + .build() + ) + + return object : AbstractPipeline { + + override val typingMode: TypingMode = evaluatorOptions.typingMode + + override fun evaluate(query: String): ExprValue { + return runBlocking { + val statement = pipeline.compile(query) + when (val result = statement.eval(session)) { + is PartiQLResult.Delete, + is PartiQLResult.Insert, + is PartiQLResult.Replace -> error("DML is not supported by test suite") + is PartiQLResult.Value -> result.value + is PartiQLResult.Explain -> error("EXPLAIN is not supported by test suite") + } + } + } + } + } +} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/internal/builtins/functions/DynamicLookupExprFunctionTest.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/internal/builtins/functions/DynamicLookupExprFunctionTest.kt index b54337fc2e..a14bc3aae6 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/internal/builtins/functions/DynamicLookupExprFunctionTest.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/internal/builtins/functions/DynamicLookupExprFunctionTest.kt @@ -10,6 +10,8 @@ import org.partiql.lang.eval.builtins.DYNAMIC_LOOKUP_FUNCTION_NAME import org.partiql.lang.eval.evaluatortestframework.EvaluatorErrorTestCase import org.partiql.lang.eval.evaluatortestframework.EvaluatorTestTarget import org.partiql.lang.eval.evaluatortestframework.ExpectedResultFormat +import org.partiql.lang.eval.internal.builtins.ExprFunctionTestCase +import org.partiql.lang.eval.internal.builtins.checkInvalidArity import org.partiql.lang.util.ArgumentsProviderBase import org.partiql.lang.util.propertyValueMapOf import org.partiql.lang.util.to @@ -33,6 +35,18 @@ class DynamicLookupExprFunctionTest : EvaluatorTestBase() { target = EvaluatorTestTarget.PARTIQL_PIPELINE ) + // Pass test cases + @ParameterizedTest + @ArgumentsSource(ToStringPassCases::class) + fun runPassTestsAsync(testCase: ExprFunctionTestCase) = + runEvaluatorTestCase( + query = testCase.source, + session = session, + expectedResult = testCase.expectedLegacyModeResult, + expectedResultFormat = ExpectedResultFormat.ION, + target = EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC + ) + // We rely on the built-in [DEFAULT_COMPARATOR] for the actual definition of equality, which is not being tested // here. class ToStringPassCases : ArgumentsProviderBase() { @@ -132,9 +146,20 @@ class DynamicLookupExprFunctionTest : EvaluatorTestBase() { session = session ) + @ParameterizedTest + @ArgumentsSource(MismatchCaseSensitiveCases::class) + fun mismatchedCaseSensitiveTestsAsync(testCase: EvaluatorErrorTestCase) = + runEvaluatorErrorTestCase( + testCase.copy( + expectedPermissiveModeResult = "MISSING", + targetPipeline = EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC + ), + session = session + ) + class MismatchCaseSensitiveCases : ArgumentsProviderBase() { override fun getParameters(): List = listOf( - // Can't find these variables due to case mismatch when perform case sensitive lookup + // Can't find these variables due to case mismatch when perform case-sensitive lookup EvaluatorErrorTestCase( query = "\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(`fOo`, `case_sensitive`, `locals_then_globals`, [f, b])", expectedErrorCode = ErrorCode.EVALUATOR_QUOTED_BINDING_DOES_NOT_EXIST, @@ -181,6 +206,23 @@ class DynamicLookupExprFunctionTest : EvaluatorTestBase() { target = EvaluatorTestTarget.PARTIQL_PIPELINE ) + @ParameterizedTest + @ArgumentsSource(InvalidArgCases::class) + fun invalidArgTypeTestCasesAsync(testCase: InvalidArgTestCase) = + runEvaluatorErrorTestCase( + query = testCase.source, + expectedErrorCode = ErrorCode.EVALUATOR_INCORRECT_TYPE_OF_ARGUMENTS_TO_FUNC_CALL, + expectedErrorContext = propertyValueMapOf( + 1, 1, + Property.FUNCTION_NAME to DYNAMIC_LOOKUP_FUNCTION_NAME, + Property.EXPECTED_ARGUMENT_TYPES to "SYMBOL", + Property.ACTUAL_ARGUMENT_TYPES to testCase.actualArgumentType, + Property.ARGUMENT_POSITION to testCase.argumentPosition + ), + expectedPermissiveModeResult = "MISSING", + target = EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC + ) + class InvalidArgCases : ArgumentsProviderBase() { override fun getParameters(): List = listOf( InvalidArgTestCase("\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"(1, `case_insensitive`, `locals_then_globals`, [])", 1, "INT"), @@ -196,4 +238,12 @@ class DynamicLookupExprFunctionTest : EvaluatorTestBase() { minArity = 3, targetPipeline = EvaluatorTestTarget.PARTIQL_PIPELINE ) + + @Test + fun invalidArityTestAsync() = checkInvalidArity( + funcName = "\"$DYNAMIC_LOOKUP_FUNCTION_NAME\"", + maxArity = Int.MAX_VALUE, + minArity = 3, + targetPipeline = EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC + ) } diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/internal/builtins/windowFunctions/WindowFunctionTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/internal/builtins/windowFunctions/WindowFunctionTests.kt index 78d6008175..bee3a481c1 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/internal/builtins/windowFunctions/WindowFunctionTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/internal/builtins/windowFunctions/WindowFunctionTests.kt @@ -25,6 +25,14 @@ class WindowFunctionTests : EvaluatorTestBase() { tc = tc.copy(targetPipeline = EvaluatorTestTarget.PARTIQL_PIPELINE), session = session ) + + @ParameterizedTest + @ArgumentsSource(LagFunctionTestsProvider::class) + fun lagFunctionTestsAsync(tc: EvaluatorTestCase) = runEvaluatorTestCase( + tc = tc.copy(targetPipeline = EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC), + session = session + ) + class LagFunctionTestsProvider : ArgumentsProviderBase() { override fun getParameters() = listOf( // Lag Function with PARTITION BY AND ORDER BY @@ -205,6 +213,13 @@ class WindowFunctionTests : EvaluatorTestBase() { session = session ) + @ParameterizedTest + @ArgumentsSource(LeadFunctionTestsProvider::class) + fun leadFunctionTestsAsync(tc: EvaluatorTestCase) = runEvaluatorTestCase( + tc = tc.copy(targetPipeline = EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC), + session = session + ) + class LeadFunctionTestsProvider : ArgumentsProviderBase() { override fun getParameters() = listOf( EvaluatorTestCase( @@ -378,6 +393,14 @@ class WindowFunctionTests : EvaluatorTestBase() { tc = tc.copy(targetPipeline = EvaluatorTestTarget.PARTIQL_PIPELINE), session = session, ) + + @ParameterizedTest + @ArgumentsSource(MultipleFunctionTestsProvider::class) + fun multipleFunctionTestsAsync(tc: EvaluatorTestCase) = runEvaluatorTestCase( + tc = tc.copy(targetPipeline = EvaluatorTestTarget.PARTIQL_PIPELINE_ASYNC), + session = session, + ) + class MultipleFunctionTestsProvider : ArgumentsProviderBase() { override fun getParameters() = listOf( EvaluatorTestCase( diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserDDLTest.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserDDLTest.kt new file mode 100644 index 0000000000..8bdba4c8f2 --- /dev/null +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserDDLTest.kt @@ -0,0 +1,41 @@ +package org.partiql.lang.syntax + +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import org.partiql.errors.ErrorCode +import org.partiql.errors.Property +import org.partiql.lang.util.ArgumentsProviderBase + +internal class PartiQLParserDDLTest : PartiQLParserTestBase() { + // As we expended the functionality of DDL, making sure that the PIG Parser is not impacted. + + override val targets: Array = arrayOf(ParserTarget.DEFAULT) + + internal data class ParserErrorTestCase( + val description: String? = null, + val query: String, + val code: ErrorCode, + val context: Map = emptyMap() + ) + + @ArgumentsSource(ErrorTestProvider::class) + @ParameterizedTest + fun errorTests(tc: ParserErrorTestCase) = checkInputThrowingParserException(tc.query, tc.code, tc.context, assertContext = false) + + class ErrorTestProvider : ArgumentsProviderBase() { + override fun getParameters() = listOf( + ParserErrorTestCase( + description = "PIG Parser does not support qualified Identifier as input for Create", + query = "CREATE TABLE foo.bar", + code = ErrorCode.PARSE_UNEXPECTED_TOKEN, + context = mapOf() + ), + ParserErrorTestCase( + description = "PIG Parser does not support qualified Identifier as input for DROP", + query = "DROP Table foo.bar", + code = ErrorCode.PARSE_UNEXPECTED_TOKEN, + context = mapOf(), + ) + ) + } +} diff --git a/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/AbstractPipeline.kt b/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/AbstractPipeline.kt index c08d677f8d..441877d866 100644 --- a/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/AbstractPipeline.kt +++ b/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/AbstractPipeline.kt @@ -4,8 +4,9 @@ import org.partiql.lang.eval.ExprValue import org.partiql.lang.eval.TypingMode /** - * Represents an abstract pipeline (either [org.partiql.lang.CompilerPipeline] or - * [org.partiql.lang.compiler.PartiQLCompilerPipeline]) so that [PipelineEvaluatorTestAdapter] can work with either. + * Represents an abstract pipeline (one of [org.partiql.lang.CompilerPipeline], + * [org.partiql.lang.compiler.PartiQLCompilerPipeline], or [org.partiql.lang.compiler.PartiQLCompilerPipelineAsync]) + * so that [PipelineEvaluatorTestAdapter] can work on any of them. * * Includes only those properties and methods that are required for testing purposes. */ diff --git a/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/EvaluatorErrorTestCase.kt b/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/EvaluatorErrorTestCase.kt index f9f52e22db..c38bb2a7d8 100644 --- a/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/EvaluatorErrorTestCase.kt +++ b/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/EvaluatorErrorTestCase.kt @@ -4,6 +4,8 @@ import org.partiql.errors.ErrorCode import org.partiql.errors.PropertyValueMap import org.partiql.lang.CompilerPipeline import org.partiql.lang.SqlException +import org.partiql.lang.compiler.PartiQLCompilerPipeline +import org.partiql.lang.compiler.PartiQLCompilerPipelineAsync import org.partiql.lang.eval.CompileOptions /** @@ -55,8 +57,8 @@ data class EvaluatorErrorTestCase( val additionalExceptionAssertBlock: (SqlException) -> Unit = { }, /** - * Determines which pipeline this test should run against; the [CompilerPipeline], - * [PartiQLCompilerPipeline] or both. + * Determines which pipeline this test should run against; the [CompilerPipeline], [PartiQLCompilerPipeline], + * [PartiQLCompilerPipelineAsync], or all of them. */ override val targetPipeline: EvaluatorTestTarget = EvaluatorTestTarget.ALL_PIPELINES, diff --git a/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestCase.kt b/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestCase.kt index 50daf3fe4e..b5bf2f804d 100644 --- a/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestCase.kt +++ b/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestCase.kt @@ -1,6 +1,8 @@ package org.partiql.lang.eval.evaluatortestframework import org.partiql.lang.CompilerPipeline +import org.partiql.lang.compiler.PartiQLCompilerPipeline +import org.partiql.lang.compiler.PartiQLCompilerPipelineAsync import org.partiql.lang.eval.CompileOptions import org.partiql.lang.eval.ExprValue @@ -55,8 +57,8 @@ data class EvaluatorTestCase( override val implicitPermissiveModeTest: Boolean = true, /** - * Determines which pipeline this test should run against; the [CompilerPipeline], - * [org.partiql.lang.compiler.PartiQLCompilerPipeline] or both. + * Determines which pipeline this test should run against; the [CompilerPipeline], [PartiQLCompilerPipeline], + * [PartiQLCompilerPipelineAsync], or all of them. */ override val targetPipeline: EvaluatorTestTarget = EvaluatorTestTarget.ALL_PIPELINES, diff --git a/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestDefinition.kt b/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestDefinition.kt index 5489421bdf..de54f55bb1 100644 --- a/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestDefinition.kt +++ b/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestDefinition.kt @@ -1,6 +1,8 @@ package org.partiql.lang.eval.evaluatortestframework import org.partiql.lang.CompilerPipeline +import org.partiql.lang.compiler.PartiQLCompilerPipeline +import org.partiql.lang.compiler.PartiQLCompilerPipelineAsync import org.partiql.lang.eval.CompileOptions /** @@ -32,8 +34,8 @@ interface EvaluatorTestDefinition { val implicitPermissiveModeTest: Boolean /** - * Determines which pipeline this test should run against; the [CompilerPipeline], - * [org.partiql.lang.compiler.PartiQLCompilerPipeline] or both. + * Determines which pipeline this test should run against; the [CompilerPipeline], [PartiQLCompilerPipeline], + * [PartiQLCompilerPipelineAsync], or all of them. */ val targetPipeline: EvaluatorTestTarget diff --git a/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestTarget.kt b/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestTarget.kt index 3c65ed0131..b4e3bd3e43 100644 --- a/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestTarget.kt +++ b/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/EvaluatorTestTarget.kt @@ -21,6 +21,17 @@ enum class EvaluatorTestTarget { /** * Run the test on [org.partiql.lang.compiler.PartiQLCompilerPipeline]. Set this when the test case covers features not * supported by [org.partiql.lang.CompilerPipeline], or when testing features unique to the former. + * + * Since [org.partiql.lang.compiler.PartiQLCompilerPipeline] is deprecated and will be removed in favor of + * [org.partiql.lang.compiler.PartiQLCompilerPipelineAsync], opt to use [PARTIQL_PIPELINE_ASYNC] or [ALL_PIPELINES]. */ PARTIQL_PIPELINE, + + /** + * Run the test on [org.partiql.lang.compiler.PartiQLCompilerPipelineAsync]. Set this when the test case covers + * features not supported by [org.partiql.lang.CompilerPipeline], or when testing features unique to the former. + * + * This is the async version of [PARTIQL_PIPELINE]. + */ + PARTIQL_PIPELINE_ASYNC } diff --git a/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/PipelineFactory.kt b/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/PipelineFactory.kt index 3ccc107c64..7654a95bc0 100644 --- a/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/PipelineFactory.kt +++ b/partiql-lang/src/testFixtures/kotlin/org/partiql/lang/eval/evaluatortestframework/PipelineFactory.kt @@ -4,8 +4,8 @@ import org.partiql.lang.eval.EvaluationSession /** * The implementation of this interface is passed to the constructor of [PipelineEvaluatorTestAdapter]. Determines - * which pipeline (either [org.partiql.lang.CompilerPipeline] or [org.partiql.lang.compiler.PartiQLCompilerPipeline]) will be - * tested. + * which pipeline (either [org.partiql.lang.CompilerPipeline], [org.partiql.lang.CompilerPipelineAsync], + * [org.partiql.lang.compiler.PartiQLCompilerPipeline]) will be tested. */ interface PipelineFactory { val pipelineName: String diff --git a/partiql-parser/src/main/antlr/PartiQL.g4 b/partiql-parser/src/main/antlr/PartiQL.g4 index 041c5e50ae..7d05a22ab0 100644 --- a/partiql-parser/src/main/antlr/PartiQL.g4 +++ b/partiql-parser/src/main/antlr/PartiQL.g4 @@ -71,6 +71,9 @@ execCommand * Currently, this is a small subset of SQL DDL that is likely to make sense for PartiQL as well. */ +// ::= [ ] +qualifiedName : (qualifier+=symbolPrimitive PERIOD)* name=symbolPrimitive; + tableName : symbolPrimitive; tableConstraintName : symbolPrimitive; columnName : symbolPrimitive; @@ -82,12 +85,12 @@ ddl ; createCommand - : CREATE TABLE tableName ( PAREN_LEFT tableDef PAREN_RIGHT )? # CreateTable + : CREATE TABLE qualifiedName ( PAREN_LEFT tableDef PAREN_RIGHT )? # CreateTable | CREATE INDEX ON symbolPrimitive PAREN_LEFT pathSimple ( COMMA pathSimple )* PAREN_RIGHT # CreateIndex ; dropCommand - : DROP TABLE target=tableName # DropTable + : DROP TABLE qualifiedName # DropTable | DROP INDEX target=symbolPrimitive ON on=symbolPrimitive # DropIndex ; @@ -716,8 +719,7 @@ functionCall // SQL-99 10.4 — ::= [ ] functionName - : (qualifier+=symbolPrimitive PERIOD)* name=( CHAR_LENGTH | CHARACTER_LENGTH | OCTET_LENGTH | BIT_LENGTH | - UPPER | LOWER | SIZE | EXISTS | COUNT | MOD) # FunctionNameReserved + : (qualifier+=symbolPrimitive PERIOD)* name=( CHAR_LENGTH | CHARACTER_LENGTH | OCTET_LENGTH | BIT_LENGTH | UPPER | LOWER | SIZE | EXISTS | COUNT | MOD ) # FunctionNameReserved | (qualifier+=symbolPrimitive PERIOD)* name=symbolPrimitive # FunctionNameSymbol ; diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt index 26b365c306..481d08acd7 100644 --- a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt @@ -570,6 +570,18 @@ internal class PartiQLParserDefault : PartiQLParser { } } + override fun visitQualifiedName(ctx: org.partiql.parser.antlr.PartiQLParser.QualifiedNameContext) = translate(ctx) { + val qualifier = ctx.qualifier.map { visitSymbolPrimitive(it) } + val name = visitSymbolPrimitive(ctx.name) + if (qualifier.isEmpty()) { + name + } else { + val root = qualifier.first() + val steps = qualifier.drop(1) + listOf(name) + identifierQualified(root, steps) + } + } + /** * * DATA DEFINITION LANGUAGE (DDL) @@ -579,7 +591,7 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitQueryDdl(ctx: GeneratedParser.QueryDdlContext): AstNode = visitDdl(ctx.ddl()) override fun visitDropTable(ctx: GeneratedParser.DropTableContext) = translate(ctx) { - val table = visitSymbolPrimitive(ctx.tableName().symbolPrimitive()) + val table = visitQualifiedName(ctx.qualifiedName()) statementDDLDropTable(table) } @@ -590,7 +602,7 @@ internal class PartiQLParserDefault : PartiQLParser { } override fun visitCreateTable(ctx: GeneratedParser.CreateTableContext) = translate(ctx) { - val table = visitSymbolPrimitive(ctx.tableName().symbolPrimitive()) + val table = visitQualifiedName(ctx.qualifiedName()) val definition = ctx.tableDef()?.let { visitTableDef(it) } statementDDLCreateTable(table, definition) } diff --git a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserDDLTests.kt b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserDDLTests.kt new file mode 100644 index 0000000000..3fbb0321a4 --- /dev/null +++ b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserDDLTests.kt @@ -0,0 +1,134 @@ +package org.partiql.parser.internal + +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider +import org.junit.jupiter.params.provider.ArgumentsSource +import org.partiql.ast.AstNode +import org.partiql.ast.Identifier +import org.partiql.ast.identifierQualified +import org.partiql.ast.identifierSymbol +import org.partiql.ast.statementDDLCreateTable +import org.partiql.ast.statementDDLDropTable +import java.util.stream.Stream +import kotlin.test.assertEquals + +class PartiQLParserDDLTests { + + private val parser = PartiQLParserDefault() + + data class SuccessTestCase( + val description: String? = null, + val query: String, + val node: AstNode + ) + + @ArgumentsSource(TestProvider::class) + @ParameterizedTest + fun errorTests(tc: SuccessTestCase) = assertExpression(tc.query, tc.node) + + class TestProvider : ArgumentsProvider { + val createTableTests = listOf( + SuccessTestCase( + "CREATE TABLE with unqualified case insensitive name", + "CREATE TABLE foo", + statementDDLCreateTable( + identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), + null + ) + ), + // Support Case Sensitive identifier as table name + // Subsequent process may need to change + // See: https://www.db-fiddle.com/f/9A8mknSNYuRGLfkqkLeiHD/0 for reference. + SuccessTestCase( + "CREATE TABLE with unqualified case sensitive name", + "CREATE TABLE \"foo\"", + statementDDLCreateTable( + identifierSymbol("foo", Identifier.CaseSensitivity.SENSITIVE), + null + ) + ), + SuccessTestCase( + "CREATE TABLE with qualified case insensitive name", + "CREATE TABLE myCatalog.mySchema.foo", + statementDDLCreateTable( + identifierQualified( + identifierSymbol("myCatalog", Identifier.CaseSensitivity.INSENSITIVE), + listOf( + identifierSymbol("mySchema", Identifier.CaseSensitivity.INSENSITIVE), + identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), + ) + ), + null + ) + ), + SuccessTestCase( + "CREATE TABLE with qualified name with mixed case sensitivity", + "CREATE TABLE myCatalog.\"mySchema\".foo", + statementDDLCreateTable( + identifierQualified( + identifierSymbol("myCatalog", Identifier.CaseSensitivity.INSENSITIVE), + listOf( + identifierSymbol("mySchema", Identifier.CaseSensitivity.SENSITIVE), + identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), + ) + ), + null + ) + ), + ) + + val dropTableTests = listOf( + SuccessTestCase( + "DROP TABLE with unqualified case insensitive name", + "DROP TABLE foo", + statementDDLDropTable( + identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), + ) + ), + SuccessTestCase( + "DROP TABLE with unqualified case sensitive name", + "DROP TABLE \"foo\"", + statementDDLDropTable( + identifierSymbol("foo", Identifier.CaseSensitivity.SENSITIVE), + ) + ), + SuccessTestCase( + "DROP TABLE with qualified case insensitive name", + "DROP TABLE myCatalog.mySchema.foo", + statementDDLDropTable( + identifierQualified( + identifierSymbol("myCatalog", Identifier.CaseSensitivity.INSENSITIVE), + listOf( + identifierSymbol("mySchema", Identifier.CaseSensitivity.INSENSITIVE), + identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), + ) + ), + ) + ), + SuccessTestCase( + "DROP TABLE with qualified name with mixed case sensitivity", + "DROP TABLE myCatalog.\"mySchema\".foo", + statementDDLDropTable( + identifierQualified( + identifierSymbol("myCatalog", Identifier.CaseSensitivity.INSENSITIVE), + listOf( + identifierSymbol("mySchema", Identifier.CaseSensitivity.SENSITIVE), + identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), + ) + ), + ) + ), + ) + + override fun provideArguments(context: ExtensionContext?): Stream = + (createTableTests + dropTableTests).map { Arguments.of(it) }.stream() + } + + private fun assertExpression(input: String, expected: AstNode) { + val result = parser.parse(input) + val actual = result.root + assertEquals(expected, actual) + } +} diff --git a/partiql-plan/src/main/resources/partiql_plan.ion b/partiql-plan/src/main/resources/partiql_plan.ion index 3f6f3fade5..30cfaf0a18 100644 --- a/partiql-plan/src/main/resources/partiql_plan.ion +++ b/partiql-plan/src/main/resources/partiql_plan.ion @@ -149,6 +149,15 @@ rex::{ ], }, + nullif::{ + value: rex, + nullifier: rex + }, + + coalesce::{ + args: list::[rex] + }, + collection::{ values: list::[rex], }, diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt index 8d388a17b7..61b8dbb552 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt @@ -3,8 +3,7 @@ package org.partiql.planner import org.partiql.errors.ProblemDetails import org.partiql.errors.ProblemSeverity import org.partiql.plan.Identifier -import org.partiql.spi.BindingCase -import org.partiql.spi.BindingPath +import org.partiql.planner.internal.utils.PlanUtils import org.partiql.types.StaticType /** @@ -27,16 +26,60 @@ public sealed class PlanningProblemDetails( public data class CompileError(val errorMessage: String) : PlanningProblemDetails(ProblemSeverity.ERROR, { errorMessage }) - public data class UndefinedVariable(val id: BindingPath) : - PlanningProblemDetails( - ProblemSeverity.ERROR, - { - val caseSensitive = id.steps.any { it.case == BindingCase.SENSITIVE } - "Undefined variable '${id.key}'." + - quotationHint(caseSensitive) + public data class UndefinedVariable( + val name: Identifier, + val inScopeVariables: Set + ) : PlanningProblemDetails( + ProblemSeverity.ERROR, + { + val humanReadableName = PlanUtils.identifierToString(name) + "Variable $humanReadableName does not exist in the database environment and is not an attribute of the following in-scope variables $inScopeVariables." + + quotationHint(isSymbolAndCaseSensitive(name)) + } + ) { + + @Deprecated("This will be removed in a future major version release.", replaceWith = ReplaceWith("name")) + val variableName: String = when (name) { + is Identifier.Symbol -> name.symbol + is Identifier.Qualified -> when (name.steps.size) { + 0 -> name.root.symbol + else -> name.steps.last().symbol } + } + + @Deprecated("This will be removed in a future major version release.", replaceWith = ReplaceWith("name")) + val caseSensitive: Boolean = when (name) { + is Identifier.Symbol -> name.caseSensitivity == Identifier.CaseSensitivity.SENSITIVE + is Identifier.Qualified -> when (name.steps.size) { + 0 -> name.root.caseSensitivity == Identifier.CaseSensitivity.SENSITIVE + else -> name.steps.last().caseSensitivity == Identifier.CaseSensitivity.SENSITIVE + } + } + + @Deprecated("This will be removed in a future major version release.", replaceWith = ReplaceWith("UndefinedVariable(Identifier, Set)")) + public constructor(variableName: String, caseSensitive: Boolean) : this( + Identifier.Symbol( + variableName, + when (caseSensitive) { + true -> Identifier.CaseSensitivity.SENSITIVE + false -> Identifier.CaseSensitivity.INSENSITIVE + } + ), + emptySet() ) + private companion object { + /** + * Used to check whether the [id] is an [Identifier.Symbol] and whether it is case-sensitive. This is helpful + * for giving the [quotationHint] to the user. + */ + private fun isSymbolAndCaseSensitive(id: Identifier): Boolean = when (id) { + is Identifier.Symbol -> id.caseSensitivity == Identifier.CaseSensitivity.SENSITIVE + is Identifier.Qualified -> false + } + } + } + public data class UndefinedDmlTarget(val variableName: String, val caseSensitive: Boolean) : PlanningProblemDetails( ProblemSeverity.ERROR, @@ -98,6 +141,14 @@ public sealed class PlanningProblemDetails( "Unknown function `$identifier($types)" }) + public data class UnknownAggregateFunction( + val identifier: Identifier, + val args: List, + ) : PlanningProblemDetails(ProblemSeverity.ERROR, { + val types = args.joinToString { "<${it.toString().lowercase()}>" } + "Unknown aggregate function `$identifier($types)" + }) + public object ExpressionAlwaysReturnsNullOrMissing : PlanningProblemDetails( severity = ProblemSeverity.ERROR, messageFormatter = { "Expression always returns null or missing." } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt index 3b479a787e..0c0d96d4d6 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt @@ -50,9 +50,11 @@ import org.partiql.planner.internal.ir.builder.RexOpCaseBranchBuilder import org.partiql.planner.internal.ir.builder.RexOpCaseBuilder import org.partiql.planner.internal.ir.builder.RexOpCastResolvedBuilder import org.partiql.planner.internal.ir.builder.RexOpCastUnresolvedBuilder +import org.partiql.planner.internal.ir.builder.RexOpCoalesceBuilder import org.partiql.planner.internal.ir.builder.RexOpCollectionBuilder import org.partiql.planner.internal.ir.builder.RexOpErrBuilder import org.partiql.planner.internal.ir.builder.RexOpLitBuilder +import org.partiql.planner.internal.ir.builder.RexOpNullifBuilder import org.partiql.planner.internal.ir.builder.RexOpPathIndexBuilder import org.partiql.planner.internal.ir.builder.RexOpPathKeyBuilder import org.partiql.planner.internal.ir.builder.RexOpPathSymbolBuilder @@ -74,6 +76,15 @@ import org.partiql.types.StaticType import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType +import kotlin.Boolean +import kotlin.Char +import kotlin.Int +import kotlin.OptIn +import kotlin.String +import kotlin.collections.List +import kotlin.collections.Set +import kotlin.jvm.JvmField +import kotlin.jvm.JvmStatic import kotlin.random.Random internal abstract class PlanNode { @@ -267,6 +278,8 @@ internal data class Rex( is Cast -> visitor.visitRexOpCast(this, ctx) is Call -> visitor.visitRexOpCall(this, ctx) is Case -> visitor.visitRexOpCase(this, ctx) + is Nullif -> visitor.visitRexOpNullif(this, ctx) + is Coalesce -> visitor.visitRexOpCoalesce(this, ctx) is Collection -> visitor.visitRexOpCollection(this, ctx) is Struct -> visitor.visitRexOpStruct(this, ctx) is Pivot -> visitor.visitRexOpPivot(this, ctx) @@ -302,7 +315,7 @@ internal data class Rex( internal data class Local( @JvmField internal val depth: Int, - @JvmField internal val ref: Int + @JvmField internal val ref: Int, ) : Var() { public override val children: List = emptyList() @@ -594,6 +607,44 @@ internal data class Rex( } } + internal data class Nullif( + @JvmField internal val `value`: Rex, + @JvmField internal val nullifier: Rex, + ) : Op() { + public override val children: List by lazy { + val kids = mutableListOf() + kids.add(value) + kids.add(nullifier) + kids.filterNotNull() + } + + public override fun accept(visitor: PlanVisitor, ctx: C): R = + visitor.visitRexOpNullif(this, ctx) + + internal companion object { + @JvmStatic + internal fun builder(): RexOpNullifBuilder = RexOpNullifBuilder() + } + } + + internal data class Coalesce( + @JvmField internal val args: List, + ) : Op() { + public override val children: List by lazy { + val kids = mutableListOf() + kids.addAll(args) + kids.filterNotNull() + } + + public override fun accept(visitor: PlanVisitor, ctx: C): R = + visitor.visitRexOpCoalesce(this, ctx) + + internal companion object { + @JvmStatic + internal fun builder(): RexOpCoalesceBuilder = RexOpCoalesceBuilder() + } + } + internal data class Collection( @JvmField internal val values: List, ) : Op() { @@ -673,7 +724,7 @@ internal data class Rex( } internal data class Subquery( - @JvmField internal val constructor: Rex, + @JvmField internal val `constructor`: Rex, @JvmField internal val rel: Rel, @JvmField internal val coercion: Coercion, ) : Op() { @@ -1112,6 +1163,10 @@ internal data class Rel( FULL, PARTIAL, } + internal enum class SetQuantifier { + ALL, DISTINCT, + } + internal sealed class Call : PlanNode() { public override fun accept(visitor: PlanVisitor, ctx: C): R = when (this) { is Unresolved -> visitor.visitRelOpAggregateCallUnresolved(this, ctx) @@ -1161,10 +1216,6 @@ internal data class Rel( } } - internal enum class SetQuantifier { - ALL, DISTINCT, - } - internal companion object { @JvmStatic internal fun builder(): RelOpAggregateBuilder = RelOpAggregateBuilder() diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt index e7102da68f..dc8573677d 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt @@ -72,7 +72,8 @@ internal object PlanTransform { override fun visitRefAgg(node: Ref.Agg, ctx: Unit) = symbols.insert(node) @OptIn(PartiQLValueExperimental::class) - override fun visitRefCast(node: Ref.Cast, ctx: Unit) = org.partiql.plan.refCast(node.input, node.target, node.isNullable) + override fun visitRefCast(node: Ref.Cast, ctx: Unit) = + org.partiql.plan.refCast(node.input, node.target, node.isNullable) override fun visitStatement(node: Statement, ctx: Unit) = super.visitStatement(node, ctx) as org.partiql.plan.Statement @@ -194,6 +195,14 @@ internal object PlanTransform { branches = node.branches.map { visitRexOpCaseBranch(it, ctx) }, default = visitRex(node.default, ctx) ) + override fun visitRexOpNullif(node: Rex.Op.Nullif, ctx: Unit) = org.partiql.plan.Rex.Op.Nullif( + value = visitRex(node.value, ctx), + nullifier = visitRex(node.nullifier, ctx), + ) + + override fun visitRexOpCoalesce(node: Rex.Op.Coalesce, ctx: Unit) = + org.partiql.plan.Rex.Op.Coalesce(args = node.args.map { visitRex(it, ctx) }) + override fun visitRexOpCaseBranch(node: Rex.Op.Case.Branch, ctx: Unit) = org.partiql.plan.Rex.Op.Case.Branch( condition = visitRex(node.condition, ctx), rex = visitRex(node.rex, ctx) ) @@ -275,10 +284,11 @@ internal object PlanTransform { predicate = visitRex(node.predicate, ctx), ) - override fun visitRelOpSort(node: Rel.Op.Sort, ctx: Unit) = org.partiql.plan.Rel.Op.Sort( - input = visitRel(node.input, ctx), - specs = node.specs.map { visitRelOpSortSpec(it, ctx) } - ) + override fun visitRelOpSort(node: Rel.Op.Sort, ctx: Unit) = + org.partiql.plan.Rel.Op.Sort( + input = visitRel(node.input, ctx), + specs = node.specs.map { visitRelOpSortSpec(it, ctx) } + ) override fun visitRelOpSortSpec(node: Rel.Op.Sort.Spec, ctx: Unit) = org.partiql.plan.Rel.Op.Sort.Spec( rex = visitRex(node.rex, ctx), @@ -366,7 +376,9 @@ internal object PlanTransform { override fun visitRelOpExcludePath(node: Rel.Op.Exclude.Path, ctx: Unit): org.partiql.plan.Rel.Op.Exclude.Path { val root = when (node.root) { - is Rex.Op.Var.Unresolved -> org.partiql.plan.Rex.Op.Var(-1, -1) // unresolved in `PlanTyper` results in error + is Rex.Op.Var.Unresolved -> org.partiql.plan.Rex.Op.Var( + -1, -1 + ) // unresolved in `PlanTyper` results in error is Rex.Op.Var.Local -> visitRexOpVarLocal(node.root, ctx) is Rex.Op.Var.Global -> error("EXCLUDE only disallows values coming from the input record.") } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt index be719b1acb..11db9199b2 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt @@ -38,8 +38,10 @@ import org.partiql.planner.internal.ir.relType import org.partiql.planner.internal.ir.rex import org.partiql.planner.internal.ir.rexOpCallUnresolved import org.partiql.planner.internal.ir.rexOpCastUnresolved +import org.partiql.planner.internal.ir.rexOpCoalesce import org.partiql.planner.internal.ir.rexOpCollection import org.partiql.planner.internal.ir.rexOpLit +import org.partiql.planner.internal.ir.rexOpNullif import org.partiql.planner.internal.ir.rexOpPathIndex import org.partiql.planner.internal.ir.rexOpPathKey import org.partiql.planner.internal.ir.rexOpPathSymbol @@ -608,43 +610,21 @@ internal object RexConverter { return rex(type, call) } - // coalesce(expr1, expr2, ... exprN) -> - // CASE - // WHEN expr1 IS NOT NULL THEN EXPR1 - // ... - // WHEN exprn is NOT NULL THEN exprn - // ELSE NULL END - override fun visitExprCoalesce(node: Expr.Coalesce, ctx: Env): Rex = plan { + override fun visitExprCoalesce(node: Expr.Coalesce, ctx: Env): Rex { val type = StaticType.ANY - val createBranch: (Rex) -> Rex.Op.Case.Branch = { expr: Rex -> - val updatedCondition = rex(type, negate(call("is_null", expr))) - rexOpCaseBranch(updatedCondition, expr) + val args = node.args.map { arg -> + visitExprCoerce(arg, ctx) } - - val branches = node.args.map { - createBranch(visitExpr(it, ctx)) - }.toMutableList() - - val defaultRex = rex(type = StaticType.NULL, op = rexOpLit(value = nullValue())) - val op = rexOpCase(branches, defaultRex) - rex(type, op) + val op = rexOpCoalesce(args) + return rex(type, op) } - // nullIf(expr1, expr2) -> - // CASE - // WHEN expr1 = expr2 THEN NULL - // ELSE expr1 END - override fun visitExprNullIf(node: Expr.NullIf, ctx: Env): Rex = plan { + override fun visitExprNullIf(node: Expr.NullIf, ctx: Env): Rex { val type = StaticType.ANY - val expr1 = visitExpr(node.value, ctx) - val expr2 = visitExpr(node.nullifier, ctx) - val id = identifierSymbol(Expr.Binary.Op.EQ.name.lowercase(), Identifier.CaseSensitivity.SENSITIVE) - val call = rexOpCallUnresolved(id, listOf(expr1, expr2)) - val branches = listOf( - rexOpCaseBranch(rex(type, call), rex(type = StaticType.NULL, op = rexOpLit(value = nullValue()))), - ) - val op = rexOpCase(branches.toMutableList(), expr1) - rex(type, op) + val value = visitExprCoerce(node.value, ctx) + val nullifier = visitExprCoerce(node.nullifier, ctx) + val op = rexOpNullif(value, nullifier) + return rex(type, op) } /** diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/DynamicTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/DynamicTyper.kt new file mode 100644 index 0000000000..9ae72cc6ad --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/DynamicTyper.kt @@ -0,0 +1,381 @@ +@file:OptIn(PartiQLValueExperimental::class) + +package org.partiql.planner.internal.typer + +import org.partiql.types.MissingType +import org.partiql.types.NullType +import org.partiql.types.StaticType +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType +import org.partiql.value.PartiQLValueType.ANY +import org.partiql.value.PartiQLValueType.BAG +import org.partiql.value.PartiQLValueType.BINARY +import org.partiql.value.PartiQLValueType.BLOB +import org.partiql.value.PartiQLValueType.BOOL +import org.partiql.value.PartiQLValueType.BYTE +import org.partiql.value.PartiQLValueType.CHAR +import org.partiql.value.PartiQLValueType.CLOB +import org.partiql.value.PartiQLValueType.DATE +import org.partiql.value.PartiQLValueType.DECIMAL +import org.partiql.value.PartiQLValueType.DECIMAL_ARBITRARY +import org.partiql.value.PartiQLValueType.FLOAT32 +import org.partiql.value.PartiQLValueType.FLOAT64 +import org.partiql.value.PartiQLValueType.INT +import org.partiql.value.PartiQLValueType.INT16 +import org.partiql.value.PartiQLValueType.INT32 +import org.partiql.value.PartiQLValueType.INT64 +import org.partiql.value.PartiQLValueType.INT8 +import org.partiql.value.PartiQLValueType.INTERVAL +import org.partiql.value.PartiQLValueType.LIST +import org.partiql.value.PartiQLValueType.MISSING +import org.partiql.value.PartiQLValueType.NULL +import org.partiql.value.PartiQLValueType.SEXP +import org.partiql.value.PartiQLValueType.STRING +import org.partiql.value.PartiQLValueType.STRUCT +import org.partiql.value.PartiQLValueType.SYMBOL +import org.partiql.value.PartiQLValueType.TIME +import org.partiql.value.PartiQLValueType.TIMESTAMP + +/** + * Graph of super types for quick lookup because we don't have a tree. + */ +internal typealias SuperGraph = Array> + +/** + * For lack of a better name, this is the "dynamic typer" which implements the typing rules of SQL-99 9.3. + * + * SQL-99 9.3 Data types of results of aggregations (, , ) + * > https://web.cecs.pdx.edu/~len/sql1999.pdf#page=359 + * + * Usage, + * To calculate the type of an "aggregation" create a new instance and "accumulate" each possible type. + * This is a pain with StaticType... + */ +@OptIn(PartiQLValueExperimental::class) +internal class DynamicTyper { + + private var supertype: PartiQLValueType? = null + private var args = mutableListOf() + + private var nullable = false + private var missable = false + private val types = mutableSetOf() + + /** + * This primarily unpacks a StaticType because of NULL, MISSING. + * + * - T + * - NULL + * - MISSING + * - (NULL) + * - (MISSING) + * - (T..) + * - (T..|NULL) + * - (T..|MISSING) + * - (T..|NULL|MISSING) + * - (NULL|MISSING) + * + * @param type + */ + fun accumulate(type: StaticType) { + val nonAbsentTypes = mutableSetOf() + val flatType = type.flatten() + if (flatType == StaticType.ANY) { + // Use ANY runtime; do not expand ANY + types.add(flatType) + args.add(ANY) + calculate(ANY) + return + } + for (t in flatType.allTypes) { + when (t) { + is NullType -> nullable = true + is MissingType -> missable = true + else -> nonAbsentTypes.add(t) + } + } + when (nonAbsentTypes.size) { + 0 -> { + // Ignore in calculating supertype. + args.add(NULL) + } + 1 -> { + // Had single type + val single = nonAbsentTypes.first() + val singleRuntime = single.toRuntimeType() + types.add(single) + args.add(singleRuntime) + calculate(singleRuntime) + } + else -> { + // Had a union; use ANY runtime + types.addAll(nonAbsentTypes) + args.add(ANY) + calculate(ANY) + } + } + } + + /** + * Returns a pair of the return StaticType and the coercion. + * + * If the list is null, then no mapping is required. + * + * @return + */ + fun mapping(): Pair>?> { + val modifiers = mutableSetOf() + if (nullable) modifiers.add(StaticType.NULL) + if (missable) modifiers.add(StaticType.MISSING) + // If at top supertype, then return union of all accumulated types + if (supertype == ANY) { + return StaticType.unionOf(types + modifiers).flatten() to null + } + // If a collection, then return union of all accumulated types as these coercion rules are not defined by SQL. + if (supertype == STRUCT || supertype == BAG || supertype == LIST || supertype == SEXP) { + return StaticType.unionOf(types + modifiers) to null + } + // If not initialized, then return null, missing, or null|missing. + val s = supertype + if (s == null) { + val t = if (modifiers.isEmpty()) StaticType.MISSING else StaticType.unionOf(modifiers).flatten() + return t to null + } + // Otherwise, return the supertype along with the coercion mapping + val type = s.toNonNullStaticType() + val mapping = args.map { it to s } + return if (modifiers.isEmpty()) { + type to mapping + } else { + StaticType.unionOf(setOf(type) + modifiers).flatten() to mapping + } + } + + private fun calculate(type: PartiQLValueType) { + val s = supertype + // Initialize + if (s == null) { + supertype = type + return + } + // Don't bother calculating the new supertype, we've already hit `dynamic`. + if (s == ANY) return + // Lookup and set the new minimum common supertype + supertype = when { + type == ANY -> type + type == NULL || type == MISSING || s == type -> return // skip + else -> graph[s][type] ?: ANY // lookup, if missing then go to top. + } + } + + private operator fun Array.get(t: PartiQLValueType): T = get(t.ordinal) + + /** + * !! IMPORTANT !! + * + * This is duplicated from the TypeLattice because that was removed in v1.0.0. I wanted to implement this as + * a standalone component so that it is easy to merge (and later merge with CastTable) into v1.0.0. + */ + companion object { + + private operator fun Array.set(t: PartiQLValueType, value: T): Unit = this.set(t.ordinal, value) + + @JvmStatic + private val N = PartiQLValueType.values().size + + @JvmStatic + private fun edges(vararg edges: Pair): Array { + val arr = arrayOfNulls(N) + for (type in edges) { + arr[type.first] = type.second + } + return arr + } + + /** + * This table defines the rules in the SQL-99 section 9.3 BUT we don't have type constraints yet. + * + * TODO collection supertypes + * TODO datetime supertypes + */ + @JvmStatic + internal val graph: SuperGraph = run { + val graph = arrayOfNulls>(N) + for (type in PartiQLValueType.values()) { + // initialize all with empty edges + graph[type] = arrayOfNulls(N) + } + graph[ANY] = edges() + graph[NULL] = edges() + graph[MISSING] = edges() + graph[BOOL] = edges( + BOOL to BOOL + ) + graph[INT8] = edges( + INT8 to INT8, + INT16 to INT16, + INT32 to INT32, + INT64 to INT64, + INT to INT, + DECIMAL to DECIMAL, + DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, + FLOAT32 to FLOAT32, + FLOAT64 to FLOAT64, + ) + graph[INT16] = edges( + INT8 to INT16, + INT16 to INT16, + INT32 to INT32, + INT64 to INT64, + INT to INT, + DECIMAL to DECIMAL, + DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, + FLOAT32 to FLOAT32, + FLOAT64 to FLOAT64, + ) + graph[INT32] = edges( + INT8 to INT32, + INT16 to INT32, + INT32 to INT32, + INT64 to INT64, + INT to INT, + DECIMAL to DECIMAL, + DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, + FLOAT32 to FLOAT32, + FLOAT64 to FLOAT64, + ) + graph[INT64] = edges( + INT8 to INT64, + INT16 to INT64, + INT32 to INT64, + INT64 to INT64, + INT to INT, + DECIMAL to DECIMAL, + DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, + FLOAT32 to FLOAT32, + FLOAT64 to FLOAT64, + ) + graph[INT] = edges( + INT8 to INT, + INT16 to INT, + INT32 to INT, + INT64 to INT, + INT to INT, + DECIMAL to DECIMAL, + DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, + FLOAT32 to FLOAT32, + FLOAT64 to FLOAT64, + ) + graph[DECIMAL] = edges( + INT8 to DECIMAL, + INT16 to DECIMAL, + INT32 to DECIMAL, + INT64 to DECIMAL, + INT to DECIMAL, + DECIMAL to DECIMAL, + DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, + FLOAT32 to FLOAT32, + FLOAT64 to FLOAT64, + ) + graph[DECIMAL_ARBITRARY] = edges( + INT8 to DECIMAL_ARBITRARY, + INT16 to DECIMAL_ARBITRARY, + INT32 to DECIMAL_ARBITRARY, + INT64 to DECIMAL_ARBITRARY, + INT to DECIMAL_ARBITRARY, + DECIMAL to DECIMAL_ARBITRARY, + DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, + FLOAT32 to FLOAT32, + FLOAT64 to FLOAT64, + ) + graph[FLOAT32] = edges( + INT8 to FLOAT32, + INT16 to FLOAT32, + INT32 to FLOAT32, + INT64 to FLOAT32, + INT to FLOAT32, + DECIMAL to FLOAT32, + DECIMAL_ARBITRARY to FLOAT32, + FLOAT32 to FLOAT32, + FLOAT64 to FLOAT64, + ) + graph[FLOAT64] = edges( + INT8 to FLOAT64, + INT16 to FLOAT64, + INT32 to FLOAT64, + INT64 to FLOAT64, + INT to FLOAT64, + DECIMAL to FLOAT64, + DECIMAL_ARBITRARY to FLOAT64, + FLOAT32 to FLOAT64, + FLOAT64 to FLOAT64, + ) + graph[CHAR] = edges( + CHAR to CHAR, + STRING to STRING, + SYMBOL to STRING, + CLOB to CLOB, + ) + graph[STRING] = edges( + CHAR to STRING, + STRING to STRING, + SYMBOL to STRING, + CLOB to CLOB, + ) + graph[SYMBOL] = edges( + CHAR to SYMBOL, + STRING to STRING, + SYMBOL to SYMBOL, + CLOB to CLOB, + ) + graph[BINARY] = edges( + BINARY to BINARY, + ) + graph[BYTE] = edges( + BYTE to BYTE, + BLOB to BLOB, + ) + graph[BLOB] = edges( + BYTE to BLOB, + BLOB to BLOB, + ) + graph[DATE] = edges( + DATE to DATE, + ) + graph[CLOB] = edges( + CHAR to CLOB, + STRING to CLOB, + SYMBOL to CLOB, + CLOB to CLOB, + ) + graph[TIME] = edges( + TIME to TIME, + ) + graph[TIMESTAMP] = edges( + TIMESTAMP to TIMESTAMP, + ) + graph[INTERVAL] = edges( + INTERVAL to INTERVAL, + ) + graph[LIST] = edges( + LIST to LIST, + SEXP to SEXP, + BAG to BAG, + ) + graph[SEXP] = edges( + LIST to SEXP, + SEXP to SEXP, + BAG to BAG, + ) + graph[BAG] = edges( + LIST to BAG, + SEXP to BAG, + BAG to BAG, + ) + graph[STRUCT] = edges( + STRUCT to STRUCT, + ) + graph.requireNoNulls() + } + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index 9dbf331d94..8c139be9f0 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -48,9 +48,11 @@ import org.partiql.planner.internal.ir.relOpUnpivot import org.partiql.planner.internal.ir.relType import org.partiql.planner.internal.ir.rex import org.partiql.planner.internal.ir.rexOpCaseBranch +import org.partiql.planner.internal.ir.rexOpCoalesce import org.partiql.planner.internal.ir.rexOpCollection import org.partiql.planner.internal.ir.rexOpErr import org.partiql.planner.internal.ir.rexOpLit +import org.partiql.planner.internal.ir.rexOpNullif import org.partiql.planner.internal.ir.rexOpPathIndex import org.partiql.planner.internal.ir.rexOpPathKey import org.partiql.planner.internal.ir.rexOpPathSymbol @@ -62,6 +64,7 @@ import org.partiql.planner.internal.ir.rexOpSubquery import org.partiql.planner.internal.ir.rexOpTupleUnion import org.partiql.planner.internal.ir.statementQuery import org.partiql.planner.internal.ir.util.PlanRewriter +import org.partiql.planner.internal.utils.PlanUtils import org.partiql.spi.BindingCase import org.partiql.spi.BindingName import org.partiql.spi.BindingPath @@ -455,8 +458,8 @@ internal class PlanTyper( Scope.GLOBAL -> env.resolveObj(path) ?: locals.resolve(path) } if (resolvedVar == null) { - handleUndefinedVariable(node.identifier) - return rexErr("Undefined variable `${node.identifier.debug()}`") + val details = handleUndefinedVariable(node.identifier, locals.schema.map { it.name }.toSet()) + return rex(MISSING, rexOpErr(details.message)) } return visitRex(resolvedVar, null) } @@ -521,7 +524,6 @@ internal class PlanTyper( override fun visitRexOpPathSymbol(node: Rex.Op.Path.Symbol, ctx: StaticType?): Rex { val root = visitRex(node.root, node.root.type) - val paths = root.type.allTypes.map { type -> val struct = type as? StructType ?: return@map rex(MISSING, rexOpLit(missingValue())) val (pathType, replacementId) = inferStructLookup( @@ -579,13 +581,14 @@ internal class PlanTyper( } override fun visitRexOpCastResolved(node: Rex.Op.Cast.Resolved, ctx: StaticType?): Rex { - val missable = node.arg.type.isMissable() || node.cast.safety == UNSAFE - var type = when (node.cast.isNullable) { - true -> node.cast.target.toStaticType() - false -> node.cast.target.toNonNullStaticType() + var type = node.cast.target.toNonNullStaticType() + val nullable = node.arg.type.isNullable() + if (nullable) { + type = type.asNullable() } + val missable = node.arg.type.isMissable() || node.cast.safety == UNSAFE if (missable) { - type = unionOf(type, MISSING) + type = unionOf(type, MISSING).flatten() } return rex(type, node) } @@ -671,34 +674,128 @@ internal class PlanTyper( } override fun visitRexOpCase(node: Rex.Op.Case, ctx: StaticType?): Rex { - // Type branches and prune branches known to never execute - val newBranches = node.branches.map { visitRexOpCaseBranch(it, it.rex.type) } - .filterNot { isLiteralBool(it.condition, false) } + // Rewrite CASE-WHEN branches + val oldBranches = node.branches.toTypedArray() + val newBranches = mutableListOf() + val typer = DynamicTyper() + for (i in oldBranches.indices) { + + // Type the branch + var branch = oldBranches[i] + branch = visitRexOpCaseBranch(branch, branch.rex.type) + + // Check if branch condition is a literal + if (boolOrNull(branch.condition.op) == false) { + continue // prune + } - newBranches.forEach { branch -> - if (canBeBoolean(branch.condition.type).not()) { + // Emit typing error if a branch condition is never a boolean (prune) + if (!canBeBoolean(branch.condition.type)) { onProblem.invoke( Problem( UNKNOWN_PROBLEM_LOCATION, PlanningProblemDetails.IncompatibleTypesForOp(branch.condition.type.allTypes, "CASE_WHEN") ) ) + // prune, always false + continue } + + // Accumulate typing information + typer.accumulate(branch.rex.type) + newBranches.add(branch) } - val default = visitRex(node.default, node.default.type) - - // Calculate final expression (short-circuit to first branch if the condition is always TRUE). - val resultTypes = newBranches.map { it.rex }.map { it.type } + listOf(default.type) - return when (newBranches.size) { - 0 -> default - else -> when (isLiteralBool(newBranches[0].condition, true)) { - true -> newBranches[0].rex - false -> rex( - type = unionOf(resultTypes.toSet()).flatten(), - node.copy(branches = newBranches, default = default) - ) + + // Rewrite ELSE branch + var newDefault = visitRex(node.default, null) + if (newBranches.isEmpty()) { + return newDefault + } + typer.accumulate(newDefault.type) + + // Compute the CASE-WHEN type from the accumulator + val (type, mapping) = typer.mapping() + + // Rewrite branches if we have coercions. + if (mapping != null) { + val msize = mapping.size + val bsize = newBranches.size + 1 + assert(msize == bsize) { "Coercion mappings `len $msize` did not match the number of CASE-WHEN branches `len $bsize`" } + // Rewrite branches + for (i in newBranches.indices) { + val (operand, target) = mapping[i] + if (operand == target) continue // skip + val branch = newBranches[i] + val cast = env.resolveCast(branch.rex, target)!! + val rex = rex(type, cast) + newBranches[i] = branch.copy(rex = rex) + } + // Rewrite default + val (operand, target) = mapping.last() + if (operand != target) { + val cast = env.resolveCast(newDefault, target)!! + newDefault = rex(type, cast) } } + + // TODO constant folding in planner which also means branch pruning + // This is added for backwards compatibility, we return the first branch if it's true + if (boolOrNull(newBranches[0].condition.op) == true) { + return newBranches[0].rex + } + + val op = Rex.Op.Case(newBranches, newDefault) + return rex(type, op) + } + + // COALESCE(v1, v2,..., vN) + // == + // CASE + // WHEN v1 IS NOT NULL THEN v1 -- WHEN branch always a boolean + // WHEN v2 IS NOT NULL THEN v2 -- WHEN branch always a boolean + // ... -- similarly for v3..vN-1 + // ELSE vN + // END + // --> minimal common supertype of(, , ..., ) + override fun visitRexOpCoalesce(node: Rex.Op.Coalesce, ctx: StaticType?): Rex { + val args = node.args.map { visitRex(it, it.type) }.toMutableList() + val typer = DynamicTyper() + args.forEach { v -> + typer.accumulate(v.type) + } + val (type, mapping) = typer.mapping() + if (mapping != null) { + assert(mapping.size == args.size) { "Coercion mappings `len ${mapping.size}` did not match the number of COALESCE arguments `len ${args.size}`" } + for (i in args.indices) { + val (operand, target) = mapping[i] + if (operand == target) continue // skip; no coercion needed + val cast = env.resolveCast(args[i], target) + if (cast != null) { + val rex = rex(type, cast) + args[i] = rex + } + } + } + val op = rexOpCoalesce(args) + return rex(type, op) + } + + // NULLIF(v1, v2) + // == + // CASE + // WHEN v1 = v2 THEN NULL -- WHEN branch always a boolean + // ELSE v1 + // END + // --> minimal common supertype of (NULL, ) + override fun visitRexOpNullif(node: Rex.Op.Nullif, ctx: StaticType?): Rex { + val value = visitRex(node.value, node.value.type) + val nullifier = visitRex(node.nullifier, node.nullifier.type) + val typer = DynamicTyper() + typer.accumulate(NULL) + typer.accumulate(value.type) + val (type, _) = typer.mapping() + val op = rexOpNullif(value, nullifier) + return rex(type, op) } /** @@ -713,11 +810,12 @@ internal class PlanTyper( } } + /** + * Returns the boolean value of the expression. For now, only handle literals. + */ @OptIn(PartiQLValueExperimental::class) - private fun isLiteralBool(rex: Rex, bool: Boolean): Boolean { - val op = rex.op as? Rex.Op.Lit ?: return false - val value = op.value as? BoolValue ?: return false - return value.value == bool + private fun boolOrNull(op: Rex.Op): Boolean? { + return if (op is Rex.Op.Lit && op.value is BoolValue) op.value.value else null } /** @@ -784,7 +882,7 @@ internal class PlanTyper( } // Replace the result's type - val type = AnyOfType(ref.type.allTypes.filterIsInstance().toSet()) + val type = AnyOfType(ref.type.allTypes.filterIsInstance().toSet()).flatten() val replacementVal = ref.copy(type = type) val rex = when (ref.op is Rex.Op.Var.Local) { true -> RexReplacer.replace(result, ref, replacementVal) @@ -1185,7 +1283,7 @@ internal class PlanTyper( // AKA, the Function IS MISSING // return signature return type !fn.isMissable && !fn.isMissingCall && !fn.isNullable && !fn.isNullCall -> fn.returns.toNonNullStaticType() - isNull || (!fn.isMissable && hadMissing) -> fn.returns.toStaticType() + isNull || (!fn.isMissable && hadMissing) -> NULL isNullable -> fn.returns.toStaticType() else -> fn.returns.toNonNullStaticType() } @@ -1328,13 +1426,20 @@ internal class PlanTyper( // ERRORS - private fun handleUndefinedVariable(id: Identifier) { - val publicId = id.toBindingPath() + /** + * Invokes [onProblem] with a newly created [PlanningProblemDetails.UndefinedVariable] and returns the + * [PlanningProblemDetails.UndefinedVariable]. + */ + private fun handleUndefinedVariable(name: Identifier, locals: Set): PlanningProblemDetails.UndefinedVariable { + val planName = PlanUtils.externalize(name) + val details = PlanningProblemDetails.UndefinedVariable(planName, locals) onProblem( Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, details = PlanningProblemDetails.UndefinedVariable(publicId) + sourceLocation = UNKNOWN_PROBLEM_LOCATION, + details = details ) ) + return details } private fun handleUnexpectedType(actual: StaticType, expected: Set) { diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt index 26726686b8..062791ffb1 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt @@ -120,9 +120,15 @@ private fun StaticType.asRuntimeType(): PartiQLValueType = when (this) { is ListType -> PartiQLValueType.LIST is SexpType -> PartiQLValueType.SEXP is DateType -> PartiQLValueType.DATE - is DecimalType -> when (this.precisionScaleConstraint) { - is DecimalType.PrecisionScaleConstraint.Constrained -> PartiQLValueType.DECIMAL - DecimalType.PrecisionScaleConstraint.Unconstrained -> PartiQLValueType.DECIMAL_ARBITRARY + // TODO: Run time decimal type does not model precision scale constraint yet + // despite that we match to Decimal vs Decimal_ARBITRARY (PVT) here + // but when mapping it back to Static Type, (i.e, mapping function return type to Value Type) + // we can only map to Unconstrained decimal (Static Type) + is DecimalType -> { + when (this.precisionScaleConstraint) { + is DecimalType.PrecisionScaleConstraint.Constrained -> PartiQLValueType.DECIMAL + DecimalType.PrecisionScaleConstraint.Unconstrained -> PartiQLValueType.DECIMAL_ARBITRARY + } } is FloatType -> PartiQLValueType.FLOAT64 is GraphType -> error("Graph type missing from runtime types") diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/utils/PlanUtils.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/utils/PlanUtils.kt new file mode 100644 index 0000000000..4a0e1b5191 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/utils/PlanUtils.kt @@ -0,0 +1,50 @@ +package org.partiql.planner.internal.utils + +import org.partiql.plan.Identifier + +internal object PlanUtils { + + /** + * Transforms an identifier to a human-readable string. + * + * Example output: aCaseInsensitiveCatalog."aCaseSensitiveSchema".aCaseInsensitiveTable + */ + fun identifierToString(node: Identifier): String = when (node) { + is Identifier.Symbol -> identifierSymbolToString(node) + is Identifier.Qualified -> { + val toJoin = listOf(node.root) + node.steps + toJoin.joinToString(separator = ".") { ident -> + identifierSymbolToString(ident) + } + } + } + + private fun identifierSymbolToString(node: Identifier.Symbol) = when (node.caseSensitivity) { + Identifier.CaseSensitivity.SENSITIVE -> "\"${node.symbol}\"" + Identifier.CaseSensitivity.INSENSITIVE -> node.symbol + } + + fun externalize(node: org.partiql.planner.internal.ir.Identifier): Identifier = when (node) { + is org.partiql.planner.internal.ir.Identifier.Symbol -> externalize(node) + is org.partiql.planner.internal.ir.Identifier.Qualified -> externalize(node) + } + + private fun externalize(node: org.partiql.planner.internal.ir.Identifier.Symbol): Identifier.Symbol { + val symbol = node.symbol + val case = externalize(node.caseSensitivity) + return Identifier.Symbol(symbol, case) + } + + private fun externalize(node: org.partiql.planner.internal.ir.Identifier.Qualified): Identifier.Qualified { + val root = externalize(node.root) + val steps = node.steps.map { externalize(it) } + return Identifier.Qualified(root, steps) + } + + private fun externalize(node: org.partiql.planner.internal.ir.Identifier.CaseSensitivity): Identifier.CaseSensitivity { + return when (node) { + org.partiql.planner.internal.ir.Identifier.CaseSensitivity.SENSITIVE -> Identifier.CaseSensitivity.SENSITIVE + org.partiql.planner.internal.ir.Identifier.CaseSensitivity.INSENSITIVE -> Identifier.CaseSensitivity.INSENSITIVE + } + } +} diff --git a/partiql-planner/src/main/resources/partiql_plan_internal.ion b/partiql-planner/src/main/resources/partiql_plan_internal.ion index 4e1c06830a..f13f4ab192 100644 --- a/partiql-planner/src/main/resources/partiql_plan_internal.ion +++ b/partiql-planner/src/main/resources/partiql_plan_internal.ion @@ -168,6 +168,15 @@ rex::{ ], }, + nullif::{ + value: rex, + nullifier: rex + }, + + coalesce::{ + args: list::[rex] + }, + collection::{ values: list::[rex], }, diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/EnvTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/EnvTest.kt deleted file mode 100644 index 962586b8cb..0000000000 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/EnvTest.kt +++ /dev/null @@ -1,68 +0,0 @@ -// package org.partiql.planner.internal -// -// import org.partiql.plan.Catalog -// import org.partiql.planner.internal.typer.TypeEnv -// -// class EnvTest { -// -// companion object { -// -// private val root = this::class.java.getResource("/catalogs/default/pql")!!.toURI().toPath() -// -// private val EMPTY_TYPE_ENV = TypeEnv(schema = emptyList()) -// -// private val GLOBAL_OS = Catalog( -// name = "pql", -// symbols = listOf( -// Catalog.Symbol(path = listOf("main", "os"), type = StaticType.STRING) -// ) -// ) -// } -// -// private lateinit var env: Env -// -// @BeforeEach -// fun init() { -// env = Env( -// PartiQLPlanner.Session( -// queryId = Random().nextInt().toString(), -// userId = "test-user", -// currentCatalog = "pql", -// currentDirectory = listOf("main"), -// catalogs = mapOf( -// "pql" to LocalConnector.Metadata(root) -// ), -// ) -// ) -// } -// -// @Test -// fun testGlobalMatchingSensitiveName() { -// val path = BindingPath(listOf(BindingName("os", BindingCase.SENSITIVE))) -// assertNotNull(env.resolve(path, EMPTY_TYPE_ENV, Scope.GLOBAL)) -// assertEquals(1, env.catalogs.size) -// assert(env.catalogs.contains(GLOBAL_OS)) -// } -// -// @Test -// fun testGlobalMatchingInsensitiveName() { -// val path = BindingPath(listOf(BindingName("oS", BindingCase.INSENSITIVE))) -// assertNotNull(env.resolve(path, EMPTY_TYPE_ENV, Scope.GLOBAL)) -// assertEquals(1, env.catalogs.size) -// assert(env.catalogs.contains(GLOBAL_OS)) -// } -// -// @Test -// fun testGlobalNotMatchingSensitiveName() { -// val path = BindingPath(listOf(BindingName("oS", BindingCase.SENSITIVE))) -// assertNull(env.resolve(path, EMPTY_TYPE_ENV, Scope.GLOBAL)) -// assert(env.catalogs.isEmpty()) -// } -// -// @Test -// fun testGlobalNotMatchingInsensitiveName() { -// val path = BindingPath(listOf(BindingName("nonexistent", BindingCase.INSENSITIVE))) -// assertNull(env.resolve(path, EMPTY_TYPE_ENV, Scope.GLOBAL)) -// assert(env.catalogs.isEmpty()) -// } -// } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt index c5eef7767f..b0bfcc9d7e 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt @@ -20,6 +20,7 @@ import org.partiql.spi.BindingPath import org.partiql.spi.connector.ConnectorMetadata import org.partiql.spi.connector.ConnectorSession import org.partiql.types.StaticType +import org.partiql.value.PartiQLValueExperimental import java.util.Random import java.util.stream.Stream @@ -66,6 +67,7 @@ abstract class PartiQLTyperTestBase { /** * Build a ConnectorMetadata instance from the list of types. */ + @OptIn(PartiQLValueExperimental::class) private fun buildMetadata(catalog: String, types: List): ConnectorMetadata { val cat = MemoryCatalog.PartiQL().name(catalog).build() val connector = MemoryConnector(cat) diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt index a7427f3b92..667f46e708 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt @@ -1,6 +1,8 @@ package org.partiql.planner.internal.typer import com.amazon.ionelement.api.loadSingleElement +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import org.junit.jupiter.api.extension.ExtensionContext import org.junit.jupiter.api.parallel.Execution @@ -13,6 +15,7 @@ import org.junit.jupiter.params.provider.MethodSource import org.partiql.errors.Problem import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION import org.partiql.parser.PartiQLParser +import org.partiql.plan.Identifier import org.partiql.plan.PartiQLPlan import org.partiql.plan.Statement import org.partiql.plan.debug.PlanPrinter @@ -38,8 +41,11 @@ import org.partiql.types.BagType import org.partiql.types.ListType import org.partiql.types.SexpType import org.partiql.types.StaticType +import org.partiql.types.StaticType.Companion.MISSING +import org.partiql.types.StaticType.Companion.unionOf import org.partiql.types.StructType import org.partiql.types.TupleConstraint +import org.partiql.value.PartiQLValueExperimental import java.util.stream.Stream import kotlin.reflect.KClass import kotlin.test.assertEquals @@ -50,7 +56,7 @@ class PlanTyperTestsPorted { sealed class TestCase { class SuccessTestCase( - val name: String, + val name: String? = null, val key: PartiQLTest.Key? = null, val query: String? = null, val catalog: String = "pql", @@ -58,7 +64,12 @@ class PlanTyperTestsPorted { val expected: StaticType, val warnings: ProblemHandler? = null, ) : TestCase() { - override fun toString(): String = "$name : ${query ?: key}" + override fun toString(): String { + if (key != null) { + return "${key.group} : ${key.name}" + } + return "${name!!} : $query" + } } class ErrorTestCase( @@ -106,9 +117,22 @@ class PlanTyperTestsPorted { override fun getUserId(): String = "user-id" } + private fun id(vararg parts: Identifier.Symbol): Identifier { + return when (parts.size) { + 0 -> error("Identifier requires more than one part.") + 1 -> parts.first() + else -> Identifier.Qualified(parts.first(), parts.drop(1)) + } + } + + private fun sensitive(part: String): Identifier.Symbol = Identifier.Symbol(part, Identifier.CaseSensitivity.SENSITIVE) + + private fun insensitive(part: String): Identifier.Symbol = Identifier.Symbol(part, Identifier.CaseSensitivity.INSENSITIVE) + /** * MemoryConnector.Factory from reading the resources in /resource_path.txt for Github CI/CD. */ + @OptIn(PartiQLValueExperimental::class) val catalogs: List> by lazy { val inputStream = this::class.java.getResourceAsStream("/resource_path.txt")!! val map = mutableMapOf>>() @@ -911,7 +935,7 @@ class PlanTyperTestsPorted { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable(insensitiveId("a")) + PlanningProblemDetails.UndefinedVariable(insensitive("a"), setOf("t1", "t2")) ) } ), @@ -2129,7 +2153,7 @@ class PlanTyperTestsPorted { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable(insensitiveId("unknown_col")) + PlanningProblemDetails.UndefinedVariable(insensitive("unknown_col"), setOf("pets")) ) } ), @@ -2517,6 +2541,384 @@ class PlanTyperTestsPorted { ) ) ), + // + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-00"), + catalog = "pql", + expected = StaticType.INT4 + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-02"), + catalog = "pql", + expected = StaticType.INT4 + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-03"), + catalog = "pql", + expected = StaticType.INT8 + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-04"), + catalog = "pql", + expected = StaticType.INT + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-05"), + catalog = "pql", + expected = StaticType.INT + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-06"), + catalog = "pql", + expected = StaticType.INT + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-07"), + catalog = "pql", + expected = StaticType.INT8 + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-08"), + catalog = "pql", + expected = unionOf(StaticType.INT, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-09"), + catalog = "pql", + expected = unionOf(StaticType.INT, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-10"), + catalog = "pql", + expected = unionOf(StaticType.DECIMAL, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-11"), + catalog = "pql", + expected = unionOf(StaticType.INT, StaticType.MISSING), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-12"), + catalog = "pql", + expected = StaticType.FLOAT + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-13"), + catalog = "pql", + expected = unionOf(StaticType.FLOAT, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-14"), + catalog = "pql", + expected = StaticType.STRING, + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-15"), + catalog = "pql", + expected = unionOf(StaticType.STRING, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-16"), + catalog = "pql", + expected = StaticType.CLOB, + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-17"), + catalog = "pql", + expected = unionOf(StaticType.CLOB, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-18"), + catalog = "pql", + expected = unionOf(StaticType.STRING, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-19"), + catalog = "pql", + expected = unionOf(StaticType.STRING, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-20"), + catalog = "pql", + expected = StaticType.NULL, + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-21"), + catalog = "pql", + expected = unionOf(StaticType.STRING, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-22"), + catalog = "pql", + expected = unionOf(StaticType.INT4, StaticType.NULL, StaticType.MISSING), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-23"), + catalog = "pql", + expected = StaticType.INT4, + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-24"), + catalog = "pql", + expected = unionOf(StaticType.INT4, StaticType.INT8, StaticType.STRING), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-25"), + catalog = "pql", + expected = unionOf(StaticType.INT4, StaticType.INT8, StaticType.STRING, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-26"), + catalog = "pql", + expected = unionOf(StaticType.INT4, StaticType.INT8, StaticType.STRING, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-27"), + catalog = "pql", + expected = unionOf(StaticType.INT2, StaticType.INT4, StaticType.INT8, StaticType.INT, StaticType.DECIMAL, StaticType.STRING, StaticType.CLOB), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-28"), + catalog = "pql", + expected = unionOf(StaticType.INT2, StaticType.INT4, StaticType.INT8, StaticType.INT, StaticType.DECIMAL, StaticType.STRING, StaticType.CLOB, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-29"), + catalog = "pql", + expected = unionOf( + StructType( + fields = listOf( + StructType.Field("x", StaticType.INT4), + StructType.Field("y", StaticType.INT4), + ), + ), + StructType( + fields = listOf( + StructType.Field("x", StaticType.INT8), + StructType.Field("y", StaticType.INT8), + ), + ), + StaticType.NULL, + ), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-30"), + catalog = "pql", + expected = MISSING + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-31"), + catalog = "pql", + expected = StaticType.ANY + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-32"), + catalog = "pql", + expected = StaticType.ANY + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-33"), + catalog = "pql", + expected = StaticType.ANY + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-34"), + catalog = "pql", + expected = StaticType.ANY + ), + ) + + @JvmStatic + fun nullIf() = listOf( + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-00"), + catalog = "pql", + expected = StaticType.INT4.asNullable() + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-01"), + catalog = "pql", + expected = StaticType.INT4.asNullable() + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-02"), + catalog = "pql", + expected = StaticType.INT4.asNullable() + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-03"), + catalog = "pql", + expected = StaticType.INT4.asNullable() + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-04"), + catalog = "pql", + expected = StaticType.INT8.asNullable() + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-05"), + catalog = "pql", + expected = StaticType.INT4.asNullable() + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-06"), + catalog = "pql", + expected = StaticType.NULL + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-07"), + catalog = "pql", + expected = StaticType.INT4.asNullable() + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-08"), + catalog = "pql", + expected = StaticType.NULL_OR_MISSING + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-09"), + catalog = "pql", + expected = StaticType.INT4.asNullable() + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-10"), + catalog = "pql", + expected = StaticType.INT4.asNullable() + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-11"), + catalog = "pql", + expected = StaticType.INT4.asNullable() + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-12"), + catalog = "pql", + expected = StaticType.INT8.asNullable() + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-13"), + catalog = "pql", + expected = StaticType.INT4.asNullable() + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-14"), + catalog = "pql", + expected = StaticType.STRING.asNullable() + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-15"), + catalog = "pql", + expected = StaticType.INT4.asNullable() + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-16"), + catalog = "pql", + expected = unionOf(StaticType.INT2, StaticType.INT4, StaticType.INT8, StaticType.INT, StaticType.DECIMAL, StaticType.NULL) + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-17"), + catalog = "pql", + expected = StaticType.INT4.asNullable() + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "nullif-18"), + catalog = "pql", + expected = StaticType.ANY + ), + ) + + @JvmStatic + fun coalesce() = listOf( + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-00"), + catalog = "pql", + expected = StaticType.INT4 + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-01"), + catalog = "pql", + expected = StaticType.INT4 + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-02"), + catalog = "pql", + expected = StaticType.DECIMAL + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-03"), + catalog = "pql", + expected = unionOf(StaticType.NULL, StaticType.DECIMAL) + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-04"), + catalog = "pql", + expected = unionOf(StaticType.NULL, StaticType.MISSING, StaticType.DECIMAL) + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-05"), + catalog = "pql", + expected = unionOf(StaticType.NULL, StaticType.MISSING, StaticType.DECIMAL) + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-06"), + catalog = "pql", + expected = StaticType.INT4 + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-07"), + catalog = "pql", + expected = StaticType.INT4 + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-08"), + catalog = "pql", + expected = StaticType.INT8 + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-09"), + catalog = "pql", + expected = StaticType.INT8.asNullable() + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-10"), + catalog = "pql", + expected = unionOf(StaticType.INT8, StaticType.NULL, StaticType.MISSING) + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-11"), + catalog = "pql", + expected = unionOf(StaticType.INT8, StaticType.STRING) + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-12"), + catalog = "pql", + expected = unionOf(StaticType.INT8, StaticType.NULL, StaticType.STRING) + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-13"), + catalog = "pql", + expected = unionOf(StaticType.INT2, StaticType.INT4, StaticType.INT8, StaticType.INT, StaticType.DECIMAL) + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-14"), + catalog = "pql", + expected = unionOf(StaticType.INT2, StaticType.INT4, StaticType.INT8, StaticType.INT, StaticType.DECIMAL, StaticType.STRING) + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-15"), + catalog = "pql", + expected = unionOf(StaticType.INT2, StaticType.INT4, StaticType.INT8, StaticType.INT, StaticType.DECIMAL, StaticType.STRING, StaticType.NULL) + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-16"), + catalog = "pql", + expected = StaticType.ANY + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "coalesce-17"), + catalog = "pql", + expected = StaticType.ANY + ), ) @JvmStatic @@ -2649,7 +3051,7 @@ class PlanTyperTestsPorted { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable(idQualified("pql" to BindingCase.SENSITIVE, "main" to BindingCase.SENSITIVE)) + PlanningProblemDetails.UndefinedVariable(id(sensitive("pql"), sensitive("main")), setOf()) ) } ), @@ -2662,7 +3064,7 @@ class PlanTyperTestsPorted { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable(sensitiveId("pql")) + PlanningProblemDetails.UndefinedVariable(sensitive("pql"), setOf()) ) } ), @@ -2795,14 +3197,16 @@ class PlanTyperTestsPorted { fun aggregationCases() = listOf( SuccessTestCase( name = "AGGREGATE over INTS, without alias", - query = "SELECT a, COUNT(*), SUM(a), MIN(b) FROM << {'a': 1, 'b': 2} >> GROUP BY a", + query = "SELECT a, COUNT(*), COUNT(a), SUM(a), MIN(b), MAX(a) FROM << {'a': 1, 'b': 2} >> GROUP BY a", expected = BagType( StructType( fields = mapOf( "a" to StaticType.INT4, - "_1" to StaticType.INT4, - "_2" to StaticType.INT4.asNullable(), + "_1" to StaticType.INT8, + "_2" to StaticType.INT8, "_3" to StaticType.INT4.asNullable(), + "_4" to StaticType.INT4.asNullable(), + "_5" to StaticType.INT4.asNullable(), ), contentClosed = true, constraints = setOf( @@ -2815,12 +3219,13 @@ class PlanTyperTestsPorted { ), SuccessTestCase( name = "AGGREGATE over INTS, with alias", - query = "SELECT a, COUNT(*) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1, 'b': 2} >> GROUP BY a", + query = "SELECT a, COUNT(*) AS c_s, COUNT(a) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1, 'b': 2} >> GROUP BY a", expected = BagType( StructType( fields = mapOf( "a" to StaticType.INT4, - "c" to StaticType.INT4, + "c_s" to StaticType.INT8, + "c" to StaticType.INT8, "s" to StaticType.INT4.asNullable(), "m" to StaticType.INT4.asNullable(), ), @@ -2840,7 +3245,7 @@ class PlanTyperTestsPorted { StructType( fields = mapOf( "a" to StaticType.DECIMAL, - "c" to StaticType.INT4, + "c" to StaticType.INT8, "s" to StaticType.DECIMAL.asNullable(), "m" to StaticType.DECIMAL.asNullable(), ), @@ -2853,6 +3258,89 @@ class PlanTyperTestsPorted { ) ) ), + SuccessTestCase( + name = "AGGREGATE over nullable integers", + query = """ + SELECT + a AS a, + COUNT(*) AS count_star, + COUNT(a) AS count_a, + COUNT(b) AS count_b, + SUM(a) AS sum_a, + SUM(b) AS sum_b, + MIN(a) AS min_a, + MIN(b) AS min_b, + MAX(a) AS max_a, + MAX(b) AS max_b, + AVG(a) AS avg_a, + AVG(b) AS avg_b + FROM << + { 'a': 1, 'b': 2 }, + { 'a': 3, 'b': 4 }, + { 'a': 5, 'b': NULL } + >> GROUP BY a + """.trimIndent(), + expected = BagType( + StructType( + fields = mapOf( + "a" to StaticType.INT4, + "count_star" to StaticType.INT8, + "count_a" to StaticType.INT8, + "count_b" to StaticType.INT8, + "sum_a" to StaticType.INT4.asNullable(), + "sum_b" to StaticType.INT4.asNullable(), + "min_a" to StaticType.INT4.asNullable(), + "min_b" to StaticType.INT4.asNullable(), + "max_a" to StaticType.INT4.asNullable(), + "max_b" to StaticType.INT4.asNullable(), + "avg_a" to StaticType.INT4.asNullable(), + "avg_b" to StaticType.INT4.asNullable(), + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ) + ), + SuccessTestCase( + name = "AGGREGATE over nullable integers", + query = """ + SELECT T1.a + FROM T1 + LEFT JOIN T2 AS T2_1 + ON T2_1.d = + ( + SELECT + CASE WHEN COUNT(f) = 1 THEN MAX(f) ELSE 0 END AS e + FROM T3 AS T3_mapping + ) + LEFT JOIN T2 AS T2_2 + ON T2_2.d = + ( + SELECT + CASE WHEN COUNT(f) = 1 THEN MAX(f) ELSE 0 END AS e + FROM T3 AS T3_mapping + ) + ; + """.trimIndent(), + expected = BagType( + StructType( + fields = mapOf( + "a" to StaticType.BOOL + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ), + catalog = "aggregations" + ) ) @JvmStatic @@ -3120,6 +3608,60 @@ class PlanTyperTestsPorted { // // Parameterized Tests // + + @Test + @Disabled("The planner doesn't support heterogeneous input to aggregation functions (yet?).") + fun failingTest() { + val tc = SuccessTestCase( + name = "AGGREGATE over heterogeneous data", + query = """ + SELECT + a AS a, + COUNT(*) AS count_star, + COUNT(a) AS count_a, + COUNT(b) AS count_b, + SUM(a) AS sum_a, + SUM(b) AS sum_b, + MIN(a) AS min_a, + MIN(b) AS min_b, + MAX(a) AS max_a, + MAX(b) AS max_b, + AVG(a) AS avg_a, + AVG(b) AS avg_b + FROM << + { 'a': 1.0, 'b': 2.0 }, + { 'a': 3, 'b': 4 }, + { 'a': 5, 'b': NULL } + >> GROUP BY a + """.trimIndent(), + expected = BagType( + StructType( + fields = mapOf( + "a" to StaticType.DECIMAL, + "count_star" to StaticType.INT8, + "count_a" to StaticType.INT8, + "count_b" to StaticType.INT8, + "sum_a" to StaticType.DECIMAL.asNullable(), + "sum_b" to StaticType.DECIMAL.asNullable(), + "min_a" to StaticType.DECIMAL.asNullable(), + "min_b" to StaticType.DECIMAL.asNullable(), + "max_a" to StaticType.DECIMAL.asNullable(), + "max_b" to StaticType.DECIMAL.asNullable(), + "avg_a" to StaticType.DECIMAL.asNullable(), + "avg_b" to StaticType.DECIMAL.asNullable(), + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ) + ) + runTest(tc) + } + @ParameterizedTest @ArgumentsSource(TestProvider::class) fun test(tc: TestCase) = runTest(tc) @@ -3194,6 +3736,16 @@ class PlanTyperTestsPorted { @Execution(ExecutionMode.CONCURRENT) fun testCaseWhens(tc: TestCase) = runTest(tc) + @ParameterizedTest + @MethodSource("nullIf") + @Execution(ExecutionMode.CONCURRENT) + fun testNullIf(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("coalesce") + @Execution(ExecutionMode.CONCURRENT) + fun testCoalesce(tc: TestCase) = runTest(tc) + @ParameterizedTest @MethodSource("subqueryCases") @Execution(ExecutionMode.CONCURRENT) @@ -3375,7 +3927,7 @@ class PlanTyperTestsPorted { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable(insensitiveId("pets")) + PlanningProblemDetails.UndefinedVariable(insensitive("pets"), emptySet()) ) } ), @@ -3397,7 +3949,7 @@ class PlanTyperTestsPorted { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable(insensitiveId("pets")) + PlanningProblemDetails.UndefinedVariable(insensitive("pets"), emptySet()) ) } ), @@ -3453,14 +4005,7 @@ class PlanTyperTestsPorted { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable( - BindingPath( - steps = listOf( - BindingName("ddb", BindingCase.INSENSITIVE), - BindingName("pets", BindingCase.INSENSITIVE), - ) - ) - ) + PlanningProblemDetails.UndefinedVariable(id(insensitive("ddb"), insensitive("pets")), emptySet()) ) } ), @@ -3726,11 +4271,12 @@ class PlanTyperTestsPorted { query = "non_existing_column = 1", // Function resolves to EQ__ANY_ANY__BOOL // Which can return BOOL Or NULL + // TODO this is maybe an error? Depends on -Werror settings.. expected = StaticType.MISSING, problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable(insensitiveId("non_existing_column")) + PlanningProblemDetails.UndefinedVariable(insensitive("non_existing_column"), emptySet()) ) } ), @@ -3773,7 +4319,7 @@ class PlanTyperTestsPorted { query = "SELECT unknown_col FROM orders WHERE customer_id = 1", expected = BagType( StructType( - fields = emptyList(), + fields = listOf(), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -3785,7 +4331,7 @@ class PlanTyperTestsPorted { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable(insensitiveId("unknown_col")) + PlanningProblemDetails.UndefinedVariable(insensitive("unknown_col"), setOf("orders")) ) } ), diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/functions/NullIfTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/functions/NullIfTest.kt index 8b8fb9fc49..6f1a56a84e 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/functions/NullIfTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/functions/NullIfTest.kt @@ -8,27 +8,34 @@ import org.partiql.planner.util.cartesianProduct import org.partiql.types.StaticType import java.util.stream.Stream -// TODO: Model handling of Truth Value in typer better. +/** + * The NULLIF() function returns NULL if two expressions are equal, otherwise it returns the first expression + * + * The type of NULLIF(arg_0: T_0, arg_1: arg_1) should be (null|T_0). + * + * CASE + * WHEN x = y THEN NULL + * ELSE x + * END + * + * TODO: Model handling of Truth Value in typer better. + */ class NullIfTest : PartiQLTyperTestBase() { @TestFactory fun nullIf(): Stream { - val tests = listOf( - "func-00", - ).map { inputs.get("basics", it)!! } - val argsMap = buildMap { - val successArgs = cartesianProduct(allSupportedType, allSupportedType) + val tests = listOf("func-00").map { inputs.get("basics", it)!! } + val argsMap = mutableMapOf>>() - successArgs.forEach { args: List -> - val returnType = StaticType.unionOf(args.first(), StaticType.NULL).flatten() - (this[TestResult.Success(returnType)] ?: setOf(args)).let { - put(TestResult.Success(returnType), it + setOf(args)) - } - Unit - } - put(TestResult.Failure, emptySet>()) + // Generate all success cases + cartesianProduct(allSupportedType, allSupportedType).forEach { args -> + val expected = StaticType.unionOf(args[0], StaticType.NULL).flatten() + val result = TestResult.Success(expected) + argsMap[result] = setOf(args) } + // No failure case + argsMap[TestResult.Failure] = emptySet() return super.testGen("nullIf", tests, argsMap) } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/path/SanityTests.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/path/SanityTests.kt index cbe7c7e3f2..852ff3cb5c 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/path/SanityTests.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/path/SanityTests.kt @@ -10,19 +10,26 @@ import java.util.stream.Stream * This test makes sure that the planner can resolve various path expression */ class SanityTests : PartiQLTyperTestBase() { + @TestFactory fun path(): Stream { + + val start = 0 + val end = 14 + val tests = buildList { - (0..14).forEach { - this.add("paths-${it.toString().padStart(2,'0')}") + (start..end).forEach { + this.add("paths-${it.toString().padStart(2, '0')}") } }.map { inputs.get("basics", it)!! } - + // -- t1 -> ANY + // -- t2 -> ANY + val argTypes = listOf(StaticType.ANY, StaticType.ANY) + // -- All paths return ANY because t1 and t2 are both ANY val argsMap: Map>> = buildMap { - put(TestResult.Success(StaticType.ANY), setOf(listOf(StaticType.ANY, StaticType.ANY))) + put(TestResult.Success(StaticType.ANY), setOf(argTypes)) put(TestResult.Failure, emptySet>()) } - return super.testGen("path", tests, argsMap) } } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpTypeAssertionTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpTypeAssertionTest.kt index 4939a3b400..bef6a89a66 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpTypeAssertionTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpTypeAssertionTest.kt @@ -28,7 +28,7 @@ class OpTypeAssertionTest : PartiQLTyperTestBase() { }.toSet() val failureArgs = setOf(listOf(MissingType)) accumulateSuccesses(StaticType.BOOL, successArgs) - accumulateSuccessNullCall(StaticType.BOOL, listOf(StaticType.NULL)) + accumulateSuccessNullCall(StaticType.NULL, listOf(StaticType.NULL)) put(TestResult.Failure, failureArgs) } diff --git a/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T1.ion b/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T1.ion new file mode 100644 index 0000000000..f0defe828f --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T1.ion @@ -0,0 +1,17 @@ +{ + type: "bag", + items: { + type: "struct", + constraints: [ closed ], + fields: [ + { + name: "a", + type: "bool", + }, + { + name: "b", + type: "int32", + }, + ] + } +} diff --git a/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T2.ion b/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T2.ion new file mode 100644 index 0000000000..9f51c844e0 --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T2.ion @@ -0,0 +1,17 @@ +{ + type: "bag", + items: { + type: "struct", + constraints: [ closed ], + fields: [ + { + name: "c", + type: "bool", + }, + { + name: "d", + type: "int32", + }, + ] + } +} diff --git a/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T3.ion b/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T3.ion new file mode 100644 index 0000000000..40e7812425 --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T3.ion @@ -0,0 +1,17 @@ +{ + type: "bag", + items: { + type: "struct", + constraints: [ closed ], + fields: [ + { + name: "e", + type: "bool", + }, + { + name: "f", + type: "int32", + }, + ] + } +} diff --git a/partiql-planner/src/testFixtures/resources/catalogs/default/pql/t_item.ion b/partiql-planner/src/testFixtures/resources/catalogs/default/pql/t_item.ion new file mode 100644 index 0000000000..e4435d624c --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/catalogs/default/pql/t_item.ion @@ -0,0 +1,184 @@ +// simple item which various types for testing +{ + type: "struct", + constraints: [ closed, unique ], + fields: [ + // Boolean + { + name: "t_bool", + type: "bool", + }, + { + name: "t_bool_nul", + type: ["bool","null"], + }, + // Exact Numeric +// { +// name: "t_int8", +// type: "int8", +// }, +// { +// name: "t_int8_null", +// type: ["int8", "null"], +// }, + { + name: "t_int16", + type: "int16", + }, + { + name: "t_int16_null", + type: ["int16", "null"], + }, + { + name: "t_int32", + type: "int32", + }, + { + name: "t_int32_null", + type: ["int32", "null"], + }, + { + name: "t_int64", + type: "int64", + }, + { + name: "t_int64_null", + type: ["int64", "null"], + }, + { + name: "t_int", + type: "int", + }, + { + name: "t_int_null", + type: ["int", "null"], + }, + { + name: "t_decimal", + type: "decimal", + }, + { + name: "t_decimal_null", + type: ["decimal", "null"], + }, + // Approximate Numeric + { + name: "t_float32", + type: "float32", + }, + { + name: "t_float32_null", + type: ["float32", "null"], + }, + { + name: "t_float64", + type: "float64", + }, + { + name: "t_float64_null", + type: ["float64", "null"], + }, + // Strings + { + name: "t_string", + type: "string", + }, + { + name: "t_string_null", + type: ["string", "null"], + }, + { + name: "t_clob", + type: "clob", + }, + { + name: "t_clob_null", + type: ["clob", "null"], + }, + // absent + { + name: "t_null", + type: "null", + }, + { + name: "t_missing", + type: "missing", + }, + { + name: "t_absent", + type: ["null", "missing"], + }, + // collections + { + name: "t_bag", + type: { + type: "bag", + items: "any", + }, + }, + { + name: "t_list", + type: { + type: "list", + items: "any", + } + }, + { + name: "t_sexp", + type: { + type: "sexp", + items: "any", + } + }, + // structs + { + name: "t_struct_a", + type: { + type: "struct", + fields: [ + { + name: "x", + type: "int32", + }, + { + name: "y", + type: "int32", + }, + ] + }, + }, + { + name: "t_struct_b", + type: { + type: "struct", + fields: [ + { + name: "x", + type: "int64", + }, + { + name: "y", + type: "int64", + }, + ] + }, + }, + { + name: "t_any", + type: "any", + }, + // unions + { + name: "t_num_exact", + type: [ "int16", "int32", "int64", "int", "decimal" ], + }, + { + name: "t_num_exact_null", + type: [ "int16", "int32", "int64", "int", "decimal", "null" ], + }, + { + name: "t_str", + type: [ "clob", "string" ], + } + ] +} diff --git a/partiql-planner/src/testFixtures/resources/inputs/basics/case.sql b/partiql-planner/src/testFixtures/resources/inputs/basics/case.sql index f7099d53cc..590e088578 100644 --- a/partiql-planner/src/testFixtures/resources/inputs/basics/case.sql +++ b/partiql-planner/src/testFixtures/resources/inputs/basics/case.sql @@ -1,30 +1,328 @@ ---#[case-00] +-- ----------------------------- +-- Exact Numeric +-- ----------------------------- + +--#[case-when-00] +-- type: (int32) +CASE t_item.t_bool + WHEN true THEN 0 + WHEN false THEN 1 + ELSE 2 +END; + +--#[case-when-02] +-- type: (int32) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int16 -- cast(.. AS INT4) + ELSE t_item.t_int32 -- INT4 +END; + +--#[case-when-03] +-- type: (int64) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int16 -- cast(.. AS INT8) + WHEN 'b' THEN t_item.t_int32 -- cast(.. AS INT8) + ELSE t_item.t_int64 -- INT8 +END; + +--#[case-when-04] +-- type: (int) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int16 -- cast(.. AS INT) + WHEN 'b' THEN t_item.t_int32 -- cast(.. AS INT) + WHEN 'c' THEN t_item.t_int64 -- cast(.. AS INT) + ELSE t_item.t_int -- INT +END; + +--#[case-when-05] +-- type: (int) +CASE t_item.t_string + WHEN 'b' THEN t_item.t_int32 -- cast(.. AS INT) + WHEN 'c' THEN t_item.t_int64 -- cast(.. AS INT) + ELSE t_item.t_int -- INT +END; + +--#[case-when-06] +-- type: (int) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int16 -- cast(.. AS INT) + WHEN 'b' THEN t_item.t_int32 -- cast(.. AS INT) + ELSE t_item.t_int -- INT +END; + +--#[case-when-07] +-- type: (int64) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int32 -- cast(.. AS INT8) + WHEN 'b' THEN t_item.t_int64 -- INT8 + ELSE t_item.t_int16 -- cast(.. AS INT8) +END; + +--#[case-when-08] +-- type: (int|null) +-- nullable default +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int16 -- cast(.. AS INT) + WHEN 'b' THEN t_item.t_int32 -- cast(.. AS INT) + ELSE t_item.t_int_null -- INT +END; + +--#[case-when-09] +-- type: (int|null) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int16_null -- cast(.. AS INT) + WHEN 'b' THEN t_item.t_int32 -- cast(.. AS INT) + ELSE t_item.t_int +END; + +--#[case-when-10] +-- type: (decimal|null) +-- nullable branch +CASE t_item.t_string + WHEN 'a' THEN t_item.t_decimal + WHEN 'b' THEN t_item.t_int32 + ELSE NULL +END; + +--#[case-when-11] +-- type: (int|missing) +COALESCE(CAST(t_item.t_string AS INT), 1); + +-- ----------------------------- +-- Approximate Numeric +-- ----------------------------- + +-- TODO model approximate numeric +-- We do not have the appropriate StaticType for this. + +--#[case-when-12] +-- type: (float64) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int + ELSE t_item.t_float64 +END; + +--#[case-when-13] +-- type: (float64|null) +-- nullable branch +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int + WHEN 'b' THEN t_item.t_float64 + ELSE NULL +END; + +-- ----------------------------- +-- Character Strings +-- ----------------------------- + +--#[case-when-14] +-- type: string +CASE t_item.t_string + WHEN 'a' THEN t_item.t_string + ELSE 'default' +END; + +--#[case-when-15] +-- type: (string|null) +-- null default +CASE t_item.t_string + WHEN 'a' THEN t_item.t_string + ELSE NULL +END; + +--#[case-when-16] +-- type: clob +CASE t_item.t_string + WHEN 'a' THEN t_item.t_string + WHEN 'b' THEN t_item.t_clob + ELSE 'default' +END; + +--#[case-when-17] +-- type: (clob|null) +-- null default +CASE t_item.t_string + WHEN 'a' THEN t_item.t_string + WHEN 'b' THEN t_item.t_clob + ELSE NULL +END; + +-- ---------------------------------- +-- Variations of null and missing +-- ---------------------------------- + +--#[case-when-18] +-- type: (string|null) +CASE t_item.t_string + WHEN 'a' THEN NULL + ELSE 'default' +END; + +--#[case-when-19] +-- type: (string|null) +CASE t_item.t_string + WHEN 'a' THEN NULL + WHEN 'b' THEN NULL + WHEN 'c' THEN NULL + WHEN 'd' THEN NULL + ELSE 'default' +END; + +--#[case-when-20] +-- type: null +-- no default, null anyways +CASE t_item.t_string + WHEN 'a' THEN NULL +END; + +--#[case-when-21] +-- type: (string|null) +-- no default +CASE t_item.t_string + WHEN 'a' THEN 'ok!' +END; + +--#[case-when-22] +-- type: (null|missing|int32) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_absent + ELSE -1 +END; + +--#[case-when-23] +-- type: int32 +-- false branch is pruned +CASE + WHEN false THEN t_item.t_absent + ELSE -1 +END; + +-- ----------------------------- +-- Heterogeneous Branches +-- ----------------------------- + +--#[case-when-24] +-- type: (int32|int64|string) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int32 + WHEN 'b' THEN t_item.t_int64 + ELSE 'default' +END; + +--#[case-when-25] +-- type: (int32|int64|string|null) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int32 + WHEN 'b' THEN t_item.t_int64 + WHEN 'c' THEN t_item.t_string + ELSE NULL +END; + +--#[case-when-26] +-- type: (int32|int64|string|null) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int32 + WHEN 'b' THEN t_item.t_int64_null + ELSE 'default' +END; + +--#[case-when-27] +-- type: (int16|int32|int64|int|decimal|string|clob) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_num_exact + WHEN 'b' THEN t_item.t_str + ELSE 'default' +END; + +--#[case-when-28] +-- type: (int16|int32|int64|int|decimal|string|clob|null) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_num_exact + WHEN 'b' THEN t_item.t_str +END; + +--#[case-when-29] +-- type: (struct_a|struct_b|null) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_struct_a + WHEN 'b' THEN t_item.t_struct_b +END; + +--#[case-when-30] +-- type: missing +CASE t_item.t_string + WHEN 'a' THEN MISSING + WHEN 'b' THEN MISSING + ELSE MISSING +END; + +-- ----------------------------- +-- Any Branches +-- ----------------------------- + +--#[case-when-31] +-- type: (any) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_any + WHEN 'b' THEN t_item.t_int32 + ELSE NULL +END; + +--#[case-when-32] +-- type: (any) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int32 + WHEN 'b' THEN t_item.t_any + ELSE NULL +END; + +--#[case-when-33] +-- type: (any) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int32 + WHEN 'b' THEN NULL + ELSE t_item.t_any +END; + +--#[case-when-34] +-- type: (any) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int32_null + WHEN 'b' THEN t_item.t_any + ELSE t_item.t_any +END; + +-- ----------------------------- +-- (Unused) old tests +-- ----------------------------- + +--#[old-case-when-00] CASE WHEN FALSE THEN 0 WHEN TRUE THEN 1 ELSE 2 END; ---#[case-01] +--#[old-case-when-01] CASE WHEN 1 = 2 THEN 0 WHEN 2 = 3 THEN 1 ELSE 3 END; ---#[case-02] +--#[old-case-when-02] CASE 1 WHEN 1 THEN 'MATCH!' ELSE 'NO MATCH!' END; ---#[case-03] +--#[old-case-when-03] CASE 'Hello World' WHEN 'Hello World' THEN TRUE ELSE FALSE END; ---#[case-04] +--#[old-case-when-04] SELECT CASE a WHEN TRUE THEN 'a IS TRUE' @@ -32,7 +330,7 @@ SELECT END AS result FROM T; ---#[case-05] +--#[old-case-when-05] SELECT CASE WHEN a = TRUE THEN 'a IS TRUE' @@ -40,7 +338,7 @@ SELECT END AS result FROM T; ---#[case-06] +--#[old-case-when-06] SELECT CASE b WHEN 10 THEN 'b IS 10' @@ -48,7 +346,7 @@ SELECT END AS result FROM T; ---#[case-07] +--#[old-case-when-07] -- TODO: This is currently failing as we seemingly cannot search for a nested attribute of a global. SELECT CASE d.e @@ -57,7 +355,7 @@ SELECT END AS result FROM T; ---#[case-08] +--#[old-case-when-08] SELECT CASE x WHEN 'WATER' THEN 'x IS WATER' @@ -66,7 +364,7 @@ SELECT END AS result FROM T; ---#[case-09] +--#[old-case-when-09] -- TODO: When using `x IS STRING` or `x IS DECIMAL`, I found that there are issues with the SqlCalls not receiving -- the length/precision/scale parameters. This doesn't have to do with CASE_WHEN, but it needs to be addressed. SELECT @@ -77,7 +375,7 @@ SELECT END AS result FROM T; ---#[case-10] +--#[old-case-when-10] CASE WHEN FALSE THEN 0 WHEN FALSE THEN 1 diff --git a/partiql-planner/src/testFixtures/resources/inputs/basics/coalesce.sql b/partiql-planner/src/testFixtures/resources/inputs/basics/coalesce.sql new file mode 100644 index 0000000000..c8e10d18fa --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/inputs/basics/coalesce.sql @@ -0,0 +1,71 @@ +--#[coalesce-00] +-- type: (int32) +COALESCE(1); + +--#[coalesce-01] +-- type: (int32) +COALESCE(1, 2); + +--#[coalesce-02] +-- type: (decimal) +COALESCE(1, 1.23); + +--#[coalesce-03] +-- type: (null | decimal) +COALESCE(NULL, 1, 1.23); + +--#[coalesce-04] +-- type: (null | missing | decimal) +COALESCE(NULL, MISSING, 1, 1.23); + +--#[coalesce-05] +-- type: (null | missing | decimal); same as above +COALESCE(1, 1.23, NULL, MISSING); + +--#[coalesce-06] +-- type: (int32) +COALESCE(t_item.t_int32); + +--#[coalesce-07] +-- type: (int32) +COALESCE(t_item.t_int32, t_item.t_int32); + +--#[coalesce-08] +-- type: (int64) +COALESCE(t_item.t_int64, t_item.t_int32); + +--#[coalesce-09] +-- type: (int64 | null) +COALESCE(t_item.t_int64_null, t_item.t_int32, t_item.t_int32_null); + +--#[coalesce-10] +-- type: (int64 | null | missing) +COALESCE(t_item.t_int64_null, t_item.t_int32, t_item.t_int32_null, MISSING); + +--#[coalesce-11] +-- type: (int64 | string) +COALESCE(t_item.t_int64, t_item.t_string); + +--#[coalesce-12] +-- type: (int64 | null | string) +COALESCE(t_item.t_int64_null, t_item.t_string); + +--#[coalesce-13] +-- type: (int16 | int32 | int64 | int | decimal) +COALESCE(t_item.t_num_exact, t_item.t_int32); + +--#[coalesce-14] +-- type: (int16 | int32 | int64 | int | decimal, string) +COALESCE(t_item.t_num_exact, t_item.t_string); + +--#[coalesce-15] +-- type: (int16 | int32 | int64 | int | decimal, string, null) +COALESCE(t_item.t_num_exact, t_item.t_string, NULL); + +--#[coalesce-16] +-- type: (any) +COALESCE(t_item.t_any, t_item.t_int32); + +--#[coalesce-17] +-- type: (any) +COALESCE(t_item.t_int32, t_item.t_any); \ No newline at end of file diff --git a/partiql-planner/src/testFixtures/resources/inputs/basics/nullif.sql b/partiql-planner/src/testFixtures/resources/inputs/basics/nullif.sql new file mode 100644 index 0000000000..d03c3d863a --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/inputs/basics/nullif.sql @@ -0,0 +1,77 @@ +--#[nullif-00] +-- Currently, no constant-folding. If there was, return type could be int32. +-- type: (int32 | null) +NULLIF(1, 2); + +--#[nullif-01] +-- Currently, no constant-folding. If there was, return type could be null. +-- type: (int32 | null) +NULLIF(1, 1); + +--#[nullif-02] +-- type: (int32 | null) +NULLIF(t_item.t_int32, t_item.t_int32); + +--#[nullif-03] +-- type: (int32 | null) +NULLIF(t_item.t_int32, t_item.t_int64); + +--#[nullif-04] +-- type: (int64 | null) +NULLIF(t_item.t_int64, t_item.t_int32); + +--#[nullif-05] +-- type: (int32 | null) +NULLIF(t_item.t_int32, t_item.t_null); + +--#[nullif-06] +-- type: (null) +NULLIF(t_item.t_null, t_item.t_int32); + +--#[nullif-07] +-- type: (int32 | null) +NULLIF(t_item.t_int32, MISSING); + +--#[nullif-08] +-- type: (missing | null) +NULLIF(MISSING, t_item.t_int32); + +--#[nullif-09] +-- type: (int32 | null) +NULLIF(t_item.t_int32, t_item.t_int32_null); + +--#[nullif-10] +-- type: (int32 | null) +NULLIF(t_item.t_int32_null, t_item.t_int32); + +--#[nullif-11] +-- type: (int32 | null) +NULLIF(t_item.t_int32, t_item.t_int64_null); + +--#[nullif-12] +-- type: (int64 | null) +NULLIF(t_item.t_int64_null, t_item.t_int32); + +--#[nullif-13] +-- type: (int32 | null) +NULLIF(t_item.t_int32, t_item.t_string); + +--#[nullif-14] +-- type: (string | null) +NULLIF(t_item.t_string, t_item.t_int32); + +--#[nullif-15] +-- type: (int32 | null) +NULLIF(t_item.t_int32, t_item.t_num_exact); + +--#[nullif-16] +-- type: (int16 | int32 | int64 | int | decimal | null) +NULLIF(t_item.t_num_exact, t_item.t_int32); + +--#[nullif-17] +-- type: (int32 | null) +NULLIF(t_item.t_int32, t_item.t_any); + +--#[nullif-18] +-- type: (any) +NULLIF(t_item.t_any, t_item.t_int32); diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt index 829a824cf9..5df65ed9e6 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt @@ -196,12 +196,6 @@ internal object SqlBuiltins { Fn_GTE__DATE_DATE__BOOL, Fn_GTE__TIME_TIME__BOOL, Fn_GTE__TIMESTAMP_TIMESTAMP__BOOL, - Fn_IN_COLLECTION__NULL_BAG__BOOL, - Fn_IN_COLLECTION__NULL_LIST__BOOL, - Fn_IN_COLLECTION__NULL_SEXP__BOOL, - Fn_IN_COLLECTION__MISSING_BAG__BOOL, - Fn_IN_COLLECTION__MISSING_LIST__BOOL, - Fn_IN_COLLECTION__MISSING_SEXP__BOOL, Fn_IN_COLLECTION__BOOL_BAG__BOOL, Fn_IN_COLLECTION__BOOL_LIST__BOOL, Fn_IN_COLLECTION__BOOL_SEXP__BOOL, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCount.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCount.kt index 0496f236db..8264a5f807 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCount.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCount.kt @@ -10,14 +10,14 @@ import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnParameter import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType.ANY -import org.partiql.value.PartiQLValueType.INT32 +import org.partiql.value.PartiQLValueType.INT64 @OptIn(PartiQLValueExperimental::class, FnExperimental::class) public object Agg_COUNT__ANY__INT32 : Agg { override val signature: AggSignature = AggSignature( name = "count", - returns = INT32, + returns = INT64, parameters = listOf( FnParameter("value", ANY), ), diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCountStar.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCountStar.kt index d8088a2017..cb63fc7411 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCountStar.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/AggCountStar.kt @@ -8,14 +8,14 @@ import org.partiql.spi.fn.Agg import org.partiql.spi.fn.AggSignature import org.partiql.spi.fn.FnExperimental import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType.INT32 +import org.partiql.value.PartiQLValueType.INT64 @OptIn(PartiQLValueExperimental::class, FnExperimental::class) public object Agg_COUNT_STAR____INT32 : Agg { override val signature: AggSignature = AggSignature( name = "count_star", - returns = INT32, + returns = INT64, parameters = listOf(), isNullable = false, isDecomposable = true diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnInCollection.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnInCollection.kt index eca17cbde4..40452eb68b 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnInCollection.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnInCollection.kt @@ -3,7 +3,6 @@ package org.partiql.spi.connector.sql.builtins -import org.partiql.errors.TypeCheckException import org.partiql.spi.fn.Fn import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnParameter @@ -26,7 +25,6 @@ import org.partiql.value.Int8Value import org.partiql.value.IntValue import org.partiql.value.IntervalValue import org.partiql.value.ListValue -import org.partiql.value.NullValue import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType.ANY @@ -49,8 +47,6 @@ import org.partiql.value.PartiQLValueType.INT64 import org.partiql.value.PartiQLValueType.INT8 import org.partiql.value.PartiQLValueType.INTERVAL import org.partiql.value.PartiQLValueType.LIST -import org.partiql.value.PartiQLValueType.MISSING -import org.partiql.value.PartiQLValueType.NULL import org.partiql.value.PartiQLValueType.SEXP import org.partiql.value.PartiQLValueType.STRING import org.partiql.value.PartiQLValueType.STRUCT @@ -2249,144 +2245,3 @@ internal object Fn_IN_COLLECTION__STRUCT_SEXP__BOOL : Fn { return boolValue(false) } } - -@OptIn(PartiQLValueExperimental::class, FnExperimental::class) -internal object Fn_IN_COLLECTION__NULL_BAG__BOOL : Fn { - - override val signature = FnSignature( - name = "in_collection", - returns = BOOL, - parameters = listOf( - FnParameter("value", NULL), - FnParameter("collection", BAG), - ), - isNullCall = true, - isNullable = false, - ) - - override fun invoke(args: Array): PartiQLValue { - val value = args[0].check() - val collection = args[1].check>() - val iter = collection.iterator() - while (iter.hasNext()) { - val v = iter.next() - if (PartiQLValue.comparator().compare(value, v) == 0) { - return boolValue(true) - } - } - return boolValue(false) - } -} - -@OptIn(PartiQLValueExperimental::class, FnExperimental::class) -internal object Fn_IN_COLLECTION__NULL_LIST__BOOL : Fn { - - override val signature = FnSignature( - name = "in_collection", - returns = BOOL, - parameters = listOf( - FnParameter("value", NULL), - FnParameter("collection", LIST), - ), - isNullCall = true, - isNullable = false, - ) - - override fun invoke(args: Array): PartiQLValue { - val value = args[0].check() - val collection = args[1].check>() - val iter = collection.iterator() - while (iter.hasNext()) { - val v = iter.next() - if (PartiQLValue.comparator().compare(value, v) == 0) { - return boolValue(true) - } - } - return boolValue(false) - } -} - -@OptIn(PartiQLValueExperimental::class, FnExperimental::class) -internal object Fn_IN_COLLECTION__NULL_SEXP__BOOL : Fn { - - override val signature = FnSignature( - name = "in_collection", - returns = BOOL, - parameters = listOf( - FnParameter("value", NULL), - FnParameter("collection", SEXP), - ), - isNullCall = true, - isNullable = false, - ) - - override fun invoke(args: Array): PartiQLValue { - val value = args[0].check() - val collection = args[1].check>() - val iter = collection.iterator() - while (iter.hasNext()) { - val v = iter.next() - if (PartiQLValue.comparator().compare(value, v) == 0) { - return boolValue(true) - } - } - return boolValue(false) - } -} - -@OptIn(PartiQLValueExperimental::class, FnExperimental::class) -internal object Fn_IN_COLLECTION__MISSING_BAG__BOOL : Fn { - - override val signature = FnSignature( - name = "in_collection", - returns = BOOL, - parameters = listOf( - FnParameter("value", MISSING), - FnParameter("collection", BAG), - ), - isNullCall = true, - isNullable = false, - ) - - override fun invoke(args: Array): PartiQLValue { - throw TypeCheckException() - } -} - -@OptIn(PartiQLValueExperimental::class, FnExperimental::class) -internal object Fn_IN_COLLECTION__MISSING_LIST__BOOL : Fn { - - override val signature = FnSignature( - name = "in_collection", - returns = BOOL, - parameters = listOf( - FnParameter("value", MISSING), - FnParameter("collection", LIST), - ), - isNullCall = true, - isNullable = false, - ) - - override fun invoke(args: Array): PartiQLValue { - throw TypeCheckException() - } -} - -@OptIn(PartiQLValueExperimental::class, FnExperimental::class) -internal object Fn_IN_COLLECTION__MISSING_SEXP__BOOL : Fn { - - override val signature = FnSignature( - name = "in_collection", - returns = BOOL, - parameters = listOf( - FnParameter("value", MISSING), - FnParameter("collection", SEXP), - ), - isNullCall = true, - isNullable = false, - ) - - override fun invoke(args: Array): PartiQLValue { - throw TypeCheckException() - } -} diff --git a/partiql-types/src/main/kotlin/org/partiql/types/StaticType.kt b/partiql-types/src/main/kotlin/org/partiql/types/StaticType.kt index d2ef1756f8..5eeba1b397 100644 --- a/partiql-types/src/main/kotlin/org/partiql/types/StaticType.kt +++ b/partiql-types/src/main/kotlin/org/partiql/types/StaticType.kt @@ -599,9 +599,10 @@ public data class StructType( get() = listOf(this) override fun toString(): String { - val firstSeveral = fields.take(3).joinToString { "${it.key}: ${it.value}" } + val firstFieldsSize = 15 + val firstSeveral = fields.take(firstFieldsSize).joinToString { "${it.key}: ${it.value}" } return when { - fields.size <= 3 -> "struct($firstSeveral, $constraints)" + fields.size <= firstFieldsSize -> "struct($firstSeveral, $constraints)" else -> "struct($firstSeveral, ... and ${fields.size - 3} other field(s), $constraints)" } } @@ -630,7 +631,7 @@ public data class AnyOfType(val types: Set, override val metas: Map< types = this.types.flatMap { when (it) { is SingleType -> listOf(it) - is AnyType -> listOf(it) + is AnyType -> return@flatten it // if `AnyType`, return `AnyType` is AnyOfType -> it.types } }.toSet() @@ -642,21 +643,10 @@ public data class AnyOfType(val types: Set, override val metas: Map< } } - override fun toString(): String = - when (val flattenedType = flatten()) { - is AnyOfType -> { - val unionedTypes = flattenedType.types - when (unionedTypes.size) { - 0 -> "\$null" - 1 -> unionedTypes.first().toString() - else -> { - val types = unionedTypes.joinToString { it.toString() } - "union($types)" - } - } - } - else -> flattenedType.toString() - } + override fun toString(): String { + val types = types.joinToString { it.toString() } + return "union($types)" + } override val allTypes: List get() = this.types.map { it.flatten() }