diff --git a/.editorconfig b/.editorconfig index 3d674d593..cc987b518 100644 --- a/.editorconfig +++ b/.editorconfig @@ -16,3 +16,9 @@ end_of_line = unset insert_final_newline = unset indent_style = unset trim_trailing_whitespace = unset + +[*.gradle.kts] +indent_size = 2 + +[.gitmodules] +indent_style = tab diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 7ddc02f9a..2feebebea 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -2,6 +2,11 @@ name: PR Build Check on: pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: editorconfig-checker: name: Check editorconfig @@ -10,7 +15,7 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 - - uses: editorconfig-checker/action-editorconfig-checker@main + - uses: editorconfig-checker/action-editorconfig-checker@v2 - run: editorconfig-checker commitlint: name: Lint commits for semantic-release @@ -28,7 +33,43 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: gradle/wrapper-validation-action@v1 + - uses: gradle/actions/wrapper-validation@v3 + cyclonedx-sbom: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + - name: Generate SBOMs + run: ./gradlew cyclonedxBom + - name: Upload SBOMs + uses: actions/upload-artifact@v4 + with: + name: cyclonedx-sbom + path: | + core/build/reports/bom.json + isthmus/build/reports/bom.json + isthmus-cli/build/reports/bom.json + osv-scanner: + needs: cyclonedx-sbom + runs-on: ubuntu-latest + continue-on-error: true + strategy: + fail-fast: false + matrix: + project: + - core + - isthmus + - isthmus-cli + steps: + - name: Download SBOMs + uses: actions/download-artifact@v4 + with: + name: cyclonedx-sbom + - name: Scan + run: docker run --rm -v "${PWD}/${{ matrix.project }}/build/reports/bom.json:/bom.json" ghcr.io/google/osv-scanner --sbom /bom.json java: name: Build and Test Java runs-on: ubuntu-latest @@ -37,10 +78,10 @@ jobs: with: submodules: recursive - name: Set up JDK 17 - uses: actions/setup-java@v3 + uses: actions/setup-java@v4 with: java-version: '17' - distribution: 'adopt' + distribution: 'temurin' - name: Setup Gradle uses: gradle/actions/setup-gradle@v3 - name: Build with Gradle @@ -71,17 +112,18 @@ jobs: - name: Build with Gradle run: gradle nativeImage - name: Smoke Test - run: ./isthmus/src/test/script/smoke.sh - ./isthmus/src/test/script/tpch_smoke.sh + run: | + ./isthmus-cli/src/test/script/smoke.sh + ./isthmus-cli/src/test/script/tpch_smoke.sh - name: Rename the artifact to OS-unique name shell: bash run: | - value=`mv isthmus/build/graal/isthmus isthmus/build/graal/isthmus-${{ matrix.os }}` + value=`mv isthmus-cli/build/graal/isthmus isthmus-cli/build/graal/isthmus-${{ matrix.os }}` - name: Publish artifact uses: actions/upload-artifact@v4 with: name: isthmus-${{ matrix.os }} - path: isthmus/build/graal/isthmus-${{ matrix.os }} + path: isthmus-cli/build/graal/isthmus-${{ matrix.os }} dry-run-release: name: Dry-run release runs-on: ubuntu-latest diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 019c5e5ee..b8d444875 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -37,17 +37,18 @@ jobs: - name: Build with Gradle run: gradle nativeImage - name: Smoke Test - run: ./isthmus/src/test/script/smoke.sh - ./isthmus/src/test/script/tpch_smoke.sh + run: | + ./isthmus-cli/src/test/script/smoke.sh + ./isthmus-cli/src/test/script/tpch_smoke.sh - name: Rename the artifact to OS-unique name shell: bash run: | - value=`mv isthmus/build/graal/isthmus isthmus/build/graal/isthmus-${{ matrix.os }}` + value=`mv isthmus-cli/build/graal/isthmus isthmus-cli/build/graal/isthmus-${{ matrix.os }}` - name: Publish artifact uses: actions/upload-artifact@v4 with: name: isthmus-${{ matrix.os }} - path: isthmus/build/graal/isthmus-${{ matrix.os }} + path: isthmus-cli/build/graal/isthmus-${{ matrix.os }} semantic-release: if: github.repository == 'substrait-io/substrait-java' runs-on: ubuntu-latest @@ -58,10 +59,10 @@ jobs: with: submodules: recursive - name: Set up JDK 17 - uses: actions/setup-java@v3 + uses: actions/setup-java@v4 with: java-version: '17' - distribution: 'adopt' + distribution: 'temurin' - uses: actions/setup-node@v4 with: node-version: '20' diff --git a/.gitignore b/.gitignore index 0ea09b09b..d7d9428ea 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ gen *.iml out/** *.iws +.vscode diff --git a/CHANGELOG.md b/CHANGELOG.md index ce7fed21e..1bca44eef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,130 @@ Release Notes --- +## [0.37.0](https://github.com/substrait-io/substrait-java/compare/v0.36.0...v0.37.0) (2024-07-21) + +### ⚠ BREAKING CHANGES + +* AdvancedExtension#getOptimization() has been removed. Use getOptimizations() instead. + +### Features + +* literal support for precision timestamp types ([#283](https://github.com/substrait-io/substrait-java/issues/283)) ([94996f9](https://github.com/substrait-io/substrait-java/commit/94996f916478ed8141e5fb54b1c8411cc80f4abd)) +* validate VirtualTableScan field names with schema ([#284](https://github.com/substrait-io/substrait-java/issues/284)) ([0f8514a](https://github.com/substrait-io/substrait-java/commit/0f8514a95f0ffa2c3cca645652ef91ebfd3ccb9d)) + +### Miscellaneous Chores + +* update to substrait 0.52.0 ([#282](https://github.com/substrait-io/substrait-java/issues/282)) ([ada8d0b](https://github.com/substrait-io/substrait-java/commit/ada8d0be54b8bbd260b194c4e93f02ed42821b5d)) + +## [0.36.0](https://github.com/substrait-io/substrait-java/compare/v0.35.0...v0.36.0) (2024-07-14) + +### ⚠ BREAKING CHANGES + +* Expression#options now returns List +* ProtoAggregateFunctionConverter#from(AggregateFunction) now returns AggregateFunctionInvocation + +### Bug Fixes + +* include FunctionOptions when converting functions ([#278](https://github.com/substrait-io/substrait-java/issues/278)) ([e574913](https://github.com/substrait-io/substrait-java/commit/e57491333c7dd05ae3b1400e2185f807af1f5f88)) + +## [0.35.0](https://github.com/substrait-io/substrait-java/compare/v0.34.0...v0.35.0) (2024-06-30) + +### Features + +* deprecate Timestamp and TimestampTZ visit functions ([#273](https://github.com/substrait-io/substrait-java/issues/273)) ([8a8253e](https://github.com/substrait-io/substrait-java/commit/8a8253ec1f077b81d0da6503e299662048fca825)) +* introduce substrait-spark module ([#271](https://github.com/substrait-io/substrait-java/issues/271)) ([8537dca](https://github.com/substrait-io/substrait-java/commit/8537dca93b410177f2ee5aefffe83f7c02a3668c)) + +## [0.34.0](https://github.com/substrait-io/substrait-java/compare/v0.33.0...v0.34.0) (2024-06-23) + +### ⚠ BREAKING CHANGES + +* getDfsNames() has been removed from VirtualTableScan +* getInitialSchema() not longer has a default implementation in VirtualTableScan + +### Bug Fixes + +* set VirtualTableScan schema explicitly ([#272](https://github.com/substrait-io/substrait-java/issues/272)) ([f1192cf](https://github.com/substrait-io/substrait-java/commit/f1192cfaf6c84fb1e466bae6eda75ba164444aa8)) + +## [0.33.0](https://github.com/substrait-io/substrait-java/compare/v0.32.0...v0.33.0) (2024-06-16) + +### Features + +* **isthmus:** support for PrecisionTimestamp conversions ([#262](https://github.com/substrait-io/substrait-java/issues/262)) ([e726904](https://github.com/substrait-io/substrait-java/commit/e72690425cb31e52bc37550c1c4851db1b927651)) + +### Bug Fixes + +* **isthmus:** correct SLF4J dependency ([#268](https://github.com/substrait-io/substrait-java/issues/268)) ([3134504](https://github.com/substrait-io/substrait-java/commit/31345045d522bf85bc60a59d52e4dd55601abbf8)) + +## [0.32.0](https://github.com/substrait-io/substrait-java/compare/v0.31.0...v0.32.0) (2024-06-04) + +### ⚠ BREAKING CHANGES + +* Substrait FP32 is now mapped to Calcite REAL instead of FLOAT +* Calcite FLOAT is now mapped to Substrait FP64 instead of FP32 + +In Calcite, the Sql Type Names DOUBLE and FLOAT correspond to FP64, and REAL corresponds to FP32 + +### Bug Fixes + +* account for struct fields in VirtualTableScan check ([#255](https://github.com/substrait-io/substrait-java/issues/255)) ([3bbcf82](https://github.com/substrait-io/substrait-java/commit/3bbcf82687bc51fdb1695436c198e91ba56befed)) +* map Calcite REAL to Substrait FP32 ([#261](https://github.com/substrait-io/substrait-java/issues/261)) ([37331c2](https://github.com/substrait-io/substrait-java/commit/37331c2fbee679fd5ec482d8ff4d16f1c7c1c5c0)) + +## [0.31.0](https://github.com/substrait-io/substrait-java/compare/v0.30.0...v0.31.0) (2024-05-05) + + +### ⚠ BREAKING CHANGES + +* **isthumus:** CLI related functionality is now in the io.substrait.isthmus.cli package + +### Features + +* allow deployment time selection of logging framework [#243](https://github.com/substrait-io/substrait-java/issues/243) ([#244](https://github.com/substrait-io/substrait-java/issues/244)) ([72bab63](https://github.com/substrait-io/substrait-java/commit/72bab63edf6c4ffb12c3c4b0e4f49d066e0c5524)) +* **isthumus:** extract CLI into isthmus-cli project [#248](https://github.com/substrait-io/substrait-java/issues/248) ([#249](https://github.com/substrait-io/substrait-java/issues/249)) ([a49de62](https://github.com/substrait-io/substrait-java/commit/a49de62c670274cccfa8b94fb86e88b36fc716d3)) + +## [0.30.0](https://github.com/substrait-io/substrait-java/compare/v0.29.1...v0.30.0) (2024-04-28) + + +### ⚠ BREAKING CHANGES + +* ParameterizedTypeVisitor has new visit methods +* TypeExpressionVisitor has new visit methods +* TypeVisitor has new visit methods +* BaseProtoTypes has new visit methods + +### Bug Fixes + +* handle FetchRels with only offset set ([#252](https://github.com/substrait-io/substrait-java/issues/252)) ([b334e1d](https://github.com/substrait-io/substrait-java/commit/b334e1d4004ebc4598cab7bc076f3d477e97a52a)) + + +### Miscellaneous Chores + +* update to substrait 0.48.0 ([#250](https://github.com/substrait-io/substrait-java/issues/250)) ([572fe57](https://github.com/substrait-io/substrait-java/commit/572fe57ccf473e3d680f8928dd5f6833583350cc)) + +## [0.29.1](https://github.com/substrait-io/substrait-java/compare/v0.29.0...v0.29.1) (2024-03-31) + + +### Bug Fixes + +* correct function compound names for IntervalDay and IntervalYear [#240](https://github.com/substrait-io/substrait-java/issues/240) ([#242](https://github.com/substrait-io/substrait-java/issues/242)) ([856331b](https://github.com/substrait-io/substrait-java/commit/856331bae9901e618663622bbf60eaf923dea5b8)) + +## [0.29.0](https://github.com/substrait-io/substrait-java/compare/v0.28.1...v0.29.0) (2024-03-17) + + +### ⚠ BREAKING CHANGES + +* **isthmus:** method ExpressionCreator.cast(Type, Expression) has been removed + +### Features + +* **isthmus:** support for safe casting ([#236](https://github.com/substrait-io/substrait-java/issues/236)) ([72785ad](https://github.com/substrait-io/substrait-java/commit/72785ad1a4bd1ba8481d75ddaf4f1a822bf9bf6b)) + +## [0.28.1](https://github.com/substrait-io/substrait-java/compare/v0.28.0...v0.28.1) (2024-03-10) + + +### Bug Fixes + +* use coercive function matcher before least restrictive matcher ([#237](https://github.com/substrait-io/substrait-java/issues/237)) ([e7aa8ff](https://github.com/substrait-io/substrait-java/commit/e7aa8ff1fe11dd784074138bf75c1afa140b59db)) + ## [0.28.0](https://github.com/substrait-io/substrait-java/compare/v0.27.0...v0.28.0) (2024-02-25) diff --git a/build.gradle.kts b/build.gradle.kts index 47a9da29f..0911e7092 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -9,18 +9,23 @@ plugins { id("com.github.vlsi.gradle-extensions") version "1.74" id("com.diffplug.spotless") version "6.11.0" id("io.github.gradle-nexus.publish-plugin") version "1.1.0" + id("org.cyclonedx.bom") version "1.8.2" } +var IMMUTABLES_VERSION = properties.get("immutables.version") +var JUNIT_VERSION = properties.get("junit.version") +var SLF4J_VERSION = properties.get("slf4j.version") + repositories { mavenCentral() } java { toolchain { languageVersion.set(JavaLanguageVersion.of(17)) } } dependencies { - testImplementation("org.junit.jupiter:junit-jupiter-api:5.9.2") - testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine") - implementation("org.slf4j:slf4j-jdk14:1.7.30") - annotationProcessor("org.immutables:value:2.8.8") - compileOnly("org.immutables:value-annotations:2.8.8") + testImplementation("org.junit.jupiter:junit-jupiter-api:${JUNIT_VERSION}") + testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine:${JUNIT_VERSION}") + implementation("org.slf4j:slf4j-api:${SLF4J_VERSION}") + annotationProcessor("org.immutables:value:${IMMUTABLES_VERSION}") + compileOnly("org.immutables:value-annotations:${IMMUTABLES_VERSION}") annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") } @@ -64,6 +69,21 @@ allprojects { } } } + + if (listOf("core", "isthmus", "isthmus-cli").contains(project.name)) { + apply(plugin = "org.cyclonedx.bom") + tasks.cyclonedxBom { + setIncludeConfigs(listOf("runtimeClasspath")) + setSkipConfigs(listOf("compileClasspath", "testCompileClasspath")) + setProjectType("library") + setSchemaVersion("1.5") + setDestination(project.file("build/reports")) + setOutputName("bom") + setOutputFormat("json") + setIncludeBomSerialNumber(false) + setIncludeLicenseText(false) + } + } } nexusPublishing { diff --git a/ci/release/dry_run.sh b/ci/release/dry_run.sh index 25d770e40..238c57b25 100755 --- a/ci/release/dry_run.sh +++ b/ci/release/dry_run.sh @@ -29,7 +29,7 @@ npx --yes \ -p "@semantic-release/changelog" \ -p "@semantic-release/exec" \ -p "@semantic-release/git" \ - -p "conventional-changelog-conventionalcommits@6.1.0" \ + -p "conventional-changelog-conventionalcommits" \ semantic-release \ --ci false \ --dry-run \ diff --git a/ci/release/publish.sh b/ci/release/publish.sh index 7d9cb64e2..1768d0a00 100755 --- a/ci/release/publish.sh +++ b/ci/release/publish.sh @@ -4,4 +4,4 @@ set -euo pipefail gradle wrapper -./gradlew clean :core:publishToSonatype :isthmus:publishToSonatype closeAndReleaseSonatypeStagingRepository +./gradlew clean :core:publishToSonatype :isthmus:publishToSonatype :spark:publishToSonatype closeAndReleaseSonatypeStagingRepository diff --git a/ci/release/run.sh b/ci/release/run.sh index 9441ef23f..438e00c77 100755 --- a/ci/release/run.sh +++ b/ci/release/run.sh @@ -11,6 +11,5 @@ npx --yes \ -p "@semantic-release/github" \ -p "@semantic-release/exec" \ -p "@semantic-release/git" \ - -p "conventional-changelog-conventionalcommits@6.1.0" \ + -p "conventional-changelog-conventionalcommits" \ semantic-release --ci - diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 6b8cfac66..6cb64ed51 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -1,5 +1,3 @@ -import com.google.protobuf.gradle.protobuf -import com.google.protobuf.gradle.protoc import org.gradle.plugins.ide.idea.model.IdeaModel plugins { @@ -7,8 +5,9 @@ plugins { id("java") id("idea") id("antlr") - id("com.google.protobuf") version "0.8.17" + id("com.google.protobuf") version "0.9.4" id("com.diffplug.spotless") version "6.11.0" + id("com.github.johnrengelman.shadow") version "8.1.1" signing } @@ -45,8 +44,8 @@ publishing { repositories { maven { name = "local" - val releasesRepoUrl = "$buildDir/repos/releases" - val snapshotsRepoUrl = "$buildDir/repos/snapshots" + val releasesRepoUrl = layout.buildDirectory.dir("repos/releases") + val snapshotsRepoUrl = layout.buildDirectory.dir("repos/snapshots") url = uri(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) } } @@ -67,26 +66,56 @@ signing { sign(publishing.publications["maven-publish"]) } +val ANTLR_VERSION = properties.get("antlr.version") +val IMMUTABLES_VERSION = properties.get("immutables.version") +val JACKSON_VERSION = properties.get("jackson.version") +val JUNIT_VERSION = properties.get("junit.version") +val SLF4J_VERSION = properties.get("slf4j.version") +val PROTOBUF_VERSION = properties.get("protobuf.version") + +// This allows specifying deps to be shadowed so that they don't get included in the POM file +val shadowImplementation by configurations.creating + +configurations[JavaPlugin.COMPILE_ONLY_CONFIGURATION_NAME].extendsFrom(shadowImplementation) + +configurations[JavaPlugin.TEST_IMPLEMENTATION_CONFIGURATION_NAME].extendsFrom(shadowImplementation) + dependencies { - testImplementation("org.junit.jupiter:junit-jupiter-api:5.9.2") - testImplementation("org.junit.jupiter:junit-jupiter-params:5.9.2") - testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine") - implementation("com.google.protobuf:protobuf-java:3.17.3") - implementation("com.fasterxml.jackson.core:jackson-databind:2.13.4") - implementation("com.fasterxml.jackson.core:jackson-annotations:2.13.4") - implementation("com.fasterxml.jackson.datatype:jackson-datatype-jdk8:2.13.4") - implementation("com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.13.4") + testImplementation("org.junit.jupiter:junit-jupiter-api:${JUNIT_VERSION}") + testImplementation("org.junit.jupiter:junit-jupiter-params:${JUNIT_VERSION}") + testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine:${JUNIT_VERSION}") + implementation("com.google.protobuf:protobuf-java:${PROTOBUF_VERSION}") + implementation("com.fasterxml.jackson.core:jackson-databind:${JACKSON_VERSION}") + implementation("com.fasterxml.jackson.core:jackson-annotations:${JACKSON_VERSION}") + implementation("com.fasterxml.jackson.datatype:jackson-datatype-jdk8:${JACKSON_VERSION}") + implementation("com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:${JACKSON_VERSION}") implementation("com.google.code.findbugs:jsr305:3.0.2") - antlr("org.antlr:antlr4:4.9.2") - implementation("org.slf4j:slf4j-jdk14:1.7.30") - implementation("org.antlr:antlr4-runtime:4.9.2") - annotationProcessor("org.immutables:value:2.8.8") - compileOnly("org.immutables:value-annotations:2.8.8") + antlr("org.antlr:antlr4:${ANTLR_VERSION}") + shadowImplementation("org.antlr:antlr4-runtime:${ANTLR_VERSION}") + implementation("org.slf4j:slf4j-api:${SLF4J_VERSION}") + annotationProcessor("org.immutables:value:${IMMUTABLES_VERSION}") + compileOnly("org.immutables:value-annotations:${IMMUTABLES_VERSION}") annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") } +configurations[JavaPlugin.API_CONFIGURATION_NAME].let { apiConfiguration -> + // Workaround for https://github.com/gradle/gradle/issues/820 + apiConfiguration.setExtendsFrom(apiConfiguration.extendsFrom.filter { it.name != "antlr" }) +} + +tasks { + shadowJar { + archiveClassifier.set("") // to override ".jar" instead of producing "-all.jar" + minimize() + // bundle the deps from shadowImplementation into the jar + configurations = listOf(shadowImplementation) + // rename the shadowed deps so that they don't conflict with consumer's own deps + relocate("org.antlr.v4.runtime", "io.substrait.org.antlr.v4.runtime") + } +} + java { toolchain { languageVersion.set(JavaLanguageVersion.of(17)) @@ -128,7 +157,8 @@ tasks.named("generateGrammarSource") { arguments.add("-Werror") arguments.add("-Xexact-output-dir") setSource(fileTree("src/main/antlr/SubstraitType.g4")) - outputDirectory = File(buildDir, "generated/sources/antlr/main/java/io/substrait/type") + outputDirectory = + layout.buildDirectory.dir("generated/sources/antlr/main/java/io/substrait/type").get().asFile } -protobuf { protoc { artifact = "com.google.protobuf:protoc:3.17.3" } } +protobuf { protoc { artifact = "com.google.protobuf:protoc:${PROTOBUF_VERSION}" } } diff --git a/core/src/main/antlr/SubstraitType.g4 b/core/src/main/antlr/SubstraitType.g4 index 4202eb611..9323a598e 100644 --- a/core/src/main/antlr/SubstraitType.g4 +++ b/core/src/main/antlr/SubstraitType.g4 @@ -51,6 +51,8 @@ IntervalYear: I N T E R V A L '_' Y E A R; IntervalDay: I N T E R V A L '_' D A Y; UUID : U U I D; Decimal : D E C I M A L; +PrecisionTimestamp: P R E C I S I O N '_' T I M E S T A M P; +PrecisionTimestampTZ: P R E C I S I O N '_' T I M E S T A M P '_' T Z; FixedChar: F I X E D C H A R; VarChar : V A R C H A R; FixedBinary: F I X E D B I N A R Y; @@ -97,7 +99,7 @@ SingleQuote: '\''; Number - : '-'? Int + : '-'? Int ; Identifier @@ -105,73 +107,75 @@ Identifier ; LineComment - : '//' ~[\r\n]* -> channel(HIDDEN) - ; + : '//' ~[\r\n]* -> channel(HIDDEN) + ; BlockComment - : ( '/*' - ( '/'* BlockComment - | ~[/*] - | '/'+ ~[/*] - | '*'+ ~[/*] - )* - '*'* - '*/' - ) -> channel(HIDDEN) - ; + : ( '/*' + ( '/'* BlockComment + | ~[/*] + | '/'+ ~[/*] + | '*'+ ~[/*] + )* + '*'* + '*/' + ) -> channel(HIDDEN) + ; Whitespace - : [ \t]+ -> channel(HIDDEN) - ; + : [ \t]+ -> channel(HIDDEN) + ; Newline - : ( '\r' '\n'? - | '\n' - ) - ; + : ( '\r' '\n'? + | '\n' + ) + ; fragment Int - : '1'..'9' Digit* - | '0' + : '1'..'9' Digit* + | '0' ; fragment Digit - : '0'..'9' + : '0'..'9' ; start: expr EOF; scalarType - : Boolean #Boolean - | I8 #i8 - | I16 #i16 - | I32 #i32 - | I64 #i64 - | FP32 #fp32 - | FP64 #fp64 - | String #string - | Binary #binary - | Timestamp #timestamp - | TimestampTZ #timestampTz - | Date #date - | Time #time - | IntervalDay #intervalDay - | IntervalYear #intervalYear - | UUID #uuid - | UserDefined Identifier #userDefined - ; + : Boolean #Boolean + | I8 #i8 + | I16 #i16 + | I32 #i32 + | I64 #i64 + | FP32 #fp32 + | FP64 #fp64 + | String #string + | Binary #binary + | Timestamp #timestamp + | TimestampTZ #timestampTz + | Date #date + | Time #time + | IntervalDay #intervalDay + | IntervalYear #intervalYear + | UUID #uuid + | UserDefined Identifier #userDefined + ; parameterizedType - : FixedChar isnull='?'? Lt len=numericParameter Gt #fixedChar - | VarChar isnull='?'? Lt len=numericParameter Gt #varChar - | FixedBinary isnull='?'? Lt len=numericParameter Gt #fixedBinary - | Decimal isnull='?'? Lt precision=numericParameter Comma scale=numericParameter Gt #decimal - | Struct isnull='?'? Lt expr (Comma expr)* Gt #struct - | NStruct isnull='?'? Lt Identifier expr (Comma Identifier expr)* Gt #nStruct - | List isnull='?'? Lt expr Gt #list - | Map isnull='?'? Lt key=expr Comma value=expr Gt #map - ; + : FixedChar isnull='?'? Lt len=numericParameter Gt #fixedChar + | VarChar isnull='?'? Lt len=numericParameter Gt #varChar + | FixedBinary isnull='?'? Lt len=numericParameter Gt #fixedBinary + | Decimal isnull='?'? Lt precision=numericParameter Comma scale=numericParameter Gt #decimal + | PrecisionTimestamp isnull='?'? Lt precision=numericParameter Gt #precisionTimestamp + | PrecisionTimestampTZ isnull='?'? Lt precision=numericParameter Gt #precisionTimestampTZ + | Struct isnull='?'? Lt expr (Comma expr)* Gt #struct + | NStruct isnull='?'? Lt Identifier expr (Comma Identifier expr)* Gt #nStruct + | List isnull='?'? Lt expr Gt #list + | Map isnull='?'? Lt key=expr Comma value=expr Gt #map + ; numericParameter : Number #numericLiteral diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index a1e5e3fba..1a03c3027 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -12,6 +12,7 @@ import io.substrait.expression.ImmutableExpression.SingleOrList; import io.substrait.expression.ImmutableExpression.Switch; import io.substrait.expression.ImmutableFieldReference; +import io.substrait.expression.WindowBound; import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.SimpleExtension; import io.substrait.function.ToTypeString; @@ -38,6 +39,7 @@ import java.util.Arrays; import java.util.List; import java.util.Optional; +import java.util.OptionalLong; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -108,14 +110,30 @@ private Cross cross(Rel left, Rel right, Optional remap) { } public Fetch fetch(long offset, long count, Rel input) { - return fetch(offset, count, Optional.empty(), input); + return fetch(offset, OptionalLong.of(count), Optional.empty(), input); } public Fetch fetch(long offset, long count, Rel.Remap remap, Rel input) { - return fetch(offset, count, Optional.of(remap), input); + return fetch(offset, OptionalLong.of(count), Optional.of(remap), input); } - private Fetch fetch(long offset, long count, Optional remap, Rel input) { + public Fetch limit(long limit, Rel input) { + return fetch(0, OptionalLong.of(limit), Optional.empty(), input); + } + + public Fetch limit(long limit, Rel.Remap remap, Rel input) { + return fetch(0, OptionalLong.of(limit), Optional.of(remap), input); + } + + public Fetch offset(long offset, Rel input) { + return fetch(offset, OptionalLong.empty(), Optional.empty(), input); + } + + public Fetch offset(long offset, Rel.Remap remap, Rel input) { + return fetch(offset, OptionalLong.empty(), Optional.of(remap), input); + } + + private Fetch fetch(long offset, OptionalLong count, Optional remap, Rel input) { return Fetch.builder().offset(offset).count(count).input(input).remap(remap).build(); } @@ -336,6 +354,10 @@ public Expression.I32Literal i32(int v) { return Expression.I32Literal.builder().value(v).build(); } + public Expression.FP64Literal fp64(double v) { + return Expression.FP64Literal.builder().value(v).build(); + } + public Expression cast(Expression input, Type type) { return Cast.builder() .input(input) @@ -600,6 +622,30 @@ public Expression.ScalarFunctionInvocation scalarFn( .build(); } + public Expression.WindowFunctionInvocation windowFn( + String namespace, + String key, + Type outputType, + Expression.AggregationPhase aggregationPhase, + Expression.AggregationInvocation invocation, + Expression.WindowBoundsType boundsType, + WindowBound lowerBound, + WindowBound upperBound, + Expression... args) { + var declaration = + extensions.getWindowFunction(SimpleExtension.FunctionAnchor.of(namespace, key)); + return Expression.WindowFunctionInvocation.builder() + .declaration(declaration) + .outputType(outputType) + .aggregationPhase(aggregationPhase) + .invocation(invocation) + .boundsType(boundsType) + .lowerBound(lowerBound) + .upperBound(upperBound) + .arguments(Arrays.stream(args).collect(java.util.stream.Collectors.toList())) + .build(); + } + // Types public Type.UserDefined userDefinedType(String namespace, String typeName) { diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index 83916a46e..6b1c9177c 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -74,6 +74,16 @@ public OUTPUT visit(Expression.TimestampTZLiteral expr) throws EXCEPTION { return visitFallback(expr); } + @Override + public OUTPUT visit(Expression.PrecisionTimestampLiteral expr) throws EXCEPTION { + return visitFallback(expr); + } + + @Override + public OUTPUT visit(Expression.PrecisionTimestampTZLiteral expr) throws EXCEPTION { + return visitFallback(expr); + } + @Override public OUTPUT visit(Expression.IntervalYearLiteral expr) throws EXCEPTION { return visitFallback(expr); diff --git a/core/src/main/java/io/substrait/expression/AggregateFunctionInvocation.java b/core/src/main/java/io/substrait/expression/AggregateFunctionInvocation.java index 5a8796d50..987ffcd23 100644 --- a/core/src/main/java/io/substrait/expression/AggregateFunctionInvocation.java +++ b/core/src/main/java/io/substrait/expression/AggregateFunctionInvocation.java @@ -3,7 +3,6 @@ import io.substrait.extension.SimpleExtension; import io.substrait.type.Type; import java.util.List; -import java.util.Map; import org.immutables.value.Value; @Value.Immutable @@ -12,7 +11,7 @@ public abstract class AggregateFunctionInvocation { public abstract List arguments(); - public abstract Map options(); + public abstract List options(); public abstract Expression.AggregationPhase aggregationPhase(); diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index aa5ef0561..b592341b5 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -270,6 +270,44 @@ public R accept(ExpressionVisitor visitor) throws } } + @Value.Immutable + abstract static class PrecisionTimestampLiteral implements Literal { + public abstract long value(); + + public abstract int precision(); + + public Type getType() { + return Type.withNullability(nullable()).precisionTimestamp(precision()); + } + + public static ImmutableExpression.PrecisionTimestampLiteral.Builder builder() { + return ImmutableExpression.PrecisionTimestampLiteral.builder(); + } + + public R accept(ExpressionVisitor visitor) throws E { + return visitor.visit(this); + } + } + + @Value.Immutable + abstract static class PrecisionTimestampTZLiteral implements Literal { + public abstract long value(); + + public abstract int precision(); + + public Type getType() { + return Type.withNullability(nullable()).precisionTimestampTZ(precision()); + } + + public static ImmutableExpression.PrecisionTimestampTZLiteral.Builder builder() { + return ImmutableExpression.PrecisionTimestampTZLiteral.builder(); + } + + public R accept(ExpressionVisitor visitor) throws E { + return visitor.visit(this); + } + } + @Value.Immutable abstract static class IntervalYearLiteral implements Literal { public abstract int years(); @@ -596,7 +634,7 @@ abstract static class ScalarFunctionInvocation implements Expression { public abstract List arguments(); - public abstract Map options(); + public abstract List options(); public abstract Type outputType(); @@ -620,7 +658,7 @@ abstract class WindowFunctionInvocation implements Expression { public abstract List arguments(); - public abstract Map options(); + public abstract List options(); public abstract AggregationPhase aggregationPhase(); diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index 8ade39b47..671d9d465 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -10,6 +10,7 @@ import java.time.Instant; import java.time.LocalDateTime; import java.time.ZoneOffset; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.UUID; @@ -75,10 +76,18 @@ public static Expression.TimeLiteral time(boolean nullable, long value) { return Expression.TimeLiteral.builder().nullable(nullable).value(value).build(); } + /** + * @deprecated Timestamp is deprecated in favor of PrecisionTimestamp + */ + @Deprecated public static Expression.TimestampLiteral timestamp(boolean nullable, long value) { return Expression.TimestampLiteral.builder().nullable(nullable).value(value).build(); } + /** + * @deprecated Timestamp is deprecated in favor of PrecisionTimestamp + */ + @Deprecated public static Expression.TimestampLiteral timestamp(boolean nullable, LocalDateTime value) { var epochMicro = TimeUnit.SECONDS.toMicros(value.toEpochSecond(ZoneOffset.UTC)) @@ -86,6 +95,10 @@ public static Expression.TimestampLiteral timestamp(boolean nullable, LocalDateT return timestamp(nullable, epochMicro); } + /** + * @deprecated Timestamp is deprecated in favor of PrecisionTimestamp + */ + @Deprecated public static Expression.TimestampLiteral timestamp( boolean nullable, int year, @@ -101,10 +114,18 @@ public static Expression.TimestampLiteral timestamp( .withNano((int) TimeUnit.MICROSECONDS.toNanos(micros))); } + /** + * @deprecated TimestampTZ is deprecated in favor of PrecisionTimestampTZ + */ + @Deprecated public static Expression.TimestampTZLiteral timestampTZ(boolean nullable, long value) { return Expression.TimestampTZLiteral.builder().nullable(nullable).value(value).build(); } + /** + * @deprecated TimestampTZ is deprecated in favor of PrecisionTimestampTZ + */ + @Deprecated public static Expression.TimestampTZLiteral timestampTZ(boolean nullable, Instant value) { var epochMicro = TimeUnit.SECONDS.toMicros(value.getEpochSecond()) @@ -112,6 +133,24 @@ public static Expression.TimestampTZLiteral timestampTZ(boolean nullable, Instan return timestampTZ(nullable, epochMicro); } + public static Expression.PrecisionTimestampLiteral precisionTimestamp( + boolean nullable, long value, int precision) { + return Expression.PrecisionTimestampLiteral.builder() + .nullable(nullable) + .value(value) + .precision(precision) + .build(); + } + + public static Expression.PrecisionTimestampTZLiteral precisionTimestampTZ( + boolean nullable, long value, int precision) { + return Expression.PrecisionTimestampTZLiteral.builder() + .nullable(nullable) + .value(value) + .precision(precision) + .build(); + } + public static Expression.IntervalYearLiteral intervalYear( boolean nullable, int years, int months) { return Expression.IntervalYearLiteral.builder() @@ -284,13 +323,13 @@ public static Expression.ScalarFunctionInvocation scalarFunction( SimpleExtension.ScalarFunctionVariant declaration, Type outputType, FunctionArg... arguments) { - return Expression.ScalarFunctionInvocation.builder() - .declaration(declaration) - .outputType(outputType) - .addArguments(arguments) - .build(); + return scalarFunction(declaration, outputType, Arrays.asList(arguments)); } + /** + * Use {@link Expression.ScalarFunctionInvocation#builder()} directly to specify other parameters, + * e.g. options + */ public static Expression.ScalarFunctionInvocation scalarFunction( SimpleExtension.ScalarFunctionVariant declaration, Type outputType, @@ -302,6 +341,10 @@ public static Expression.ScalarFunctionInvocation scalarFunction( .build(); } + /** + * Use {@link AggregateFunctionInvocation#builder()} directly to specify other parameters, e.g. + * options + */ public static AggregateFunctionInvocation aggregateFunction( SimpleExtension.AggregateFunctionVariant declaration, Type outputType, @@ -326,16 +369,14 @@ public static AggregateFunctionInvocation aggregateFunction( List sort, Expression.AggregationInvocation invocation, FunctionArg... arguments) { - return AggregateFunctionInvocation.builder() - .declaration(declaration) - .outputType(outputType) - .aggregationPhase(phase) - .sort(sort) - .invocation(invocation) - .addArguments(arguments) - .build(); + return aggregateFunction( + declaration, outputType, phase, sort, invocation, Arrays.asList(arguments)); } + /** + * Use {@link Expression.WindowFunctionInvocation#builder()} directly to specify other parameters, + * e.g. options + */ public static Expression.WindowFunctionInvocation windowFunction( SimpleExtension.WindowFunctionVariant declaration, Type outputType, @@ -361,6 +402,10 @@ public static Expression.WindowFunctionInvocation windowFunction( .build(); } + /** + * Use {@link ConsistentPartitionWindow.WindowRelFunctionInvocation#builder()} directly to specify + * other parameters, e.g. options + */ public static ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFunction( SimpleExtension.WindowFunctionVariant declaration, Type outputType, @@ -393,22 +438,17 @@ public static Expression.WindowFunctionInvocation windowFunction( WindowBound lowerBound, WindowBound upperBound, FunctionArg... arguments) { - return Expression.WindowFunctionInvocation.builder() - .declaration(declaration) - .outputType(outputType) - .aggregationPhase(phase) - .sort(sort) - .invocation(invocation) - .partitionBy(partitionBy) - .boundsType(boundsType) - .lowerBound(lowerBound) - .upperBound(upperBound) - .addArguments(arguments) - .build(); - } - - public static Expression cast(Type type, Expression expression) { - return cast(type, expression, Expression.FailureBehavior.UNSPECIFIED); + return windowFunction( + declaration, + outputType, + phase, + sort, + invocation, + partitionBy, + boundsType, + lowerBound, + upperBound, + Arrays.asList(arguments)); } public static Expression cast( diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index 5e2c00854..42c78c184 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -31,6 +31,10 @@ public interface ExpressionVisitor { R visit(Expression.TimestampTZLiteral expr) throws E; + R visit(Expression.PrecisionTimestampLiteral expr) throws E; + + R visit(Expression.PrecisionTimestampTZLiteral expr) throws E; + R visit(Expression.IntervalYearLiteral expr) throws E; R visit(Expression.IntervalDayLiteral expr) throws E; diff --git a/core/src/main/java/io/substrait/expression/FunctionOption.java b/core/src/main/java/io/substrait/expression/FunctionOption.java index 4ee3fdafd..5ee775789 100644 --- a/core/src/main/java/io/substrait/expression/FunctionOption.java +++ b/core/src/main/java/io/substrait/expression/FunctionOption.java @@ -9,4 +9,8 @@ public abstract class FunctionOption { public abstract String getName(); public abstract List values(); + + public static ImmutableFunctionOption.Builder builder() { + return ImmutableFunctionOption.builder(); + } } diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index c4a313453..7b0bb687b 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -10,6 +10,7 @@ import io.substrait.extension.SimpleExtension; import io.substrait.proto.Expression; import io.substrait.proto.FunctionArgument; +import io.substrait.proto.FunctionOption; import io.substrait.proto.Rel; import io.substrait.proto.SortField; import io.substrait.proto.Type; @@ -112,6 +113,32 @@ public Expression visit(io.substrait.expression.Expression.TimestampTZLiteral ex return lit(bldr -> bldr.setNullable(expr.nullable()).setTimestampTz(expr.value())); } + @Override + public Expression visit(io.substrait.expression.Expression.PrecisionTimestampLiteral expr) { + return lit( + bldr -> + bldr.setNullable(expr.nullable()) + .setPrecisionTimestamp( + Expression.Literal.PrecisionTimestamp.newBuilder() + .setValue(expr.value()) + .setPrecision(expr.precision()) + .build()) + .build()); + } + + @Override + public Expression visit(io.substrait.expression.Expression.PrecisionTimestampTZLiteral expr) { + return lit( + bldr -> + bldr.setNullable(expr.nullable()) + .setPrecisionTimestampTz( + Expression.Literal.PrecisionTimestamp.newBuilder() + .setValue(expr.value()) + .setPrecision(expr.precision()) + .build()) + .build()); + } + @Override public Expression visit(io.substrait.expression.Expression.IntervalYearLiteral expr) { return lit( @@ -314,10 +341,21 @@ public Expression visit(io.substrait.expression.Expression.ScalarFunctionInvocat .addAllArguments( expr.arguments().stream() .map(a -> a.accept(expr.declaration(), 0, argVisitor)) + .collect(java.util.stream.Collectors.toList())) + .addAllOptions( + expr.options().stream() + .map(ExpressionProtoConverter::from) .collect(java.util.stream.Collectors.toList()))) .build(); } + public static FunctionOption from(io.substrait.expression.FunctionOption option) { + return FunctionOption.newBuilder() + .setName(option.getName()) + .addAllPreference(option.values()) + .build(); + } + @Override public Expression visit(io.substrait.expression.Expression.Cast expr) { return Expression.newBuilder() @@ -495,7 +533,11 @@ public Expression visit(io.substrait.expression.Expression.WindowFunctionInvocat .addAllPartitions(partitionExprs) .setBoundsType(expr.boundsType().toProto()) .setLowerBound(lowerBound) - .setUpperBound(upperBound)) + .setUpperBound(upperBound) + .addAllOptions( + expr.options().stream() + .map(ExpressionProtoConverter::from) + .collect(java.util.stream.Collectors.toList()))) .build(); } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index e2fc56460..b429338ea 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -6,7 +6,6 @@ import io.substrait.expression.FunctionArg; import io.substrait.expression.FunctionOption; import io.substrait.expression.ImmutableExpression; -import io.substrait.expression.ImmutableFunctionOption; import io.substrait.expression.WindowBound; import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; @@ -117,10 +116,15 @@ public Expression from(io.substrait.proto.Expression expr) { IntStream.range(0, scalarFunction.getArgumentsCount()) .mapToObj(i -> pF.convert(declaration, i, scalarFunction.getArguments(i))) .collect(java.util.stream.Collectors.toList()); + var options = + scalarFunction.getOptionsList().stream() + .map(ProtoExpressionConverter::fromFunctionOption) + .collect(Collectors.toList()); yield ImmutableExpression.ScalarFunctionInvocation.builder() .addAllArguments(args) .declaration(declaration) .outputType(protoTypeConverter.from(scalarFunction.getOutputType())) + .options(options) .build(); } case WINDOW_FUNCTION -> fromWindowFunction(expr.getWindowFunction()); @@ -174,7 +178,9 @@ public Expression from(io.substrait.proto.Expression expr) { .build(); } case CAST -> ExpressionCreator.cast( - protoTypeConverter.from(expr.getCast().getType()), from(expr.getCast().getInput())); + protoTypeConverter.from(expr.getCast().getType()), + from(expr.getCast().getInput()), + Expression.FailureBehavior.fromProto(expr.getCast().getFailureBehavior())); case SUBQUERY -> { switch (expr.getSubquery().getSubqueryTypeCase()) { case SET_PREDICATE -> { @@ -239,8 +245,8 @@ public Expression.WindowFunctionInvocation fromWindowFunction( .collect(Collectors.toList()); var options = windowFunction.getOptionsList().stream() - .map(this::fromFunctionOption) - .collect(Collectors.toMap(FunctionOption::getName, Function.identity())); + .map(ProtoExpressionConverter::fromFunctionOption) + .collect(Collectors.toList()); WindowBound lowerBound = toWindowBound(windowFunction.getLowerBound()); WindowBound upperBound = toWindowBound(windowFunction.getUpperBound()); @@ -274,8 +280,8 @@ public ConsistentPartitionWindow.WindowRelFunctionInvocation fromWindowRelFuncti windowRelFunction::getArguments); var options = windowRelFunction.getOptionsList().stream() - .map(this::fromFunctionOption) - .collect(Collectors.toMap(FunctionOption::getName, Function.identity())); + .map(ProtoExpressionConverter::fromFunctionOption) + .collect(Collectors.toList()); WindowBound lowerBound = toWindowBound(windowRelFunction.getLowerBound()); WindowBound upperBound = toWindowBound(windowRelFunction.getUpperBound()); @@ -318,6 +324,16 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { case STRING -> ExpressionCreator.string(literal.getNullable(), literal.getString()); case BINARY -> ExpressionCreator.binary(literal.getNullable(), literal.getBinary()); case TIMESTAMP -> ExpressionCreator.timestamp(literal.getNullable(), literal.getTimestamp()); + case TIMESTAMP_TZ -> ExpressionCreator.timestampTZ( + literal.getNullable(), literal.getTimestampTz()); + case PRECISION_TIMESTAMP -> ExpressionCreator.precisionTimestamp( + literal.getNullable(), + literal.getPrecisionTimestamp().getValue(), + literal.getPrecisionTimestamp().getPrecision()); + case PRECISION_TIMESTAMP_TZ -> ExpressionCreator.precisionTimestampTZ( + literal.getNullable(), + literal.getPrecisionTimestampTz().getValue(), + literal.getPrecisionTimestampTz().getPrecision()); case DATE -> ExpressionCreator.date(literal.getNullable(), literal.getDate()); case TIME -> ExpressionCreator.time(literal.getNullable(), literal.getTime()); case INTERVAL_YEAR_TO_MONTH -> ExpressionCreator.intervalYear( @@ -348,8 +364,6 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { literal.getNullable(), literal.getMap().getKeyValuesList().stream() .collect(Collectors.toMap(kv -> from(kv.getKey()), kv -> from(kv.getValue())))); - case TIMESTAMP_TZ -> ExpressionCreator.timestampTZ( - literal.getNullable(), literal.getTimestampTz()); case UUID -> ExpressionCreator.uuid(literal.getNullable(), literal.getUuid()); case NULL -> ExpressionCreator.typedNull(protoTypeConverter.from(literal.getNull())); case LIST -> ExpressionCreator.list( @@ -391,10 +405,7 @@ public Expression.SortField fromSortField(SortField s) { .build(); } - public FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) { - return ImmutableFunctionOption.builder() - .name(o.getName()) - .addAllValues(o.getPreferenceList()) - .build(); + public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) { + return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build(); } } diff --git a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java index 2d4085432..1938a2b51 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java @@ -64,9 +64,12 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp break; case MEASURE: io.substrait.relation.Aggregate.Measure measure = - new ProtoAggregateFunctionConverter( - functionLookup, extensionCollection, protoExpressionConverter) - .from(expressionReference.getMeasure()); + io.substrait.relation.Aggregate.Measure.builder() + .function( + new ProtoAggregateFunctionConverter( + functionLookup, extensionCollection, protoExpressionConverter) + .from(expressionReference.getMeasure())) + .build(); ImmutableAggregateFunctionReference buildMeasure = ImmutableAggregateFunctionReference.builder() .measure(measure) diff --git a/core/src/main/java/io/substrait/extension/AdvancedExtension.java b/core/src/main/java/io/substrait/extension/AdvancedExtension.java index bb4490efd..0d7278ef4 100644 --- a/core/src/main/java/io/substrait/extension/AdvancedExtension.java +++ b/core/src/main/java/io/substrait/extension/AdvancedExtension.java @@ -1,20 +1,21 @@ package io.substrait.extension; import io.substrait.relation.Extension; +import java.util.List; import java.util.Optional; import org.immutables.value.Value; @Value.Immutable public abstract class AdvancedExtension { - public abstract Optional getOptimization(); + public abstract List getOptimizations(); public abstract Optional getEnhancement(); public io.substrait.proto.AdvancedExtension toProto() { var builder = io.substrait.proto.AdvancedExtension.newBuilder(); getEnhancement().ifPresent(e -> builder.setEnhancement(e.toProto())); - getOptimization().ifPresent(e -> builder.setOptimization(e.toProto())); + getOptimizations().forEach(e -> builder.addOptimization(e.toProto())); return builder.build(); } diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index 3bbed56c8..29ba5fe2c 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -796,7 +796,7 @@ public static ExtensionCollection buildExtensionCollection( .windowFunctions(allWindowFunctionVariants) .addAllTypes(extensionSignatures.types()) .build(); - logger.debug( + logger.atDebug().log( "Loaded {} aggregate functions and {} scalar functions from {}.", collection.aggregateFunctions().size(), collection.scalarFunctions().size(), diff --git a/core/src/main/java/io/substrait/function/ParameterizedType.java b/core/src/main/java/io/substrait/function/ParameterizedType.java index 7b2202a27..767ad9253 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedType.java +++ b/core/src/main/java/io/substrait/function/ParameterizedType.java @@ -106,6 +106,36 @@ public static ImmutableParameterizedType.Decimal.Builder builder() { } } + @Value.Immutable + abstract static class PrecisionTimestamp extends BaseParameterizedType implements NullableType { + public abstract StringLiteral precision(); + + @Override + R accept(final ParameterizedTypeVisitor parameterizedTypeVisitor) + throws E { + return parameterizedTypeVisitor.visit(this); + } + + public static ImmutableParameterizedType.PrecisionTimestamp.Builder builder() { + return ImmutableParameterizedType.PrecisionTimestamp.builder(); + } + } + + @Value.Immutable + abstract static class PrecisionTimestampTZ extends BaseParameterizedType implements NullableType { + public abstract StringLiteral precision(); + + @Override + R accept(final ParameterizedTypeVisitor parameterizedTypeVisitor) + throws E { + return parameterizedTypeVisitor.visit(this); + } + + public static ImmutableParameterizedType.PrecisionTimestampTZ.Builder builder() { + return ImmutableParameterizedType.PrecisionTimestampTZ.builder(); + } + } + @Value.Immutable abstract static class Struct extends BaseParameterizedType implements NullableType { public abstract java.util.List fields(); diff --git a/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java b/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java index 75c81945c..63db88887 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java +++ b/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java @@ -49,6 +49,20 @@ public ParameterizedType decimalE(String precision, String scale) { .build(); } + public ParameterizedType precisionTimestampE(String precision) { + return ParameterizedType.PrecisionTimestamp.builder() + .nullable(nullable) + .precision(parameter(precision, false)) + .build(); + } + + public ParameterizedType precisionTimestampTZE(String precision) { + return ParameterizedType.PrecisionTimestampTZ.builder() + .nullable(nullable) + .precision(parameter(precision, false)) + .build(); + } + public ParameterizedType structE(ParameterizedType... types) { return ParameterizedType.Struct.builder().nullable(nullable).addFields(types).build(); } diff --git a/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java b/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java index 8f451453b..67e685cc5 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java +++ b/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java @@ -11,6 +11,10 @@ public interface ParameterizedTypeVisitor extends TypeVi R visit(ParameterizedType.Decimal expr) throws E; + R visit(ParameterizedType.PrecisionTimestamp expr) throws E; + + R visit(ParameterizedType.PrecisionTimestampTZ expr) throws E; + R visit(ParameterizedType.Struct expr) throws E; R visit(ParameterizedType.ListType expr) throws E; @@ -46,6 +50,16 @@ public R visit(ParameterizedType.Decimal expr) throws E { throw t(); } + @Override + public R visit(ParameterizedType.PrecisionTimestamp expr) throws E { + throw t(); + } + + @Override + public R visit(ParameterizedType.PrecisionTimestampTZ expr) throws E { + throw t(); + } + @Override public R visit(ParameterizedType.Struct expr) throws E { throw t(); diff --git a/core/src/main/java/io/substrait/function/ToTypeString.java b/core/src/main/java/io/substrait/function/ToTypeString.java index b86898528..f68733cc6 100644 --- a/core/src/main/java/io/substrait/function/ToTypeString.java +++ b/core/src/main/java/io/substrait/function/ToTypeString.java @@ -82,12 +82,12 @@ public String visit(final Type.Timestamp expr) { @Override public String visit(final Type.IntervalYear expr) { - return "year"; + return "iyear"; } @Override public String visit(final Type.IntervalDay expr) { - return "day"; + return "iday"; } @Override @@ -115,6 +115,16 @@ public String visit(final Type.Decimal expr) { return "dec"; } + @Override + public String visit(final Type.PrecisionTimestamp expr) { + return "pts"; + } + + @Override + public String visit(final Type.PrecisionTimestampTZ expr) { + return "ptstz"; + } + @Override public String visit(final Type.Struct expr) { return "struct"; @@ -155,6 +165,16 @@ public String visit(ParameterizedType.Decimal expr) throws RuntimeException { return "dec"; } + @Override + public String visit(ParameterizedType.PrecisionTimestamp expr) throws RuntimeException { + return "pts"; + } + + @Override + public String visit(ParameterizedType.PrecisionTimestampTZ expr) throws RuntimeException { + return "ptstz"; + } + @Override public String visit(ParameterizedType.Struct expr) throws RuntimeException { return "struct"; diff --git a/core/src/main/java/io/substrait/function/TypeExpression.java b/core/src/main/java/io/substrait/function/TypeExpression.java index 41ee976ef..e9f7c5ce1 100644 --- a/core/src/main/java/io/substrait/function/TypeExpression.java +++ b/core/src/main/java/io/substrait/function/TypeExpression.java @@ -84,6 +84,36 @@ public static ImmutableTypeExpression.Decimal.Builder builder() { } } + @Value.Immutable + abstract static class PrecisionTimestamp extends BaseTypeExpression implements NullableType { + + public abstract TypeExpression precision(); + + @Override + R acceptE(final TypeExpressionVisitor visitor) throws E { + return visitor.visit(this); + } + + public static ImmutableTypeExpression.PrecisionTimestamp.Builder builder() { + return ImmutableTypeExpression.PrecisionTimestamp.builder(); + } + } + + @Value.Immutable + abstract static class PrecisionTimestampTZ extends BaseTypeExpression implements NullableType { + + public abstract TypeExpression precision(); + + @Override + R acceptE(final TypeExpressionVisitor visitor) throws E { + return visitor.visit(this); + } + + public static ImmutableTypeExpression.PrecisionTimestampTZ.Builder builder() { + return ImmutableTypeExpression.PrecisionTimestampTZ.builder(); + } + } + @Value.Immutable abstract static class Struct extends BaseTypeExpression implements NullableType { public abstract java.util.List fields(); diff --git a/core/src/main/java/io/substrait/function/TypeExpressionCreator.java b/core/src/main/java/io/substrait/function/TypeExpressionCreator.java index ce1cb5aee..5dc56584f 100644 --- a/core/src/main/java/io/substrait/function/TypeExpressionCreator.java +++ b/core/src/main/java/io/substrait/function/TypeExpressionCreator.java @@ -35,6 +35,20 @@ public TypeExpression decimalE(TypeExpression precision, TypeExpression scale) { .build(); } + public TypeExpression precisionTimestampE(TypeExpression precision) { + return TypeExpression.PrecisionTimestamp.builder() + .nullable(nullable) + .precision(precision) + .build(); + } + + public TypeExpression precisionTimestampTZE(TypeExpression precision) { + return TypeExpression.PrecisionTimestampTZ.builder() + .nullable(nullable) + .precision(precision) + .build(); + } + public TypeExpression structE(TypeExpression... types) { return TypeExpression.Struct.builder().nullable(nullable).addFields(types).build(); } diff --git a/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java b/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java index 72be9dd67..a30891e75 100644 --- a/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java +++ b/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java @@ -10,6 +10,10 @@ public interface TypeExpressionVisitor R visit(TypeExpression.Decimal expr) throws E; + R visit(TypeExpression.PrecisionTimestamp expr) throws E; + + R visit(TypeExpression.PrecisionTimestampTZ expr) throws E; + R visit(TypeExpression.Struct expr) throws E; R visit(TypeExpression.ListType expr) throws E; @@ -54,6 +58,16 @@ public R visit(TypeExpression.Decimal expr) throws E { throw t(); } + @Override + public R visit(TypeExpression.PrecisionTimestamp expr) throws E { + throw t(); + } + + @Override + public R visit(TypeExpression.PrecisionTimestampTZ expr) throws E { + throw t(); + } + @Override public R visit(TypeExpression.Struct expr) throws E { throw t(); diff --git a/core/src/main/java/io/substrait/relation/ConsistentPartitionWindow.java b/core/src/main/java/io/substrait/relation/ConsistentPartitionWindow.java index f0f7cc781..d1344581c 100644 --- a/core/src/main/java/io/substrait/relation/ConsistentPartitionWindow.java +++ b/core/src/main/java/io/substrait/relation/ConsistentPartitionWindow.java @@ -9,7 +9,6 @@ import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.util.List; -import java.util.Map; import java.util.stream.Stream; import org.immutables.value.Value; @@ -49,7 +48,7 @@ public abstract static class WindowRelFunctionInvocation { public abstract List arguments(); - public abstract Map options(); + public abstract List options(); public abstract Type outputType(); diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index 179b25094..c88f7e68d 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -98,6 +98,16 @@ public Optional visit(Expression.TimestampTZLiteral expr) throws EXC return visitLiteral(expr); } + @Override + public Optional visit(Expression.PrecisionTimestampLiteral expr) throws EXCEPTION { + return visitLiteral(expr); + } + + @Override + public Optional visit(Expression.PrecisionTimestampTZLiteral expr) throws EXCEPTION { + return visitLiteral(expr); + } + @Override public Optional visit(Expression.IntervalYearLiteral expr) throws EXCEPTION { return visitLiteral(expr); diff --git a/core/src/main/java/io/substrait/relation/ProtoAggregateFunctionConverter.java b/core/src/main/java/io/substrait/relation/ProtoAggregateFunctionConverter.java index 34ba0a1e3..92ca23707 100644 --- a/core/src/main/java/io/substrait/relation/ProtoAggregateFunctionConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoAggregateFunctionConverter.java @@ -3,12 +3,14 @@ import io.substrait.expression.AggregateFunctionInvocation; import io.substrait.expression.Expression; import io.substrait.expression.FunctionArg; +import io.substrait.expression.FunctionOption; import io.substrait.expression.proto.ProtoExpressionConverter; import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; import io.substrait.type.proto.ProtoTypeConverter; import java.io.IOException; import java.util.List; +import java.util.stream.Collectors; import java.util.stream.IntStream; /** @@ -37,7 +39,7 @@ public ProtoAggregateFunctionConverter( this.protoExpressionConverter = protoExpressionConverter; } - public io.substrait.relation.Aggregate.Measure from( + public io.substrait.expression.AggregateFunctionInvocation from( io.substrait.proto.AggregateFunction measure) { FunctionArg.ProtoFrom protoFrom = new FunctionArg.ProtoFrom(protoExpressionConverter, protoTypeConverter); @@ -47,15 +49,17 @@ public io.substrait.relation.Aggregate.Measure from( IntStream.range(0, measure.getArgumentsCount()) .mapToObj(i -> protoFrom.convert(aggregateFunction, i, measure.getArguments(i))) .collect(java.util.stream.Collectors.toList()); - return Aggregate.Measure.builder() - .function( - AggregateFunctionInvocation.builder() - .arguments(functionArgs) - .declaration(aggregateFunction) - .outputType(protoTypeConverter.from(measure.getOutputType())) - .aggregationPhase(Expression.AggregationPhase.fromProto(measure.getPhase())) - .invocation(Expression.AggregationInvocation.fromProto(measure.getInvocation())) - .build()) + List options = + measure.getOptionsList().stream() + .map(ProtoExpressionConverter::fromFunctionOption) + .collect(Collectors.toList()); + return AggregateFunctionInvocation.builder() + .arguments(functionArgs) + .declaration(aggregateFunction) + .outputType(protoTypeConverter.from(measure.getOutputType())) + .aggregationPhase(Expression.AggregationPhase.fromProto(measure.getPhase())) + .invocation(Expression.AggregationInvocation.fromProto(measure.getInvocation())) + .options(options) .build(); } } diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 18808303c..2d66c5c4c 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -1,6 +1,5 @@ package io.substrait.relation; -import io.substrait.expression.AggregateFunctionInvocation; import io.substrait.expression.Expression; import io.substrait.expression.FunctionArg; import io.substrait.expression.ImmutableExpression; @@ -334,12 +333,11 @@ private VirtualTableScan newVirtualTable(ReadRel rel) { .collect(java.util.stream.Collectors.toList())) .build()); } - var fieldNames = - rel.getBaseSchema().getNamesList().stream().collect(java.util.stream.Collectors.toList()); + var builder = VirtualTableScan.builder() .filter(Optional.ofNullable(rel.hasFilter() ? converter.from(rel.getFilter()) : null)) - .addAllDfsNames(fieldNames) + .initialSchema(NamedStruct.fromProto(rel.getBaseSchema(), protoTypeConverter)) .rows(structLiterals); builder @@ -353,7 +351,12 @@ private VirtualTableScan newVirtualTable(ReadRel rel) { private Fetch newFetch(FetchRel rel) { var input = from(rel.getInput()); - var builder = Fetch.builder().input(input).count(rel.getCount()).offset(rel.getOffset()); + var builder = Fetch.builder().input(input).offset(rel.getOffset()); + if (rel.getCount() != -1) { + // -1 is used as a sentinel value to signal LIMIT ALL + // count only needs to be set when it is not -1 + builder.count(rel.getCount()); + } builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) @@ -388,6 +391,9 @@ private Aggregate newAggregate(AggregateRel rel) { var input = from(rel.getInput()); var protoExprConverter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); + var protoAggrFuncConverter = + new ProtoAggregateFunctionConverter(lookup, extensions, protoExprConverter); + List groupings = new ArrayList<>(rel.getGroupingsCount()); for (var grouping : rel.getGroupingsList()) { groupings.add( @@ -409,14 +415,7 @@ private Aggregate newAggregate(AggregateRel rel) { .collect(java.util.stream.Collectors.toList()); measures.add( Aggregate.Measure.builder() - .function( - AggregateFunctionInvocation.builder() - .arguments(args) - .declaration(funcDecl) - .outputType(protoTypeConverter.from(func.getOutputType())) - .aggregationPhase(Expression.AggregationPhase.fromProto(func.getPhase())) - .invocation(Expression.AggregationInvocation.fromProto(func.getInvocation())) - .build()) + .function(protoAggrFuncConverter.from(measure.getMeasure())) .preMeasureFilter( Optional.ofNullable( measure.hasFilter() ? protoExprConverter.from(measure.getFilter()) : null)) @@ -659,15 +658,18 @@ private AdvancedExtension advancedExtension( if (advancedExtension.hasEnhancement()) { builder.enhancement(enhancementFromAdvancedExtension(advancedExtension.getEnhancement())); } - if (advancedExtension.hasOptimization()) { - builder.optimization(optimizationFromAdvancedExtension(advancedExtension.getOptimization())); - } + advancedExtension + .getOptimizationList() + .forEach( + optimization -> + builder.addOptimizations(optimizationFromAdvancedExtension(optimization))); + return builder.build(); } /** * Override to provide a custom converter for {@link - * io.substrait.proto.AdvancedExtension#getOptimization()} data + * io.substrait.proto.AdvancedExtension#getOptimizationList()} ()} data */ protected Extension.Optimization optimizationFromAdvancedExtension(com.google.protobuf.Any any) { return new EmptyOptimization(); diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index cebf9faa3..44dcc681c 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -15,7 +15,6 @@ import io.substrait.proto.ExtensionSingleRel; import io.substrait.proto.FetchRel; import io.substrait.proto.FilterRel; -import io.substrait.proto.FunctionOption; import io.substrait.proto.HashJoinRel; import io.substrait.proto.JoinRel; import io.substrait.proto.MergeJoinRel; @@ -117,7 +116,11 @@ private AggregateRel.Measure toProto(Aggregate.Measure measure) { .collect(java.util.stream.Collectors.toList())) .addAllSorts(toProtoS(measure.getFunction().sort())) .setFunctionReference( - functionCollector.getFunctionReference(measure.getFunction().declaration())); + functionCollector.getFunctionReference(measure.getFunction().declaration())) + .addAllOptions( + measure.getFunction().options().stream() + .map(ExpressionProtoConverter::from) + .collect(java.util.stream.Collectors.toList())); var builder = AggregateRel.Measure.newBuilder().setMeasure(func); @@ -149,9 +152,9 @@ public Rel visit(Fetch fetch) throws RuntimeException { FetchRel.newBuilder() .setCommon(common(fetch)) .setInput(toProto(fetch.getInput())) - .setOffset(fetch.getOffset()); - - fetch.getCount().ifPresent(f -> builder.setCount(f)); + .setOffset(fetch.getOffset()) + // -1 is used as a sentinel value to signal LIMIT ALL + .setCount(fetch.getCount().orElse(-1)); fetch.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); return Rel.newBuilder().setFetch(builder).build(); @@ -345,13 +348,8 @@ private List toProtoWindowRelFun .mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor)) .collect(Collectors.toList()); var options = - f.options().entrySet().stream() - .map( - o -> - FunctionOption.newBuilder() - .setName(o.getKey()) - .addAllPreference(o.getValue().values()) - .build()) + f.options().stream() + .map(ExpressionProtoConverter::from) .collect(java.util.stream.Collectors.toList()); return ConsistentPartitionWindowRel.WindowRelFunction.newBuilder() diff --git a/core/src/main/java/io/substrait/relation/VirtualTableScan.java b/core/src/main/java/io/substrait/relation/VirtualTableScan.java index ad596e7a2..c35dab8cb 100644 --- a/core/src/main/java/io/substrait/relation/VirtualTableScan.java +++ b/core/src/main/java/io/substrait/relation/VirtualTableScan.java @@ -1,16 +1,14 @@ package io.substrait.relation; import io.substrait.expression.Expression; -import io.substrait.type.NamedStruct; import io.substrait.type.Type; +import io.substrait.type.TypeVisitor; import java.util.List; import org.immutables.value.Value; @Value.Immutable public abstract class VirtualTableScan extends AbstractReadRel { - public abstract List getDfsNames(); - public abstract List getRows(); /** @@ -25,17 +23,17 @@ public abstract class VirtualTableScan extends AbstractReadRel { */ @Value.Check protected void check() { - var names = getDfsNames(); + var names = getInitialSchema().names(); + + assert names.size() + == NamedFieldCountingTypeVisitor.countNames(this.getInitialSchema().struct()); var rows = getRows(); assert rows.size() > 0 && names.stream().noneMatch(s -> s == null) - && rows.stream().noneMatch(r -> r == null || r.fields().size() != names.size()); - } - - @Override - public final NamedStruct getInitialSchema() { - return NamedStruct.of(getDfsNames(), (Type.Struct) getRows().get(0).getType()); + && rows.stream().noneMatch(r -> r == null) + && rows.stream() + .allMatch(r -> NamedFieldCountingTypeVisitor.countNames(r.getType()) == names.size()); } @Override @@ -46,4 +44,147 @@ public O accept(RelVisitor visitor) throws E { public static ImmutableVirtualTableScan.Builder builder() { return ImmutableVirtualTableScan.builder(); } + + private static class NamedFieldCountingTypeVisitor + implements TypeVisitor { + + private static final NamedFieldCountingTypeVisitor VISITOR = + new NamedFieldCountingTypeVisitor(); + + private static Integer countNames(Type type) { + return type.accept(VISITOR); + } + + @Override + public Integer visit(Type.Bool type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.I8 type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.I16 type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.I32 type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.I64 type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.FP32 type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.FP64 type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.Str type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.Binary type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.Date type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.Time type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.TimestampTZ type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.Timestamp type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.PrecisionTimestamp type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.PrecisionTimestampTZ type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.IntervalYear type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.IntervalDay type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.UUID type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.FixedChar type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.VarChar type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.FixedBinary type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.Decimal type) throws RuntimeException { + return 0; + } + + @Override + public Integer visit(Type.Struct type) throws RuntimeException { + // Only struct fields have names - the top level column names are also + // captured by this since the whole schema is wrapped in a Struct type + return type.fields().stream().mapToInt(field -> 1 + field.accept(this)).sum(); + } + + @Override + public Integer visit(Type.ListType type) throws RuntimeException { + return type.elementType().accept(this); + } + + @Override + public Integer visit(Type.Map type) throws RuntimeException { + return type.key().accept(this) + type.value().accept(this); + } + + @Override + public Integer visit(Type.UserDefined type) throws RuntimeException { + return 0; + } + } } diff --git a/core/src/main/java/io/substrait/relation/extensions/EmptyOptimization.java b/core/src/main/java/io/substrait/relation/extensions/EmptyOptimization.java index 74cc40ad4..46a11888d 100644 --- a/core/src/main/java/io/substrait/relation/extensions/EmptyOptimization.java +++ b/core/src/main/java/io/substrait/relation/extensions/EmptyOptimization.java @@ -1,11 +1,12 @@ package io.substrait.relation.extensions; import com.google.protobuf.Any; +import io.substrait.proto.AdvancedExtension; import io.substrait.relation.Extension; /** - * Default type to which {@link io.substrait.proto.AdvancedExtension#getOptimization()} data is - * converted to by the {@link io.substrait.relation.ProtoRelConverter} + * Default type to which {@link AdvancedExtension#getOptimizationList()} data is converted to by the + * {@link io.substrait.relation.ProtoRelConverter} */ public class EmptyOptimization implements Extension.Optimization { @Override diff --git a/core/src/main/java/io/substrait/type/StringTypeVisitor.java b/core/src/main/java/io/substrait/type/StringTypeVisitor.java index 17fa8f523..668b770cd 100644 --- a/core/src/main/java/io/substrait/type/StringTypeVisitor.java +++ b/core/src/main/java/io/substrait/type/StringTypeVisitor.java @@ -109,6 +109,16 @@ public String visit(Type.Decimal type) throws RuntimeException { return String.format("decimal<%d,%d>%s", type.precision(), type.scale(), n(type)); } + @Override + public String visit(Type.PrecisionTimestamp type) throws RuntimeException { + return String.format("precision_timestamp<%d>%s", type.precision(), n(type)); + } + + @Override + public String visit(Type.PrecisionTimestampTZ type) throws RuntimeException { + return String.format("precision_timestamp_tz<%d>%s", type.precision(), n(type)); + } + @Override public String visit(Type.Struct type) throws RuntimeException { return String.format( diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index d0f14c047..43d1fd60c 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -155,8 +155,12 @@ public R accept(final TypeVisitor typeVisitor) th } } + /** Deprecated, use {@link PrecisionTimestampTZ} instead */ @Value.Immutable + @Deprecated abstract static class TimestampTZ implements Type { + + /** Deprecated, use {@link PrecisionTimestampTZ#builder()} instead */ public static ImmutableType.TimestampTZ.Builder builder() { return ImmutableType.TimestampTZ.builder(); } @@ -167,8 +171,13 @@ public R accept(final TypeVisitor typeVisitor) th } } + /** Deprecated, use {@link PrecisionTimestamp} instead */ @Value.Immutable + @Deprecated abstract static class Timestamp implements Type { + + /** Deprecated, use {@link PrecisionTimestamp#builder()} instead */ + @Deprecated public static ImmutableType.Timestamp.Builder builder() { return ImmutableType.Timestamp.builder(); } @@ -273,6 +282,34 @@ public R accept(final TypeVisitor typeVisitor) th } } + @Value.Immutable + abstract static class PrecisionTimestamp implements Type { + public abstract int precision(); + + public static ImmutableType.PrecisionTimestamp.Builder builder() { + return ImmutableType.PrecisionTimestamp.builder(); + } + + @Override + public R accept(final TypeVisitor typeVisitor) throws E { + return typeVisitor.visit(this); + } + } + + @Value.Immutable + abstract static class PrecisionTimestampTZ implements Type { + public abstract int precision(); + + public static ImmutableType.PrecisionTimestampTZ.Builder builder() { + return ImmutableType.PrecisionTimestampTZ.builder(); + } + + @Override + public R accept(final TypeVisitor typeVisitor) throws E { + return typeVisitor.visit(this); + } + } + @Value.Immutable abstract static class Struct implements Type { public abstract java.util.List fields(); diff --git a/core/src/main/java/io/substrait/type/TypeCreator.java b/core/src/main/java/io/substrait/type/TypeCreator.java index 0a7943e13..adaedd0b1 100644 --- a/core/src/main/java/io/substrait/type/TypeCreator.java +++ b/core/src/main/java/io/substrait/type/TypeCreator.java @@ -66,6 +66,14 @@ public final Type.Struct struct(Type... types) { return Type.Struct.builder().nullable(nullable).addFields(types).build(); } + public final Type precisionTimestamp(int precision) { + return Type.PrecisionTimestamp.builder().nullable(nullable).precision(precision).build(); + } + + public final Type precisionTimestampTZ(int precision) { + return Type.PrecisionTimestampTZ.builder().nullable(nullable).precision(precision).build(); + } + public Type.Struct struct(Iterable types) { return Type.Struct.builder().nullable(nullable).addAllFields(types).build(); } @@ -215,6 +223,16 @@ public Type visit(Type.Decimal type) throws RuntimeException { return Type.Decimal.builder().from(type).nullable(nullability).build(); } + @Override + public Type visit(Type.PrecisionTimestamp type) throws RuntimeException { + return Type.PrecisionTimestamp.builder().from(type).nullable(nullability).build(); + } + + @Override + public Type visit(Type.PrecisionTimestampTZ type) throws RuntimeException { + return Type.PrecisionTimestampTZ.builder().from(type).nullable(nullability).build(); + } + @Override public Type visit(Type.Struct type) throws RuntimeException { return Type.Struct.builder().from(type).nullable(nullability).build(); diff --git a/core/src/main/java/io/substrait/type/TypeVisitor.java b/core/src/main/java/io/substrait/type/TypeVisitor.java index 4eae3362a..9d377499c 100644 --- a/core/src/main/java/io/substrait/type/TypeVisitor.java +++ b/core/src/main/java/io/substrait/type/TypeVisitor.java @@ -23,10 +23,16 @@ public interface TypeVisitor { R visit(Type.Time type) throws E; + @Deprecated R visit(Type.TimestampTZ type) throws E; + @Deprecated R visit(Type.Timestamp type) throws E; + R visit(Type.PrecisionTimestamp type) throws E; + + R visit(Type.PrecisionTimestampTZ type) throws E; + R visit(Type.IntervalYear type) throws E; R visit(Type.IntervalDay type) throws E; @@ -162,6 +168,16 @@ public R visit(Type.Decimal type) throws E { throw t(); } + @Override + public R visit(Type.PrecisionTimestamp type) throws E { + throw t(); + } + + @Override + public R visit(Type.PrecisionTimestampTZ type) throws E { + throw t(); + } + @Override public R visit(Type.Struct type) throws E { throw t(); diff --git a/core/src/main/java/io/substrait/type/YamlRead.java b/core/src/main/java/io/substrait/type/YamlRead.java index 9d3f1230d..89f8a9163 100644 --- a/core/src/main/java/io/substrait/type/YamlRead.java +++ b/core/src/main/java/io/substrait/type/YamlRead.java @@ -56,7 +56,7 @@ private static Stream parse(String name) { .registerModule(Deserializers.MODULE); var doc = mapper.readValue(new File(name), SimpleExtension.ExtensionSignatures.class); - logger.debug( + logger.atDebug().log( "Parsed {} functions in file {}.", Optional.ofNullable(doc.scalars()).map(List::size).orElse(0) + Optional.ofNullable(doc.aggregates()).map(List::size).orElse(0), diff --git a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java index a79a2ebec..9ae85f438 100644 --- a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java +++ b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java @@ -237,6 +237,40 @@ public TypeExpression visitDecimal(final SubstraitTypeParser.DecimalContext ctx) return withNullE(nullable).decimalE(ctx.precision.accept(this), ctx.scale.accept(this)); } + @Override + public TypeExpression visitPrecisionTimestamp( + final SubstraitTypeParser.PrecisionTimestampContext ctx) { + boolean nullable = ctx.isnull != null; + Object precision = i(ctx.precision); + if (precision instanceof Integer p) { + return withNull(nullable).precisionTimestamp(p); + } + if (precision instanceof String s) { + checkParameterizedOrExpression(); + return withNullP(nullable).precisionTimestampE(s); + } + + checkExpression(); + return withNullE(nullable).precisionTimestampE(ctx.precision.accept(this)); + } + + @Override + public TypeExpression visitPrecisionTimestampTZ( + final SubstraitTypeParser.PrecisionTimestampTZContext ctx) { + boolean nullable = ctx.isnull != null; + Object precision = i(ctx.precision); + if (precision instanceof Integer p) { + return withNull(nullable).precisionTimestampTZ(p); + } + if (precision instanceof String s) { + checkParameterizedOrExpression(); + return withNullP(nullable).precisionTimestampTZE(s); + } + + checkExpression(); + return withNullE(nullable).precisionTimestampTZE(ctx.precision.accept(this)); + } + private Object i(SubstraitTypeParser.NumericParameterContext ctx) { TypeExpression type = ctx.accept(this); if (type instanceof TypeExpression.IntegerLiteral) { diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index 913bcde71..f21a03f79 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -124,6 +124,16 @@ public final T visit(final Type.Decimal expr) { return typeContainer(expr).decimal(expr.scale(), expr.precision()); } + @Override + public final T visit(final Type.PrecisionTimestamp expr) { + return typeContainer(expr).precisionTimestamp(expr.precision()); + } + + @Override + public final T visit(final Type.PrecisionTimestampTZ expr) { + return typeContainer(expr).precisionTimestampTZ(expr.precision()); + } + @Override public final T visit(final Type.Struct expr) { return typeContainer(expr) diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java index 8c63a7d73..ac4c13521 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java @@ -81,6 +81,14 @@ public final T decimal(int scale, I precision) { return decimal(i(scale), precision); } + public final T precisionTimestamp(int precision) { + return precisionTimestamp(i(precision)); + } + + public final T precisionTimestampTZ(int precision) { + return precisionTimestampTZ(i(precision)); + } + public abstract T typeParam(String name); public abstract I integerParam(String name); @@ -91,6 +99,10 @@ public final T decimal(int scale, I precision) { public abstract T decimal(I scale, I precision); + public abstract T precisionTimestamp(I precision); + + public abstract T precisionTimestampTZ(I precision); + public final T struct(T... types) { return struct(Arrays.asList(types)); } diff --git a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java index 070cf92cc..f2270b822 100644 --- a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java @@ -54,6 +54,18 @@ public ParameterizedType visit(io.substrait.function.ParameterizedType.Decimal e return typeContainer(expr).decimal(i(expr.precision()), i(expr.scale())); } + @Override + public ParameterizedType visit(io.substrait.function.ParameterizedType.PrecisionTimestamp expr) + throws RuntimeException { + return typeContainer(expr).precisionTimestamp(i(expr.precision())); + } + + @Override + public ParameterizedType visit(io.substrait.function.ParameterizedType.PrecisionTimestampTZ expr) + throws RuntimeException { + return typeContainer(expr).precisionTimestampTZ(i(expr.precision())); + } + @Override public ParameterizedType visit(io.substrait.function.ParameterizedType.Struct expr) throws RuntimeException { @@ -174,6 +186,24 @@ public ParameterizedType decimal( .build()); } + @Override + public ParameterizedType precisionTimestamp(ParameterizedType.IntegerOption precision) { + return wrap( + ParameterizedType.ParameterizedPrecisionTimestamp.newBuilder() + .setPrecision(precision) + .setNullability(nullability) + .build()); + } + + @Override + public ParameterizedType precisionTimestampTZ(ParameterizedType.IntegerOption precision) { + return wrap( + ParameterizedType.ParameterizedPrecisionTimestampTZ.newBuilder() + .setPrecision(precision) + .setNullability(nullability) + .build()); + } + public ParameterizedType struct(Iterable types) { return wrap( ParameterizedType.ParameterizedStruct.newBuilder() @@ -246,6 +276,10 @@ protected ParameterizedType wrap(final Object o) { return bldr.setFixedBinary(t).build(); } else if (o instanceof ParameterizedType.ParameterizedDecimal t) { return bldr.setDecimal(t).build(); + } else if (o instanceof ParameterizedType.ParameterizedPrecisionTimestamp t) { + return bldr.setPrecisionTimestamp(t).build(); + } else if (o instanceof ParameterizedType.ParameterizedPrecisionTimestampTZ t) { + return bldr.setPrecisionTimestampTz(t).build(); } else if (o instanceof ParameterizedType.ParameterizedStruct t) { return bldr.setStruct(t).build(); } else if (o instanceof ParameterizedType.ParameterizedList t) { diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index 9208970f3..a4352c463 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -42,6 +42,10 @@ public Type from(io.substrait.proto.Type type) { .fixedBinary(type.getFixedBinary().getLength()); case DECIMAL -> n(type.getDecimal().getNullability()) .decimal(type.getDecimal().getPrecision(), type.getDecimal().getScale()); + case PRECISION_TIMESTAMP -> n(type.getPrecisionTimestamp().getNullability()) + .precisionTimestamp(type.getPrecisionTimestamp().getPrecision()); + case PRECISION_TIMESTAMP_TZ -> n(type.getPrecisionTimestampTz().getNullability()) + .precisionTimestampTZ(type.getPrecisionTimestampTz().getPrecision()); case STRUCT -> n(type.getStruct().getNullability()) .struct( type.getStruct().getTypesList().stream() diff --git a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java index b32d0d71d..07e5b1071 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java +++ b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java @@ -121,6 +121,16 @@ public DerivationExpression visit(ParameterizedType.Decimal expr) { return typeContainer(expr).decimal(expr.precision().accept(this), expr.scale().accept(this)); } + @Override + public DerivationExpression visit(ParameterizedType.PrecisionTimestamp expr) { + return typeContainer(expr).precisionTimestamp(expr.precision().accept(this)); + } + + @Override + public DerivationExpression visit(TypeExpression.PrecisionTimestampTZ expr) { + return typeContainer(expr).precisionTimestampTZ(expr.precision().accept(this)); + } + @Override public DerivationExpression visit(ParameterizedType.Struct expr) { return typeContainer(expr) @@ -235,6 +245,24 @@ public DerivationExpression decimal( .build()); } + @Override + public DerivationExpression precisionTimestamp(DerivationExpression precision) { + return wrap( + DerivationExpression.ExpressionPrecisionTimestamp.newBuilder() + .setPrecision(precision) + .setNullability(nullability) + .build()); + } + + @Override + public DerivationExpression precisionTimestampTZ(DerivationExpression precision) { + return wrap( + DerivationExpression.ExpressionPrecisionTimestampTZ.newBuilder() + .setPrecision(precision) + .setNullability(nullability) + .build()); + } + public DerivationExpression struct(Iterable types) { return wrap( DerivationExpression.ExpressionStruct.newBuilder() @@ -311,6 +339,10 @@ protected DerivationExpression wrap(final Object o) { return bldr.setFixedBinary(t).build(); } else if (o instanceof DerivationExpression.ExpressionDecimal t) { return bldr.setDecimal(t).build(); + } else if (o instanceof DerivationExpression.ExpressionPrecisionTimestamp t) { + return bldr.setPrecisionTimestamp(t).build(); + } else if (o instanceof DerivationExpression.ExpressionPrecisionTimestampTZ t) { + return bldr.setPrecisionTimestampTz(t).build(); } else if (o instanceof DerivationExpression.ExpressionStruct t) { return bldr.setStruct(t).build(); } else if (o instanceof DerivationExpression.ExpressionList t) { diff --git a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java index ce5ffb786..e21cc158a 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java @@ -61,6 +61,22 @@ public Type decimal(Integer scale, Integer precision) { .build()); } + public Type precisionTimestamp(Integer precision) { + return wrap( + Type.PrecisionTimestamp.newBuilder() + .setPrecision(precision) + .setNullability(nullability) + .build()); + } + + public Type precisionTimestampTZ(Integer precision) { + return wrap( + Type.PrecisionTimestampTZ.newBuilder() + .setPrecision(precision) + .setNullability(nullability) + .build()); + } + public Type struct(Iterable types) { return wrap(Type.Struct.newBuilder().addAllTypes(types).setNullability(nullability).build()); } @@ -121,6 +137,10 @@ protected Type wrap(final Object o) { return bldr.setFixedBinary(t).build(); } else if (o instanceof Type.Decimal t) { return bldr.setDecimal(t).build(); + } else if (o instanceof Type.PrecisionTimestamp t) { + return bldr.setPrecisionTimestamp(t).build(); + } else if (o instanceof Type.PrecisionTimestampTZ t) { + return bldr.setPrecisionTimestampTz(t).build(); } else if (o instanceof Type.Struct t) { return bldr.setStruct(t).build(); } else if (o instanceof Type.List t) { diff --git a/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java b/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java index 1f16e29fa..51e0dfae3 100644 --- a/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java +++ b/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java @@ -36,12 +36,12 @@ Rel relWithExtension(AdvancedExtension advancedExtension) { Rel emptyAdvancedExtension = relWithExtension(AdvancedExtension.builder().build()); Rel advancedExtensionWithOptimization = - relWithExtension(AdvancedExtension.builder().optimization(OPTIMIZED).build()); + relWithExtension(AdvancedExtension.builder().addOptimizations(OPTIMIZED).build()); Rel advancedExtensionWithEnhancement = relWithExtension(AdvancedExtension.builder().enhancement(ENHANCED).build()); Rel advancedExtensionWithEnhancementAndOptimization = relWithExtension( - AdvancedExtension.builder().enhancement(ENHANCED).optimization(OPTIMIZED).build()); + AdvancedExtension.builder().enhancement(ENHANCED).addOptimizations(OPTIMIZED).build()); @Test void emptyAdvancedExtension() { diff --git a/core/src/test/java/io/substrait/relation/VirtualTableScanTest.java b/core/src/test/java/io/substrait/relation/VirtualTableScanTest.java new file mode 100644 index 000000000..ca6fceaa7 --- /dev/null +++ b/core/src/test/java/io/substrait/relation/VirtualTableScanTest.java @@ -0,0 +1,69 @@ +package io.substrait.relation; + +import static io.substrait.expression.ExpressionCreator.list; +import static io.substrait.expression.ExpressionCreator.map; +import static io.substrait.expression.ExpressionCreator.string; +import static io.substrait.expression.ExpressionCreator.struct; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import io.substrait.type.NamedStruct; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Test; + +class VirtualTableScanTest extends TestBase { + + @Test + void check() { + VirtualTableScan virtualTableScan = + ImmutableVirtualTableScan.builder() + .initialSchema( + NamedStruct.of( + Arrays.stream( + new String[] { + "string", + "struct", + "struct_field1", + "struct_field2", + "list", + "list_struct_field1", + "map", + "map_key_struct_field1", + "map_value_struct_field1" + }) + .collect(Collectors.toList()), + R.struct( + R.STRING, + R.struct(R.STRING, R.STRING), + R.list(R.struct(R.STRING)), + R.map(R.struct(R.STRING), R.struct(R.STRING))))) + .addRows( + struct( + false, + string(false, "string_val"), + struct( + false, + string(false, "struct_field1_val"), + string(false, "struct_field2_val")), + list(false, struct(false, string(false, "list_struct_field1_val"))), + map( + false, + mapOf( + struct(false, string(false, "map_key_struct_field1_val")), + struct(false, string(false, "map_value_struct_field1_val")))))) + .build(); + assertDoesNotThrow(virtualTableScan::check); + } + + private Map mapOf( + Expression.Literal key, Expression.Literal value) { + // Map.of() comes only in Java 9 and the "core" module is on Java 8 + HashMap map = new HashMap<>(); + map.put(key, value); + return map; + } +} diff --git a/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java index 90c68e1d8..0485716f1 100644 --- a/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java @@ -6,6 +6,7 @@ import io.substrait.expression.AggregateFunctionInvocation; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FunctionOption; import io.substrait.expression.ImmutableExpression; import io.substrait.extension.ExtensionCollector; import io.substrait.relation.Aggregate; @@ -13,6 +14,7 @@ import io.substrait.relation.ProtoRelConverter; import io.substrait.relation.RelProtoConverter; import io.substrait.relation.VirtualTableScan; +import io.substrait.type.NamedStruct; import io.substrait.type.TypeCreator; import java.io.IOException; import java.math.BigDecimal; @@ -25,8 +27,12 @@ public class AggregateRoundtripTest extends TestBase { private void assertAggregateRoundtrip(Expression.AggregationInvocation invocation) { var expression = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); Expression.StructLiteral literal = - ImmutableExpression.StructLiteral.builder().from(expression).build(); - var input = VirtualTableScan.builder().addRows(literal).build(); + ImmutableExpression.StructLiteral.builder().addFields(expression).build(); + var input = + VirtualTableScan.builder() + .initialSchema(NamedStruct.of(Arrays.asList("decimal"), R.struct(R.decimal(10, 2)))) + .addRows(literal) + .build(); ExtensionCollector functionCollector = new ExtensionCollector(); var to = new RelProtoConverter(functionCollector); var extensions = defaultExtensionCollection; @@ -41,6 +47,12 @@ private void assertAggregateRoundtrip(Expression.AggregationInvocation invocatio .outputType(TypeCreator.of(false).I64) .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) .invocation(invocation) + .options( + Arrays.asList( + FunctionOption.builder() + .name("option") + .addValues("VALUE1", "VALUE2") + .build())) .build()) .build(); diff --git a/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java index 45077c4f7..13036d048 100644 --- a/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java @@ -4,6 +4,7 @@ import io.substrait.TestBase; import io.substrait.expression.Expression; +import io.substrait.expression.FunctionOption; import io.substrait.expression.ImmutableWindowBound; import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.SimpleExtension; @@ -11,7 +12,6 @@ import io.substrait.relation.ImmutableConsistentPartitionWindow; import io.substrait.relation.Rel; import java.util.Arrays; -import java.util.Collections; import org.junit.jupiter.api.Test; public class ConsistentPartitionWindowRelRoundtripTest extends TestBase { @@ -36,7 +36,12 @@ void consistentPartitionWindowRoundtrip() { .declaration(windowFunctionDeclaration) // lead(a) .arguments(Arrays.asList(b.fieldReference(input, 0))) - .options(Collections.emptyMap()) + .options( + Arrays.asList( + FunctionOption.builder() + .name("option") + .addValues("VALUE1", "VALUE2") + .build())) .outputType(R.I64) .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) .invocation(Expression.AggregationInvocation.ALL) diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index 33081c0df..06ed71dec 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -52,7 +52,7 @@ public class ExtensionRoundtripTest extends TestBase { final AdvancedExtension commonExtension = AdvancedExtension.builder() .enhancement(new StringHolder("COMMON ENHANCEMENT")) - .optimization(new StringHolder("COMMON OPTIMIZATION")) + .addOptimizations(new StringHolder("COMMON OPTIMIZATION")) .build(); final StringHolder detail = new StringHolder("DETAIL"); @@ -60,7 +60,7 @@ public class ExtensionRoundtripTest extends TestBase { final AdvancedExtension relExtension = AdvancedExtension.builder() .enhancement(new StringHolder("REL ENHANCEMENT")) - .optimization(new StringHolder("REL OPTIMIZATION")) + .addOptimizations(new StringHolder("REL OPTIMIZATION")) .build(); @Override @@ -74,6 +74,7 @@ protected void verifyRoundTrip(Rel rel) { void virtualTable() { Rel rel = VirtualTableScan.builder() + .initialSchema(NamedStruct.of(Collections.emptyList(), R.struct())) .addRows(Expression.StructLiteral.builder().fields(Collections.emptyList()).build()) .commonExtension(commonExtension) .extension(relExtension) diff --git a/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java index 7b085f67a..028a07e07 100644 --- a/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java @@ -44,7 +44,10 @@ void emptyScan() { void virtualTable() { var virtTable = VirtualTableScan.builder() - .addAllDfsNames(Stream.of("column1", "column2").collect(Collectors.toList())) + .initialSchema( + NamedStruct.of( + Stream.of("column1", "column2").collect(Collectors.toList()), + R.struct(R.I64, R.I64))) .addRows( ExpressionCreator.struct( false, ExpressionCreator.i64(false, 1), ExpressionCreator.i64(false, 2))) diff --git a/gradle.properties b/gradle.properties index 31cf17e7d..0b1e19b21 100644 --- a/gradle.properties +++ b/gradle.properties @@ -13,14 +13,20 @@ org.jetbrains.gradle.plugin.idea-ext.version=0.5 kotlin.version=1.5.31 com.github.vlsi.vlsi-release-plugins.version=1.74 -calcite.version=1.28.0 -junit5.version=5.8.1 -protobuf.version=3.17.1 -slf4j.version=1.7.25 -jackson.version=2.12.4 +# library version +antlr.version=4.13.1 +calcite.version=1.37.0 +guava.version=32.1.3-jre +immutables.version=2.10.1 +jackson.version=2.16.1 +junit.version=5.8.1 +protobuf.version=3.25.3 +slf4j.version=2.0.13 +sparkbundle.version=3.4 +spark.version=3.4.2 #version that is going to be updated automatically by releases -version = 0.28.0 +version = 0.37.0 #signing SIGNING_KEY_ID = 193EAE47 diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 7454180f2..e6441136f 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index aa991fcea..a4413138c 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.4.2-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.8-bin.zip +networkTimeout=10000 +validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index 1b6c78733..b740cf133 100755 --- a/gradlew +++ b/gradlew @@ -55,7 +55,7 @@ # Darwin, MinGW, and NonStop. # # (3) This script is generated from the Groovy template -# https://github.com/gradle/gradle/blob/master/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# https://github.com/gradle/gradle/blob/HEAD/platforms/jvm/plugins-application/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt # within the Gradle project. # # You can find Gradle at https://github.com/gradle/gradle/. @@ -80,13 +80,11 @@ do esac done -APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit - -APP_NAME="Gradle" +# This is normally unused +# shellcheck disable=SC2034 APP_BASE_NAME=${0##*/} - -# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD=maximum @@ -133,22 +131,29 @@ location of your Java installation." fi else JAVACMD=java - which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the location of your Java installation." + fi fi # Increase the maximum file descriptors if we can. if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then case $MAX_FD in #( max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 MAX_FD=$( ulimit -H -n ) || warn "Could not query maximum file descriptor limit" esac case $MAX_FD in #( '' | soft) :;; #( *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 ulimit -n "$MAX_FD" || warn "Could not set maximum file descriptor limit to $MAX_FD" esac @@ -193,11 +198,15 @@ if "$cygwin" || "$msys" ; then done fi -# Collect all arguments for the java command; -# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of -# shell script including quotes and variable substitutions, so put them in -# double quotes to make sure that they get re-expanded; and -# * put everything else in single quotes, so that it's not re-expanded. + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. set -- \ "-Dorg.gradle.appname=$APP_BASE_NAME" \ @@ -205,6 +214,12 @@ set -- \ org.gradle.wrapper.GradleWrapperMain \ "$@" +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + # Use "xargs" to parse quoted args. # # With -n1 it outputs one arg per line, with the quotes and backslashes removed. diff --git a/gradlew.bat b/gradlew.bat index ac1b06f93..7101f8e46 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -14,7 +14,7 @@ @rem limitations under the License. @rem -@if "%DEBUG%" == "" @echo off +@if "%DEBUG%"=="" @echo off @rem ########################################################################## @rem @rem Gradle startup script for Windows @@ -25,7 +25,8 @@ if "%OS%"=="Windows_NT" setlocal set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% @@ -40,13 +41,13 @@ if defined JAVA_HOME goto findJavaFromJavaHome set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto execute +if %ERRORLEVEL% equ 0 goto execute -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail @@ -56,11 +57,11 @@ set JAVA_EXE=%JAVA_HOME%/bin/java.exe if exist "%JAVA_EXE%" goto execute -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail @@ -75,13 +76,15 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar :end @rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd +if %ERRORLEVEL% equ 0 goto mainEnd :fail rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% :mainEnd if "%OS%"=="Windows_NT" endlocal diff --git a/isthmus-cli/README.md b/isthmus-cli/README.md new file mode 100644 index 000000000..431159b0c --- /dev/null +++ b/isthmus-cli/README.md @@ -0,0 +1,389 @@ +# Isthmus-CLI + +## Overview + +Isthmus-CLI provides a native command-line interface to drive the [Isthmus](../isthmus) library. This can be used to convert SQL queries to [Substrait Protobuf](https://substrait.io/serialization/binary_serialization/) and SQL expressions to [Extended Expressions](https://substrait.io/expressions/extended_expression/) using the Calcite SQL compiler. + +## Build + +Isthmus can be built as a native executable via Graal + +``` +./gradlew nativeImage +``` + +## Usage + +### Version + +``` +$ ./isthmus-cli/build/graal/isthmus --version + +isthmus 0.1 +``` + +### Help + +``` +$ ./isthmus-cli/build/graal/isthmus --help + +Usage: isthmus [-hmV] [--crossjoinpolicy=] + [--outputformat=] + [--sqlconformancemode=] + [-c=]... [-e=...]... [] +Convert SQL Queries and SQL Expressions to Substrait + [] A SQL query + -c, --create= + One or multiple create table statements e.g. CREATE + TABLE T1(foo int, bar bigint) + --crossjoinpolicy= + One of built-in Calcite SQL compatibility modes: + KEEP_AS_CROSS_JOIN, CONVERT_TO_INNER_JOIN + -e, --expression=... + One or more SQL expressions e.g. col + 1 + -h, --help Show this help message and exit. + -m, --multistatement Allow multiple statements terminated with a semicolon + --outputformat= + Set the output format for the generated plan: + PROTOJSON, PROTOTEXT, BINARY + --sqlconformancemode= + One of built-in Calcite SQL compatibility modes: + DEFAULT, LENIENT, BABEL, STRICT_92, STRICT_99, + PRAGMATIC_99, BIG_QUERY, MYSQL_5, ORACLE_10, + ORACLE_12, STRICT_2003, PRAGMATIC_2003, PRESTO, + SQL_SERVER_2008 + -V, --version Print version information and exit. +``` + +## Example + +### SQL to Substrait Plan + +``` +> $ ./isthmus-cli/build/graal/isthmus \ + -c "CREATE TABLE Persons ( firstName VARCHAR, lastName VARCHAR, zip INT )" \ + "SELECT lastName, firstName FROM Persons WHERE zip = 90210" + +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "equal:any1_any1" + } + }], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [3, 4] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["FIRSTNAME", "LASTNAME", "ZIP"], + "struct": { + "types": [{ + "string": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "string": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["PERSONS"] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [{ + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, { + "literal": { + "i32": 90210, + "nullable": false, + "typeVariationReference": 0 + } + }], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }] + } + }, + "names": ["LASTNAME", "FIRSTNAME"] + } + }], + "expectedTypeUrls": [] +} +``` + +### SQL Expression to Substrait Extended Expression + +#### Projection + +``` +$ ./isthmus-cli/build/graal/isthmus -c "CREATE TABLE NATION (N_NATIONKEY BIGINT NOT NULL, N_NAME CHAR(25), N_REGIONKEY BIGINT NOT NULL, N_COMMENT VARCHAR(152))" \ + -e "N_REGIONKEY + 10" + +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "add:i64_i64" + } + }], + "referredExpr": [{ + "expression": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 10, + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }], + "options": [] + } + }, + "outputNames": ["new-column"] + }], + "baseSchema": { + "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "varchar": { + "length": 152, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "expectedTypeUrls": [] +} +``` + +#### Filter + +``` +$ ./isthmus-cli/build/graal/isthmus -c "CREATE TABLE NATION (N_NATIONKEY BIGINT NOT NULL, N_NAME CHAR(25), N_REGIONKEY BIGINT NOT NULL, N_COMMENT VARCHAR(152))" \ + -e "N_REGIONKEY > 10" + +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "gt:any_any" + } + }], + "referredExpr": [{ + "expression": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 10, + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }], + "options": [] + } + }, + "outputNames": ["new-column"] + }], + "baseSchema": { + "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "varchar": { + "length": 152, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "expectedTypeUrls": [] +} +``` diff --git a/isthmus-cli/build.gradle.kts b/isthmus-cli/build.gradle.kts new file mode 100644 index 000000000..b84301fe6 --- /dev/null +++ b/isthmus-cli/build.gradle.kts @@ -0,0 +1,139 @@ +plugins { + id("java") + id("idea") + id("com.palantir.graal") version "0.10.0" + id("com.diffplug.spotless") version "6.11.0" +} + +java { + toolchain { languageVersion.set(JavaLanguageVersion.of(17)) } + withJavadocJar() + withSourcesJar() +} + +val CALCITE_VERSION = properties.get("calcite.version") +val GUAVA_VERSION = properties.get("guava.version") +val IMMUTABLES_VERSION = properties.get("immutables.version") +val JACKSON_VERSION = properties.get("jackson.version") +val JUNIT_VERSION = properties.get("junit.version") +val PROTOBUF_VERSION = properties.get("protobuf.version") +val SLF4J_VERSION = properties.get("slf4j.version") + +dependencies { + implementation(project(":core")) + implementation(project(":isthmus")) + implementation("org.apache.calcite:calcite-core:${CALCITE_VERSION}") + implementation("org.apache.calcite:calcite-server:${CALCITE_VERSION}") + testImplementation("org.junit.jupiter:junit-jupiter:${JUNIT_VERSION}") + implementation("org.reflections:reflections:0.9.12") + implementation("com.google.guava:guava:${GUAVA_VERSION}") + implementation("org.graalvm.sdk:graal-sdk:22.1.0") + implementation("info.picocli:picocli:4.7.5") + annotationProcessor("info.picocli:picocli-codegen:4.7.5") + implementation("com.fasterxml.jackson.core:jackson-databind:${JACKSON_VERSION}") + implementation("com.google.protobuf:protobuf-java-util:${PROTOBUF_VERSION}") { + exclude("com.google.guava", "guava") + .because("Brings in Guava for Android, which we don't want (and breaks multimaps).") + } + implementation("org.immutables:value-annotations:${IMMUTABLES_VERSION}") + annotationProcessor("org.immutables:value:${IMMUTABLES_VERSION}") + testImplementation("org.apache.calcite:calcite-plus:${CALCITE_VERSION}") + annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") + compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") + runtimeOnly("org.slf4j:slf4j-jdk14:${SLF4J_VERSION}") +} + +val initializeAtBuildTime = + listOf( + "com.google.common.base.Platform", + "com.google.common.base.Preconditions", + "com.google.common.cache.CacheBuilder", + "com.google.common.cache.LocalCache", + "com.google.common.collect.CollectCollectors", + "com.google.common.collect.ImmutableRangeSet", + "com.google.common.collect.ImmutableSortedMap", + "com.google.common.collect.Platform", + "com.google.common.collect.Range", + "com.google.common.collect.RegularImmutableMap", + "com.google.common.collect.RegularImmutableSortedSet", + "com.google.common.math.IntMath", + "com.google.common.math.IntMath\$1", + "com.google.common.primitives.Primitives", + "com.google.common.util.concurrent.AbstractFuture", + "com.google.common.util.concurrent.AbstractFuture\$UnsafeAtomicHelper", + "com.google.common.util.concurrent.SettableFuture", + "io.substrait.isthmus.cli.InitializeAtBuildTime", + "io.substrait.isthmus.metadata.LambdaHandlerCache", + "io.substrait.isthmus.metadata.LambdaMetadataSupplier", + "io.substrait.isthmus.metadata.LegacyToLambdaGenerator", + "org.apache.calcite.config.CalciteSystemProperty", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$AllPredicates", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$Collation", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$ColumnOrigin", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$ColumnUniqueness", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$CumulativeCost", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$DistinctRowCount", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$Distribution", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$ExplainVisibility", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$ExpressionLineage", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$LowerBoundCost", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$MaxRowCount", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$Memory", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$MinRowCount", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$NodeTypes", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$NonCumulativeCost", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$Parallelism", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$PercentageOriginalRows", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$PopulationSize", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$Predicates", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$RowCount", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$Selectivity", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$Size", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$TableReferences", + "org.apache.calcite.rel.metadata.BuiltInMetadata\$UniqueKeys", + "org.apache.calcite.rel.metadata.RelMdAllPredicates", + "org.apache.calcite.rel.metadata.RelMdCollation", + "org.apache.calcite.rel.metadata.RelMdColumnOrigins", + "org.apache.calcite.rel.metadata.RelMdColumnUniqueness", + "org.apache.calcite.rel.metadata.RelMdDistinctRowCount", + "org.apache.calcite.rel.metadata.RelMdDistribution", + "org.apache.calcite.rel.metadata.RelMdExplainVisibility", + "org.apache.calcite.rel.metadata.RelMdExpressionLineage", + "org.apache.calcite.rel.metadata.RelMdLowerBoundCost", + "org.apache.calcite.rel.metadata.RelMdMaxRowCount", + "org.apache.calcite.rel.metadata.RelMdMemory", + "org.apache.calcite.rel.metadata.RelMdMinRowCount", + "org.apache.calcite.rel.metadata.RelMdNodeTypes", + "org.apache.calcite.rel.metadata.RelMdParallelism", + "org.apache.calcite.rel.metadata.RelMdPercentageOriginalRows", + "org.apache.calcite.rel.metadata.RelMdPopulationSize", + "org.apache.calcite.rel.metadata.RelMdPredicates", + "org.apache.calcite.rel.metadata.RelMdRowCount", + "org.apache.calcite.rel.metadata.RelMdSelectivity", + "org.apache.calcite.rel.metadata.RelMdSize", + "org.apache.calcite.rel.metadata.RelMdTableReferences", + "org.apache.calcite.rel.metadata.RelMdUniqueKeys", + "org.apache.calcite.util.Pair", + "org.apache.calcite.util.ReflectUtil", + "org.apache.calcite.util.Util", + "org.apache.commons.codec.language.Soundex", + "org.slf4j.LoggerFactory", + "org.slf4j.impl.JDK14LoggerAdapter", + "org.slf4j.impl.StaticLoggerBinder", + ) + .joinToString(",") + +graal { + mainClass("io.substrait.isthmus.cli.IsthmusEntryPoint") + outputName("isthmus") + graalVersion("22.1.0") + javaVersion("17") + option("--no-fallback") + option("--initialize-at-build-time=${initializeAtBuildTime}") + option("-H:IncludeResources=.*yaml") + option("--report-unsupported-elements-at-runtime") + option("-H:+ReportExceptionStackTraces") + option("-H:DynamicProxyConfigurationFiles=proxies.json") + option("--features=io.substrait.isthmus.cli.RegisterAtRuntime") + option("-J--enable-preview") +} diff --git a/isthmus/proxies.json b/isthmus-cli/proxies.json similarity index 100% rename from isthmus/proxies.json rename to isthmus-cli/proxies.json diff --git a/isthmus/src/main/java/io/substrait/isthmus/InitializeAtBuildTime.java b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/InitializeAtBuildTime.java similarity index 53% rename from isthmus/src/main/java/io/substrait/isthmus/InitializeAtBuildTime.java rename to isthmus-cli/src/main/java/io/substrait/isthmus/cli/InitializeAtBuildTime.java index d4b5fca78..01c35b535 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/InitializeAtBuildTime.java +++ b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/InitializeAtBuildTime.java @@ -1,3 +1,3 @@ -package io.substrait.isthmus; +package io.substrait.isthmus.cli; public class InitializeAtBuildTime {} diff --git a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java similarity index 95% rename from isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java rename to isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java index eac6fbed3..37a380562 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java +++ b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java @@ -1,4 +1,4 @@ -package io.substrait.isthmus; +package io.substrait.isthmus.cli; import static picocli.CommandLine.Command; import static picocli.CommandLine.Option; @@ -9,6 +9,10 @@ import com.google.protobuf.TextFormat; import com.google.protobuf.util.JsonFormat; import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.FeatureBoard; +import io.substrait.isthmus.ImmutableFeatureBoard; +import io.substrait.isthmus.SqlExpressionToSubstrait; +import io.substrait.isthmus.SqlToSubstrait; import io.substrait.isthmus.SubstraitRelVisitor.CrossJoinPolicy; import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; diff --git a/isthmus/src/main/java/io/substrait/isthmus/RegisterAtRuntime.java b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/RegisterAtRuntime.java similarity index 94% rename from isthmus/src/main/java/io/substrait/isthmus/RegisterAtRuntime.java rename to isthmus-cli/src/main/java/io/substrait/isthmus/cli/RegisterAtRuntime.java index 676a12a88..416574208 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/RegisterAtRuntime.java +++ b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/RegisterAtRuntime.java @@ -1,4 +1,4 @@ -package io.substrait.isthmus; +package io.substrait.isthmus.cli; import com.google.protobuf.Empty; import com.google.protobuf.GeneratedMessageV3; @@ -108,17 +108,17 @@ private static void register(Class c) { } private static void registerByAnnotation(Reflections reflections, Class c) { - reflections.getTypesAnnotatedWith(c).stream() + reflections + .getTypesAnnotatedWith(c) .forEach( inner -> { register(inner); - reflections.getSubTypesOf(c).stream().forEach(RegisterAtRuntime::register); + reflections.getSubTypesOf(c).forEach(RegisterAtRuntime::register); }); } private static void registerByParent(Reflections reflections, Class c) { register(c); - reflections.getSubTypesOf(c).stream().forEach(RegisterAtRuntime::register); + reflections.getSubTypesOf(c).forEach(RegisterAtRuntime::register); } - ; } diff --git a/isthmus/src/test/java/io/substrait/isthmus/IsthmusEntryPointTest.java b/isthmus-cli/src/test/java/io/substrait/isthmus/cli/IsthmusEntryPointTest.java similarity index 96% rename from isthmus/src/test/java/io/substrait/isthmus/IsthmusEntryPointTest.java rename to isthmus-cli/src/test/java/io/substrait/isthmus/cli/IsthmusEntryPointTest.java index 262b31fe3..a99a8f61f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/IsthmusEntryPointTest.java +++ b/isthmus-cli/src/test/java/io/substrait/isthmus/cli/IsthmusEntryPointTest.java @@ -1,10 +1,11 @@ -package io.substrait.isthmus; +package io.substrait.isthmus.cli; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import io.substrait.isthmus.FeatureBoard; import io.substrait.isthmus.SubstraitRelVisitor.CrossJoinPolicy; import org.apache.calcite.sql.validate.SqlConformance; import org.apache.calcite.sql.validate.SqlConformanceEnum; diff --git a/isthmus/src/test/script/smoke.sh b/isthmus-cli/src/test/script/smoke.sh similarity index 97% rename from isthmus/src/test/script/smoke.sh rename to isthmus-cli/src/test/script/smoke.sh index 7ecbc2ec6..e7dd7a77d 100755 --- a/isthmus/src/test/script/smoke.sh +++ b/isthmus-cli/src/test/script/smoke.sh @@ -1,9 +1,12 @@ #!/bin/bash + +set -eu -o pipefail + parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P ) cd "${parent_path}" CMD=../../../build/graal/isthmus LINEITEM="CREATE TABLE LINEITEM (L_ORDERKEY BIGINT NOT NULL, L_PARTKEY BIGINT NOT NULL, L_SUPPKEY BIGINT NOT NULL, L_LINENUMBER INTEGER, L_QUANTITY DECIMAL, L_EXTENDEDPRICE DECIMAL, L_DISCOUNT DECIMAL, L_TAX DECIMAL, L_RETURNFLAG CHAR(1), L_LINESTATUS CHAR(1), L_SHIPDATE DATE, L_COMMITDATE DATE, L_RECEIPTDATE DATE, L_SHIPINSTRUCT CHAR(25), L_SHIPMODE CHAR(10), L_COMMENT VARCHAR(44))" -echo $LINEITEM +echo "${LINEITEM}" #set -x # SQL Query - Simple diff --git a/isthmus/src/test/script/tpch_smoke.sh b/isthmus-cli/src/test/script/tpch_smoke.sh similarity index 56% rename from isthmus/src/test/script/tpch_smoke.sh rename to isthmus-cli/src/test/script/tpch_smoke.sh index ff48e0e02..a6f63bff9 100755 --- a/isthmus/src/test/script/tpch_smoke.sh +++ b/isthmus-cli/src/test/script/tpch_smoke.sh @@ -1,24 +1,27 @@ #!/bin/bash + +set -eu -o pipefail + parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P ) cd "${parent_path}" CMD=../../../build/graal/isthmus -TPCH="../resources/tpch/" +TPCH="../../../../isthmus/src/test/resources/tpch" -DDL=`cat ${TPCH}/schema.sql` +DDL=$(cat ${TPCH}/schema.sql) QUERY_FOLDER="${TPCH}/queries" ##for QUERYNUM in {1..22}; do # TODO: failed query: 8 12 15. 15 failed due to comments QUERY_TO_RUN=(1 2 3 4 5 6 7 9 10 11 13 14 16 17 18 19 20 21 22) for QUERY_NUM in "${QUERY_TO_RUN[@]}"; do - if [ $QUERY_NUM -lt 10 ]; then - QUERY=`cat ${QUERY_FOLDER}/0${QUERY_NUM}.sql` + if [ "${QUERY_NUM}" -lt 10 ]; then + QUERY=$(cat "${QUERY_FOLDER}/0${QUERY_NUM}.sql") else - QUERY=`cat ${QUERY_FOLDER}/${QUERY_NUM}.sql` + QUERY=$(cat "${QUERY_FOLDER}/${QUERY_NUM}.sql") fi - echo "Processing tpc-h query", $QUERY_NUM - echo $QUERY + echo "Processing tpc-h query ${QUERY_NUM}" + echo "${QUERY}" $CMD "${QUERY}" --create "${DDL}" done diff --git a/isthmus/README.md b/isthmus/README.md index 21d9da919..827575b9f 100644 --- a/isthmus/README.md +++ b/isthmus/README.md @@ -2,389 +2,7 @@ ## Overview -Substrait Isthmus is a Java library which enables serializing SQL queries to [Substrait Protobuf](https://substrait.io/serialization/binary_serialization/) and SQL expressions to [Extended Expressions](https://substrait.io/expressions/extended_expression/) via +Substrait Isthmus is a Java library which enables serializing SQL queries to [Substrait Protobuf](https://substrait.io/serialization/binary_serialization/) and SQL expressions to [Extended Expressions](https://substrait.io/expressions/extended_expression/) using the Calcite SQL compiler. Optionally, you can leverage the Calcite RelNode to Substrait Plan translator as an IR translation. -## Build - -Isthmus can be built as a native executable via Graal - -``` -./gradlew nativeImage -``` - -## Usage - -### Version - -``` -$ ./isthmus/build/graal/isthmus --version - -isthmus 0.1 -``` - -### Help - -``` -$ ./isthmus/build/graal/isthmus --help - -Usage: isthmus [-hmV] [--crossjoinpolicy=] - [--outputformat=] - [--sqlconformancemode=] - [-c=]... [-e=...]... [] -Convert SQL Queries and SQL Expressions to Substrait - [] A SQL query - -c, --create= - One or multiple create table statements e.g. CREATE - TABLE T1(foo int, bar bigint) - --crossjoinpolicy= - One of built-in Calcite SQL compatibility modes: - KEEP_AS_CROSS_JOIN, CONVERT_TO_INNER_JOIN - -e, --expression=... - One or more SQL expressions e.g. col + 1 - -h, --help Show this help message and exit. - -m, --multistatement Allow multiple statements terminated with a semicolon - --outputformat= - Set the output format for the generated plan: - PROTOJSON, PROTOTEXT, BINARY - --sqlconformancemode= - One of built-in Calcite SQL compatibility modes: - DEFAULT, LENIENT, BABEL, STRICT_92, STRICT_99, - PRAGMATIC_99, BIG_QUERY, MYSQL_5, ORACLE_10, - ORACLE_12, STRICT_2003, PRAGMATIC_2003, PRESTO, - SQL_SERVER_2008 - -V, --version Print version information and exit. -``` - -## Example - -### SQL to Substrait Plan - -``` -> $ ./isthmus/build/graal/isthmus \ - -c "CREATE TABLE Persons ( firstName VARCHAR, lastName VARCHAR, zip INT )" \ - "SELECT lastName, firstName FROM Persons WHERE zip = 90210" - -{ - "extensionUris": [{ - "extensionUriAnchor": 1, - "uri": "/functions_comparison.yaml" - }], - "extensions": [{ - "extensionFunction": { - "extensionUriReference": 1, - "functionAnchor": 0, - "name": "equal:any1_any1" - } - }], - "relations": [{ - "root": { - "input": { - "project": { - "common": { - "emit": { - "outputMapping": [3, 4] - } - }, - "input": { - "filter": { - "common": { - "direct": { - } - }, - "input": { - "read": { - "common": { - "direct": { - } - }, - "baseSchema": { - "names": ["FIRSTNAME", "LASTNAME", "ZIP"], - "struct": { - "types": [{ - "string": { - "typeVariationReference": 0, - "nullability": "NULLABILITY_NULLABLE" - } - }, { - "string": { - "typeVariationReference": 0, - "nullability": "NULLABILITY_NULLABLE" - } - }, { - "i32": { - "typeVariationReference": 0, - "nullability": "NULLABILITY_NULLABLE" - } - }], - "typeVariationReference": 0, - "nullability": "NULLABILITY_REQUIRED" - } - }, - "namedTable": { - "names": ["PERSONS"] - } - } - }, - "condition": { - "scalarFunction": { - "functionReference": 0, - "args": [{ - "selection": { - "directReference": { - "structField": { - "field": 2 - } - }, - "rootReference": { - } - } - }, { - "literal": { - "i32": 90210, - "nullable": false, - "typeVariationReference": 0 - } - }], - "outputType": { - "bool": { - "typeVariationReference": 0, - "nullability": "NULLABILITY_NULLABLE" - } - }, - "arguments": [] - } - } - } - }, - "expressions": [{ - "selection": { - "directReference": { - "structField": { - "field": 1 - } - }, - "rootReference": { - } - } - }, { - "selection": { - "directReference": { - "structField": { - "field": 0 - } - }, - "rootReference": { - } - } - }] - } - }, - "names": ["LASTNAME", "FIRSTNAME"] - } - }], - "expectedTypeUrls": [] -} -``` - -### SQL Expression to Substrait Extended Expression - -#### Projection - -``` -$ ./isthmus/build/graal/isthmus -c "CREATE TABLE NATION (N_NATIONKEY BIGINT NOT NULL, N_NAME CHAR(25), N_REGIONKEY BIGINT NOT NULL, N_COMMENT VARCHAR(152))" \ - -e "N_REGIONKEY + 10" - -{ - "extensionUris": [{ - "extensionUriAnchor": 1, - "uri": "/functions_arithmetic.yaml" - }], - "extensions": [{ - "extensionFunction": { - "extensionUriReference": 1, - "functionAnchor": 0, - "name": "add:i64_i64" - } - }], - "referredExpr": [{ - "expression": { - "scalarFunction": { - "functionReference": 0, - "args": [], - "outputType": { - "i64": { - "typeVariationReference": 0, - "nullability": "NULLABILITY_REQUIRED" - } - }, - "arguments": [{ - "value": { - "selection": { - "directReference": { - "structField": { - "field": 2 - } - }, - "rootReference": { - } - } - } - }, { - "value": { - "cast": { - "type": { - "i64": { - "typeVariationReference": 0, - "nullability": "NULLABILITY_REQUIRED" - } - }, - "input": { - "literal": { - "i32": 10, - "nullable": false, - "typeVariationReference": 0 - } - }, - "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" - } - } - }], - "options": [] - } - }, - "outputNames": ["new-column"] - }], - "baseSchema": { - "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], - "struct": { - "types": [{ - "i64": { - "typeVariationReference": 0, - "nullability": "NULLABILITY_REQUIRED" - } - }, { - "fixedChar": { - "length": 25, - "typeVariationReference": 0, - "nullability": "NULLABILITY_NULLABLE" - } - }, { - "i64": { - "typeVariationReference": 0, - "nullability": "NULLABILITY_REQUIRED" - } - }, { - "varchar": { - "length": 152, - "typeVariationReference": 0, - "nullability": "NULLABILITY_NULLABLE" - } - }], - "typeVariationReference": 0, - "nullability": "NULLABILITY_REQUIRED" - } - }, - "expectedTypeUrls": [] -} -``` - -#### Filter - -``` -$ ./isthmus/build/graal/isthmus -c "CREATE TABLE NATION (N_NATIONKEY BIGINT NOT NULL, N_NAME CHAR(25), N_REGIONKEY BIGINT NOT NULL, N_COMMENT VARCHAR(152))" \ - -e "N_REGIONKEY > 10" - -{ - "extensionUris": [{ - "extensionUriAnchor": 1, - "uri": "/functions_comparison.yaml" - }], - "extensions": [{ - "extensionFunction": { - "extensionUriReference": 1, - "functionAnchor": 0, - "name": "gt:any_any" - } - }], - "referredExpr": [{ - "expression": { - "scalarFunction": { - "functionReference": 0, - "args": [], - "outputType": { - "bool": { - "typeVariationReference": 0, - "nullability": "NULLABILITY_REQUIRED" - } - }, - "arguments": [{ - "value": { - "selection": { - "directReference": { - "structField": { - "field": 2 - } - }, - "rootReference": { - } - } - } - }, { - "value": { - "cast": { - "type": { - "i64": { - "typeVariationReference": 0, - "nullability": "NULLABILITY_REQUIRED" - } - }, - "input": { - "literal": { - "i32": 10, - "nullable": false, - "typeVariationReference": 0 - } - }, - "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" - } - } - }], - "options": [] - } - }, - "outputNames": ["new-column"] - }], - "baseSchema": { - "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], - "struct": { - "types": [{ - "i64": { - "typeVariationReference": 0, - "nullability": "NULLABILITY_REQUIRED" - } - }, { - "fixedChar": { - "length": 25, - "typeVariationReference": 0, - "nullability": "NULLABILITY_NULLABLE" - } - }, { - "i64": { - "typeVariationReference": 0, - "nullability": "NULLABILITY_REQUIRED" - } - }, { - "varchar": { - "length": 152, - "typeVariationReference": 0, - "nullability": "NULLABILITY_NULLABLE" - } - }], - "typeVariationReference": 0, - "nullability": "NULLABILITY_REQUIRED" - } - }, - "expectedTypeUrls": [] -} -``` +The capability provided by this library can be accessed using a command-line interface, provided by [isthmus-cli](../isthmus-cli). diff --git a/isthmus/build.gradle.kts b/isthmus/build.gradle.kts index cf3f8b332..4de0daef1 100644 --- a/isthmus/build.gradle.kts +++ b/isthmus/build.gradle.kts @@ -4,9 +4,8 @@ plugins { `maven-publish` id("java") id("idea") - id("com.palantir.graal") version "0.10.0" id("com.diffplug.spotless") version "6.11.0" - id("com.github.johnrengelman.shadow") version "7.1.2" + id("com.github.johnrengelman.shadow") version "8.1.1" signing } @@ -43,8 +42,8 @@ publishing { repositories { maven { name = "local" - val releasesRepoUrl = "$buildDir/repos/releases" - val snapshotsRepoUrl = "$buildDir/repos/snapshots" + val releasesRepoUrl = layout.buildDirectory.dir("repos/releases") + val snapshotsRepoUrl = layout.buildDirectory.dir("repos/snapshots") url = uri(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) } } @@ -71,56 +70,45 @@ java { withSourcesJar() } -var CALCITE_VERSION = "1.34.0" +val CALCITE_VERSION = properties.get("calcite.version") +val GUAVA_VERSION = properties.get("guava.version") +val IMMUTABLES_VERSION = properties.get("immutables.version") +val JACKSON_VERSION = properties.get("jackson.version") +val JUNIT_VERSION = properties.get("junit.version") +val SLF4J_VERSION = properties.get("slf4j.version") +val PROTOBUF_VERSION = properties.get("protobuf.version") dependencies { implementation(project(":core")) implementation("org.apache.calcite:calcite-core:${CALCITE_VERSION}") implementation("org.apache.calcite:calcite-server:${CALCITE_VERSION}") - implementation("org.junit.jupiter:junit-jupiter:5.9.2") + testImplementation("org.junit.jupiter:junit-jupiter:${JUNIT_VERSION}") implementation("org.reflections:reflections:0.9.12") - implementation("com.google.guava:guava:29.0-jre") - implementation("org.graalvm.sdk:graal-sdk:22.1.0") - implementation("info.picocli:picocli:4.7.5") - annotationProcessor("info.picocli:picocli-codegen:4.7.5") - implementation("com.fasterxml.jackson.core:jackson-databind:2.13.4") - implementation("com.fasterxml.jackson.core:jackson-annotations:2.13.4") - implementation("com.fasterxml.jackson.datatype:jackson-datatype-jdk8:2.13.4") - implementation("com.google.protobuf:protobuf-java-util:3.17.3") { + implementation("com.google.guava:guava:${GUAVA_VERSION}") + implementation("com.fasterxml.jackson.core:jackson-databind:${JACKSON_VERSION}") + implementation("com.fasterxml.jackson.core:jackson-annotations:${JACKSON_VERSION}") + implementation("com.fasterxml.jackson.datatype:jackson-datatype-jdk8:${JACKSON_VERSION}") + implementation("com.google.protobuf:protobuf-java-util:${PROTOBUF_VERSION}") { exclude("com.google.guava", "guava") .because("Brings in Guava for Android, which we don't want (and breaks multimaps).") } implementation("com.google.code.findbugs:jsr305:3.0.2") implementation("com.github.ben-manes.caffeine:caffeine:3.0.4") - implementation("org.immutables:value-annotations:2.8.8") - annotationProcessor("org.immutables:value:2.8.8") + implementation("org.immutables:value-annotations:${IMMUTABLES_VERSION}") + implementation("org.slf4j:slf4j-api:${SLF4J_VERSION}") + annotationProcessor("org.immutables:value:${IMMUTABLES_VERSION}") testImplementation("org.apache.calcite:calcite-plus:${CALCITE_VERSION}") annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") } -graal { - mainClass("io.substrait.isthmus.IsthmusEntryPoint") - outputName("isthmus") - graalVersion("22.1.0") - javaVersion("17") - option("--no-fallback") - option( - "--initialize-at-build-time=io.substrait.isthmus.InitializeAtBuildTime,org.slf4j.impl.StaticLoggerBinder,com.google.common.math.IntMath\$1,com.google.common.base.Platform,com.google.common.util.concurrent.AbstractFuture\$UnsafeAtomicHelper,com.google.common.collect.ImmutableSortedMap,com.google.common.math.IntMath,com.google.common.collect.RegularImmutableSortedSet,com.google.common.cache.LocalCache,com.google.common.collect.Range,org.apache.commons.codec.language.Soundex,com.google.common.collect.ImmutableRangeSet,org.slf4j.LoggerFactory,com.google.common.collect.Platform,com.google.common.util.concurrent.SettableFuture,com.google.common.util.concurrent.AbstractFuture,com.google.common.util.concurrent.AbstractFuture,com.google.common.cache.CacheBuilder,com.google.common.base.Preconditions,com.google.common.collect.RegularImmutableMap,org.slf4j.impl.JDK14LoggerAdapter,org.apache.calcite.rel.metadata.RelMdColumnUniqueness,org.apache.calcite.rel.metadata.BuiltInMetadata\$ColumnOrigin,io.substrait.isthmus.metadata.LambdaMetadataSupplier,org.apache.calcite.rel.metadata.BuiltInMetadata\$PopulationSize,org.apache.calcite.rel.metadata.BuiltInMetadata\$Size,org.apache.calcite.rel.metadata.BuiltInMetadata\$UniqueKeys,org.apache.calcite.rel.metadata.RelMdColumnOrigins,org.apache.calcite.rel.metadata.RelMdExplainVisibility,org.apache.calcite.rel.metadata.RelMdMemory,org.apache.calcite.rel.metadata.RelMdExpressionLineage,org.apache.calcite.rel.metadata.RelMdDistinctRowCount,org.apache.calcite.rel.metadata.BuiltInMetadata\$RowCount,org.apache.calcite.rel.metadata.BuiltInMetadata\$PercentageOriginalRows,org.apache.calcite.util.Pair,org.apache.calcite.rel.metadata.BuiltInMetadata\$ExpressionLineage,org.apache.calcite.rel.metadata.BuiltInMetadata\$MinRowCount,com.google.common.primitives.Primitives,org.apache.calcite.rel.metadata.BuiltInMetadata\$Selectivity,org.apache.calcite.rel.metadata.BuiltInMetadata\$Parallelism,org.apache.calcite.rel.metadata.RelMdUniqueKeys,org.apache.calcite.rel.metadata.RelMdParallelism,org.apache.calcite.rel.metadata.RelMdPercentageOriginalRows,org.apache.calcite.rel.metadata.BuiltInMetadata\$Predicates,org.apache.calcite.rel.metadata.BuiltInMetadata\$Distribution,org.apache.calcite.config.CalciteSystemProperty,org.apache.calcite.rel.metadata.BuiltInMetadata\$NonCumulativeCost,org.apache.calcite.util.Util,org.apache.calcite.rel.metadata.RelMdAllPredicates,io.substrait.isthmus.metadata.LambdaHandlerCache,org.apache.calcite.rel.metadata.BuiltInMetadata\$TableReferences,org.apache.calcite.rel.metadata.RelMdNodeTypes,org.apache.calcite.rel.metadata.RelMdCollation,org.apache.calcite.rel.metadata.RelMdSelectivity,org.apache.calcite.rel.metadata.BuiltInMetadata\$NodeTypes,org.apache.calcite.rel.metadata.RelMdPredicates,org.apache.calcite.rel.metadata.BuiltInMetadata\$DistinctRowCount,org.apache.calcite.rel.metadata.RelMdRowCount,org.apache.calcite.rel.metadata.BuiltInMetadata\$MaxRowCount,org.apache.calcite.rel.metadata.BuiltInMetadata\$AllPredicates,org.apache.calcite.rel.metadata.RelMdMaxRowCount,org.apache.calcite.rel.metadata.RelMdLowerBoundCost,org.apache.calcite.rel.metadata.BuiltInMetadata\$ExplainVisibility,org.apache.calcite.rel.metadata.BuiltInMetadata\$ColumnUniqueness,org.apache.calcite.rel.metadata.RelMdPopulationSize,org.apache.calcite.rel.metadata.BuiltInMetadata\$Memory,org.apache.calcite.rel.metadata.RelMdMinRowCount,org.apache.calcite.rel.metadata.RelMdSize,org.apache.calcite.rel.metadata.BuiltInMetadata\$LowerBoundCost,org.apache.calcite.rel.metadata.RelMdTableReferences,org.apache.calcite.rel.metadata.RelMdDistribution,io.substrait.isthmus.metadata.LegacyToLambdaGenerator,org.apache.calcite.rel.metadata.BuiltInMetadata\$CumulativeCost,org.apache.calcite.rel.metadata.BuiltInMetadata\$Collation" - ) - option("-H:IncludeResources=.*yaml") - option("--report-unsupported-elements-at-runtime") - option("-H:+ReportExceptionStackTraces") - option("-H:DynamicProxyConfigurationFiles=proxies.json") - option("--features=io.substrait.isthmus.RegisterAtRuntime") - option("-J--enable-preview") -} - tasks { named("shadowJar") { archiveBaseName.set("isthmus") manifest { attributes(mapOf("Main-Class" to "io.substrait.isthmus.PlanEntryPoint")) } } + + classes { dependsOn(":core:shadowJar") } } tasks { build { dependsOn(shadowJar) } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 034e3417a..399c99678 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -317,6 +317,7 @@ private AggregateCall fromMeasure(Aggregate.Measure measure) { distinct, false, false, + Collections.emptyList(), argIndex, filterArg, null, diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index ce8612b9b..cd10a7ef3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.OptionalLong; import java.util.stream.Collectors; import java.util.stream.Stream; import org.apache.calcite.rel.RelFieldCollation; @@ -135,7 +136,7 @@ public Rel visit(org.apache.calcite.rel.core.Values values) { return ExpressionCreator.struct(false, fields); }) .collect(Collectors.toUnmodifiableList()); - return VirtualTableScan.builder().addAllDfsNames(type.names()).addAllRows(structs).build(); + return VirtualTableScan.builder().initialSchema(type).addAllRows(structs).build(); } @Override @@ -283,22 +284,33 @@ public Rel visit(org.apache.calcite.rel.core.Match match) { @Override public Rel visit(org.apache.calcite.rel.core.Sort sort) { - var input = apply(sort.getInput()); - var fields = - sort.getCollation().getFieldCollations().stream() - .map(t -> toSortField(t, input.getRecordType())) - .collect(java.util.stream.Collectors.toList()); - var convertedSort = Sort.builder().addAllSortFields(fields).input(input).build(); - if (sort.fetch == null && sort.offset == null) { - return convertedSort; + Rel input = apply(sort.getInput()); + Rel output = input; + + // The Calcite Sort relation combines sorting along with offset and fetch/limit + // Sorting is applied BEFORE the offset and limit is are applied + // Substrait splits this functionality into two different relations: SortRel, FetchRel + // Add the SortRel to the relation tree first to match Calcite's application order + if (!sort.getCollation().getFieldCollations().isEmpty()) { + List fields = + sort.getCollation().getFieldCollations().stream() + .map(t -> toSortField(t, input.getRecordType())) + .collect(java.util.stream.Collectors.toList()); + output = Sort.builder().addAllSortFields(fields).input(output).build(); } - var offset = Optional.ofNullable(sort.offset).map(r -> asLong(r)).orElse(0L); - var builder = Fetch.builder().input(convertedSort).offset(offset); - if (sort.fetch == null) { - return builder.build(); + + if (sort.fetch != null || sort.offset != null) { + Long offset = Optional.ofNullable(sort.offset).map(this::asLong).orElse(0L); + OptionalLong count = + Optional.ofNullable(sort.fetch) + .map(r -> OptionalLong.of(asLong(r))) + .orElse(OptionalLong.empty()); + + var builder = Fetch.builder().input(output).offset(offset).count(count); + output = builder.build(); } - return builder.count(asLong(sort.fetch)).build(); + return output; } private long asLong(RexNode rex) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java index 92ccecbe4..e4ce67009 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java @@ -71,8 +71,8 @@ private Type toSubstrait(RelDataType type, List names) { case SMALLINT -> creator.I16; case INTEGER -> creator.I32; case BIGINT -> creator.I64; - case FLOAT -> creator.FP32; - case DOUBLE -> creator.FP64; + case REAL -> creator.FP32; + case FLOAT, DOUBLE -> creator.FP64; case DECIMAL -> { if (type.getPrecision() > 38) { throw new UnsupportedOperationException( @@ -89,27 +89,9 @@ private Type toSubstrait(RelDataType type, List names) { } case SYMBOL -> creator.STRING; case DATE -> creator.DATE; - case TIME -> { - if (type.getPrecision() != 6) { - throw new UnsupportedOperationException( - "unsupported time precision " + type.getPrecision()); - } - yield creator.TIME; - } - case TIMESTAMP -> { - if (type.getPrecision() != 6) { - throw new UnsupportedOperationException( - "unsupported timestamp precision " + type.getPrecision()); - } - yield creator.TIMESTAMP; - } - case TIMESTAMP_WITH_LOCAL_TIME_ZONE -> { - if (type.getPrecision() != 6) { - throw new UnsupportedOperationException( - "unsupported timestamptz precision " + type.getPrecision()); - } - yield creator.TIMESTAMP_TZ; - } + case TIME -> creator.TIME; + case TIMESTAMP -> creator.precisionTimestamp(type.getPrecision()); + case TIMESTAMP_WITH_LOCAL_TIME_ZONE -> creator.precisionTimestampTZ(type.getPrecision()); case INTERVAL_YEAR, INTERVAL_YEAR_MONTH, INTERVAL_MONTH -> creator.INTERVAL_YEAR; case INTERVAL_DAY, INTERVAL_DAY_HOUR, @@ -203,7 +185,7 @@ public RelDataType visit(Type.I64 expr) { @Override public RelDataType visit(Type.FP32 expr) { - return t(n(expr), SqlTypeName.FLOAT); + return t(n(expr), SqlTypeName.REAL); } @Override @@ -241,6 +223,31 @@ public RelDataType visit(Type.Timestamp expr) { return t(n(expr), SqlTypeName.TIMESTAMP, 6); } + @Override + public RelDataType visit(Type.PrecisionTimestamp expr) { + int maxPrecision = typeFactory.getTypeSystem().getMaxPrecision(SqlTypeName.TIMESTAMP); + if (expr.precision() > maxPrecision) { + throw new UnsupportedOperationException( + String.format( + "unsupported precision_timestamp precision %s, max precision in Calcite type system is set to %s", + expr.precision(), maxPrecision)); + } + return t(n(expr), SqlTypeName.TIMESTAMP, expr.precision()); + } + + @Override + public RelDataType visit(Type.PrecisionTimestampTZ expr) throws RuntimeException { + int maxPrecision = + typeFactory.getTypeSystem().getMaxPrecision(SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE); + if (expr.precision() > maxPrecision) { + throw new UnsupportedOperationException( + String.format( + "unsupported precision_timestamp_tz precision %s, max precision in Calcite type system is set to %s", + expr.precision(), maxPrecision)); + } + return t(n(expr), SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE, expr.precision()); + } + @Override public RelDataType visit(Type.IntervalYear expr) { return typeFactory.createTypeWithNullability( diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index 48555c19c..52dcf78ab 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -24,13 +24,22 @@ public class CallConverters { public static Function CAST = typeConverter -> (call, visitor) -> { - if (call.getKind() != SqlKind.CAST) { - return null; + Expression.FailureBehavior failureBehavior; + switch (call.getKind()) { + case CAST: + failureBehavior = Expression.FailureBehavior.THROW_EXCEPTION; + break; + case SAFE_CAST: + failureBehavior = Expression.FailureBehavior.RETURN_NULL; + break; + default: + return null; } return ExpressionCreator.cast( typeConverter.toSubstrait(call.getType()), - visitor.apply(call.getOperands().get(0))); + visitor.apply(call.getOperands().get(0)), + failureBehavior); }; /** diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 9d00af02b..8f0cf24ec 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -4,6 +4,7 @@ import io.substrait.expression.AbstractExpressionVisitor; import io.substrait.expression.EnumArg; import io.substrait.expression.Expression; +import io.substrait.expression.Expression.FailureBehavior; import io.substrait.expression.Expression.SingleOrList; import io.substrait.expression.Expression.Switch; import io.substrait.expression.FieldReference; @@ -478,8 +479,9 @@ private String convert(FunctionArg a) { @Override public RexNode visit(Expression.Cast expr) throws RuntimeException { + var safeCast = expr.failureBehavior() == FailureBehavior.RETURN_NULL; return rexBuilder.makeAbstractCast( - typeConverter.toCalcite(typeFactory, expr.getType()), expr.input().accept(this)); + typeConverter.toCalcite(typeFactory, expr.getType()), expr.input().accept(this), safeCast); } @Override diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FieldSelectionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FieldSelectionConverter.java index c5e93b5a9..ab6233f54 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FieldSelectionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FieldSelectionConverter.java @@ -32,10 +32,12 @@ public Optional convert( var reference = call.getOperands().get(1); if (reference.getKind() != SqlKind.LITERAL || !(reference instanceof RexLiteral)) { - logger.warn( - "Found item operator without literal kind/type. This isn't handled well. Reference was {} with toString {}.", - reference.getKind().name(), - reference); + logger + .atWarn() + .log( + "Found item operator without literal kind/type. This isn't handled well. Reference was {} with toString {}.", + reference.getKind().name(), + reference); return Optional.empty(); } @@ -99,13 +101,13 @@ private Optional toInt(Expression.Literal l) { } else if (l instanceof Expression.I64Literal i64) { return Optional.of((int) i64.value()); } - logger.warn("Literal expected to be int type but was not. {}.", l); + logger.atWarn().log("Literal expected to be int type but was not. {}.", l); return Optional.empty(); } public Optional toString(Expression.Literal l) { if (!(l instanceof Expression.FixedCharLiteral)) { - logger.warn("Literal expected to be char type but was not. {}", l); + logger.atWarn().log("Literal expected to be char type but was not. {}", l); return Optional.empty(); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index f938ad7f5..8d6ae849c 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -81,7 +81,7 @@ public FunctionConverter( for (String key : alm.keySet()) { var sigs = calciteOperators.get(key); if (sigs == null) { - logger.info("Dropping function due to no binding: {}", key); + logger.atInfo().log("Dropping function due to no binding: {}", key); continue; } @@ -390,15 +390,14 @@ public Optional attemptMatch(C call, Function topLevelCo } if (singularInputType.isPresent()) { - Optional leastRestrictive = matchByLeastRestrictive(call, outputType, operands); - if (leastRestrictive.isPresent()) { - return leastRestrictive; - } - Optional coerced = matchCoerced(call, outputType, operands); if (coerced.isPresent()) { return coerced; } + Optional leastRestrictive = matchByLeastRestrictive(call, outputType, operands); + if (leastRestrictive.isPresent()) { + return leastRestrictive; + } } return Optional.empty(); } @@ -487,7 +486,7 @@ private static List coerceArguments(List arguments, Type private static Expression coerceArgument(Expression argument, Type type) { var typeMatches = isMatch(type, argument.getType()); if (!typeMatches) { - return ExpressionCreator.cast(type, argument); + return ExpressionCreator.cast(type, argument, Expression.FailureBehavior.THROW_EXCEPTION); } return argument; } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java index d9adb487d..a170268ff 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java @@ -121,6 +121,18 @@ public Boolean visit(Type.Decimal type) { return typeToMatch instanceof Type.Decimal || typeToMatch instanceof ParameterizedType.Decimal; } + @Override + public Boolean visit(Type.PrecisionTimestamp type) { + return typeToMatch instanceof Type.PrecisionTimestamp + || typeToMatch instanceof ParameterizedType.PrecisionTimestamp; + } + + @Override + public Boolean visit(Type.PrecisionTimestampTZ type) { + return typeToMatch instanceof Type.PrecisionTimestampTZ + || typeToMatch instanceof ParameterizedType.PrecisionTimestampTZ; + } + @Override public Boolean visit(Type.Struct type) { return typeToMatch instanceof Type.Struct || typeToMatch instanceof ParameterizedType.Struct; @@ -159,6 +171,18 @@ public Boolean visit(ParameterizedType.Decimal expr) throws RuntimeException { return typeToMatch instanceof Type.Decimal || typeToMatch instanceof ParameterizedType.Decimal; } + @Override + public Boolean visit(ParameterizedType.PrecisionTimestamp expr) throws RuntimeException { + return typeToMatch instanceof Type.PrecisionTimestamp + || typeToMatch instanceof ParameterizedType.PrecisionTimestamp; + } + + @Override + public Boolean visit(ParameterizedType.PrecisionTimestampTZ expr) throws RuntimeException { + return typeToMatch instanceof Type.PrecisionTimestampTZ + || typeToMatch instanceof ParameterizedType.PrecisionTimestampTZ; + } + @Override public Boolean visit(ParameterizedType.Struct expr) throws RuntimeException { return typeToMatch instanceof Type.Struct || typeToMatch instanceof ParameterizedType.Struct; diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java index 649f0c792..433284b24 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java @@ -104,8 +104,8 @@ public Expression.Literal convert(RexLiteral literal) { } throw new UnsupportedOperationException("Unable to handle char type: " + val); } - case DOUBLE -> fp64(n, bd(literal).doubleValue()); - case FLOAT -> fp32(n, bd(literal).floatValue()); + case FLOAT, DOUBLE -> fp64(n, bd(literal).doubleValue()); + case REAL -> fp32(n, bd(literal).floatValue()); case DECIMAL -> { BigDecimal bd = bd(literal); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java index 5de28714d..2bc7ec534 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java @@ -16,6 +16,8 @@ import org.apache.calcite.rex.RexDynamicParam; import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLambda; +import org.apache.calcite.rex.RexLambdaRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLocalRef; import org.apache.calcite.rex.RexNode; @@ -193,4 +195,14 @@ public Expression visitLocalRef(RexLocalRef localRef) { public Expression visitPatternFieldRef(RexPatternFieldRef fieldRef) { throw new UnsupportedOperationException("RexPatternFieldRef not supported"); } + + @Override + public Expression visitLambda(RexLambda rexLambda) { + throw new UnsupportedOperationException("RexLambda not supported"); + } + + @Override + public Expression visitLambdaRef(RexLambdaRef rexLambdaRef) { + throw new UnsupportedOperationException("RexLambdaRef not supported"); + } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ArithmeticFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/ArithmeticFunctionTest.java index 9f9beeacf..823ef6a48 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ArithmeticFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ArithmeticFunctionTest.java @@ -8,7 +8,7 @@ public class ArithmeticFunctionTest extends PlanTestBase { static List CREATES = List.of( - "CREATE TABLE numbers (i8 TINYINT, i16 SMALLINT, i32 INT, i64 BIGINT, fp32 FLOAT, fp64 DOUBLE)"); + "CREATE TABLE numbers (i8 TINYINT, i16 SMALLINT, i32 INT, i64 BIGINT, fp32 REAL, fp64 DOUBLE)"); @ParameterizedTest @ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"}) diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteCallTest.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteCallTest.java index 39a602b66..1c7e8d844 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteCallTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteCallTest.java @@ -65,7 +65,10 @@ public void coerceNumericOp() { func -> { // check that there is a cast for the incorrect argument type. assertEquals( - ExpressionCreator.cast(TypeCreator.REQUIRED.I64, ExpressionCreator.i32(false, 20)), + ExpressionCreator.cast( + TypeCreator.REQUIRED.I64, + ExpressionCreator.i32(false, 20), + Expression.FailureBehavior.THROW_EXCEPTION), func.arguments().get(0)); }, false); // TODO: implicit calcite cast diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java index deff5bd0e..5029b678f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java @@ -70,7 +70,7 @@ void tI64() { @Test void tFP32() { - bitest(fp32(false, 4.44F), c(4.44F, SqlTypeName.FLOAT)); + bitest(fp32(false, 4.44F), c(4.44F, SqlTypeName.REAL)); } @Test @@ -78,6 +78,11 @@ void tFP64() { bitest(fp64(false, 4.45F), c(4.45F, SqlTypeName.DOUBLE)); } + @Test + void tFloatFP64() { + test(fp64(false, 4.45F), c(4.45F, SqlTypeName.FLOAT)); + } + @Test void tStr() { bitest(string(false, "my test"), c("my test", SqlTypeName.VARCHAR)); diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java index 36fae668d..20cbd962f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java @@ -51,7 +51,7 @@ void i64(boolean nullable) { @ParameterizedTest @ValueSource(booleans = {true, false}) void fp32(boolean nullable) { - testType(Type.withNullability(nullable).FP32, SqlTypeName.FLOAT, nullable); + testType(Type.withNullability(nullable).FP32, SqlTypeName.REAL, nullable); } @ParameterizedTest @@ -60,6 +60,15 @@ void fp64(boolean nullable) { testType(Type.withNullability(nullable).FP64, SqlTypeName.DOUBLE, nullable); } + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void calciteFloatToFp64(boolean nullable) { + assertEquals( + Type.withNullability(nullable).FP64, + TypeConverter.DEFAULT.toSubstrait( + type.createTypeWithNullability(type.createSqlType(SqlTypeName.FLOAT), nullable))); + } + @ParameterizedTest @ValueSource(booleans = {true, false}) void date(boolean nullable) { @@ -74,18 +83,26 @@ void time(boolean nullable) { @ParameterizedTest @ValueSource(booleans = {true, false}) - void timestamp(boolean nullable) { - testType(Type.withNullability(nullable).TIMESTAMP, SqlTypeName.TIMESTAMP, nullable, 6); + void precisionTimeStamp(boolean nullable) { + for (int precision : new int[] {0, 3, 6}) { + testType( + Type.withNullability(nullable).precisionTimestamp(precision), + SqlTypeName.TIMESTAMP, + nullable, + precision); + } } @ParameterizedTest @ValueSource(booleans = {true, false}) - void timestamptz(boolean nullable) { - testType( - Type.withNullability(nullable).TIMESTAMP_TZ, - SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE, - nullable, - 6); + void precisionTimestamptz(boolean nullable) { + for (int precision : new int[] {0, 3, 6}) { + testType( + Type.withNullability(nullable).precisionTimestampTZ(precision), + SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE, + nullable, + precision); + } } @ParameterizedTest diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index ff4367335..4772468ad 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -38,8 +38,6 @@ /** Verify that custom functions can convert from Substrait to Calcite and back. */ public class CustomFunctionTest extends PlanTestBase { - static final TypeCreator R = TypeCreator.of(false); - static final TypeCreator N = TypeCreator.of(true); // Define custom functions in a "functions_custom.yaml" extension static final String NAMESPACE = "/functions_custom"; diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java b/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java index 58a39f613..0e877be3b 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java @@ -6,6 +6,7 @@ import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.extension.ExtensionCollector; import io.substrait.isthmus.expression.ExpressionRexConverter; @@ -105,6 +106,24 @@ public void switchExpression() { expression); } + @Test + public void castFailureCondition() { + Rel rel = + b.project( + input -> + List.of( + ExpressionCreator.cast( + R.I64, + b.fieldReference(input, 0), + Expression.FailureBehavior.THROW_EXCEPTION), + ExpressionCreator.cast( + R.I32, b.fieldReference(input, 0), Expression.FailureBehavior.RETURN_NULL)), + b.remap(1, 2), + b.namedScan(List.of("test"), List.of("col1"), List.of(R.STRING))); + + assertFullRoundTrip(rel); + } + void assertExpressionEquality(Expression expected, Expression actual) { // go the extra mile and convert both inputs to protobuf // helps verify that the protobuf conversion is not broken diff --git a/isthmus/src/test/java/io/substrait/isthmus/FetchTest.java b/isthmus/src/test/java/io/substrait/isthmus/FetchTest.java new file mode 100644 index 000000000..8cfeae86f --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/FetchTest.java @@ -0,0 +1,34 @@ +package io.substrait.isthmus; + +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.relation.Rel; +import io.substrait.type.TypeCreator; +import java.util.List; +import org.junit.jupiter.api.Test; + +public class FetchTest extends PlanTestBase { + + static final TypeCreator R = TypeCreator.of(false); + + final SubstraitBuilder b = new SubstraitBuilder(extensions); + + final Rel TABLE = b.namedScan(List.of("test"), List.of("col1"), List.of(R.STRING)); + + @Test + void limitOnly() { + Rel rel = b.limit(50, TABLE); + assertFullRoundTrip(rel); + } + + @Test + void offsetOnly() { + Rel rel = b.offset(50, TABLE); + assertFullRoundTrip(rel); + } + + @Test + void offsetAndLimit() { + Rel rel = b.fetch(50, 10, TABLE); + assertFullRoundTrip(rel); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/LogarithmicFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LogarithmicFunctionTest.java index bb999cee6..30091d0ad 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/LogarithmicFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/LogarithmicFunctionTest.java @@ -8,7 +8,7 @@ public class LogarithmicFunctionTest extends PlanTestBase { static List CREATES = List.of( - "CREATE TABLE numbers (i8 TINYINT, i16 SMALLINT, i32 INT, i64 BIGINT, fp32 FLOAT, fp64 DOUBLE)"); + "CREATE TABLE numbers (i8 TINYINT, i16 SMALLINT, i32 INT, i64 BIGINT, fp32 REAL, fp64 DOUBLE)"); @ParameterizedTest @ValueSource(strings = {"fp32", "fp64"}) diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index 657c9eeea..82025ebba 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -8,6 +8,7 @@ import com.google.common.annotations.Beta; import com.google.common.base.Charsets; import com.google.common.io.Resources; +import io.substrait.dsl.SubstraitBuilder; import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; import io.substrait.plan.Plan; @@ -17,6 +18,7 @@ import io.substrait.relation.Rel; import io.substrait.relation.RelProtoConverter; import io.substrait.type.Type; +import io.substrait.type.TypeCreator; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -46,6 +48,9 @@ public class PlanTestBase { protected final RelBuilder builder = creator.createRelBuilder(); protected final RexBuilder rex = creator.rex(); protected final RelDataTypeFactory typeFactory = creator.typeFactory(); + protected final SubstraitBuilder substraitBuilder = new SubstraitBuilder(extensions); + protected static final TypeCreator R = TypeCreator.of(false); + protected static final TypeCreator N = TypeCreator.of(true); public static String asString(String resource) throws IOException { return Resources.toString(Resources.getResource(resource), Charsets.UTF_8); diff --git a/isthmus/src/test/java/io/substrait/isthmus/RoundingFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/RoundingFunctionTest.java index 8838a111c..e8d500dfa 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RoundingFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RoundingFunctionTest.java @@ -8,7 +8,7 @@ public class RoundingFunctionTest extends PlanTestBase { static List CREATES = List.of( - "CREATE TABLE numbers (i8 TINYINT, i16 SMALLINT, i32 INT, i64 BIGINT, fp32 FLOAT, fp64 DOUBLE)"); + "CREATE TABLE numbers (i8 TINYINT, i16 SMALLINT, i32 INT, i64 BIGINT, fp32 REAL, fp64 DOUBLE)"); @ParameterizedTest @ValueSource(strings = {"fp32", "fp64"}) diff --git a/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java index e31ceafbf..56eb645fe 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java @@ -2,7 +2,12 @@ import static org.junit.jupiter.api.Assertions.assertThrows; +import io.substrait.expression.Expression; +import io.substrait.expression.WindowBound; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.relation.Rel; import java.io.IOException; +import java.util.List; import org.apache.calcite.sql.parser.SqlParseException; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -19,6 +24,16 @@ void rowNumber() throws IOException, SqlParseException { assertFullRoundTrip("select O_ORDERKEY, row_number() over () from ORDERS"); } + @Test + void lag() throws IOException, SqlParseException { + assertFullRoundTrip("select O_TOTALPRICE, LAG(O_TOTALPRICE, 1) over () from ORDERS"); + } + + @Test + void lead() throws IOException, SqlParseException { + assertFullRoundTrip("select O_TOTALPRICE, LEAD(O_TOTALPRICE, 1) over () from ORDERS"); + } + @ParameterizedTest @ValueSource(strings = {"rank", "dense_rank", "percent_rank"}) void rankFunctions(String rankFunction) throws IOException, SqlParseException { @@ -170,4 +185,76 @@ void rejectQueriesWithIgnoreNulls() { var query = "select last_value(L_LINENUMBER) ignore nulls over () from lineitem"; assertThrows(IllegalArgumentException.class, () -> assertFullRoundTrip(query)); } + + @ParameterizedTest + @ValueSource(strings = {"lag", "lead"}) + void lagLeadFunctions(String function) { + Rel rel = + substraitBuilder.project( + input -> + List.of( + substraitBuilder.windowFn( + DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, + String.format("%s:any", function), + R.FP64, + Expression.AggregationPhase.INITIAL_TO_RESULT, + Expression.AggregationInvocation.ALL, + Expression.WindowBoundsType.ROWS, + WindowBound.Preceding.UNBOUNDED, + WindowBound.Following.CURRENT_ROW, + substraitBuilder.fieldReference(input, 0))), + substraitBuilder.remap(1), + substraitBuilder.namedScan(List.of("window_test"), List.of("a"), List.of(R.FP64))); + + assertFullRoundTrip(rel); + } + + @ParameterizedTest + @ValueSource(strings = {"lag", "lead"}) + void lagLeadWithOffset(String function) { + Rel rel = + substraitBuilder.project( + input -> + List.of( + substraitBuilder.windowFn( + DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, + String.format("%s:any_i32", function), + R.FP64, + Expression.AggregationPhase.INITIAL_TO_RESULT, + Expression.AggregationInvocation.ALL, + Expression.WindowBoundsType.RANGE, + WindowBound.Preceding.UNBOUNDED, + WindowBound.Following.UNBOUNDED, + substraitBuilder.fieldReference(input, 0), + substraitBuilder.i32(1))), + substraitBuilder.remap(1), + substraitBuilder.namedScan(List.of("window_test"), List.of("a"), List.of(R.FP64))); + + assertFullRoundTrip(rel); + } + + @ParameterizedTest + @ValueSource(strings = {"lag", "lead"}) + void lagLeadWithOffsetAndDefault(String function) { + Rel rel = + substraitBuilder.project( + input -> + List.of( + substraitBuilder.windowFn( + DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, + String.format("%s:any_i32_any", function), + R.I64, + Expression.AggregationPhase.INITIAL_TO_RESULT, + Expression.AggregationInvocation.ALL, + Expression.WindowBoundsType.ROWS, + WindowBound.Preceding.UNBOUNDED, + WindowBound.Following.CURRENT_ROW, + substraitBuilder.fieldReference(input, 0), + substraitBuilder.i32(1), + substraitBuilder.fp64(100.0))), + substraitBuilder.remap(1), + substraitBuilder.namedScan(List.of("window_test"), List.of("a"), List.of(R.FP64))); + + assertFullRoundTrip(rel); + } } diff --git a/readme.md b/readme.md index 31c33d348..4ae53a2fc 100644 --- a/readme.md +++ b/readme.md @@ -3,6 +3,7 @@ Substrait Java is a project that makes it easier to build [Substrait](https://substrait.io/) plans through Java. The project has two main parts: 1) **Core** is the module that supports building Substrait plans directly through Java. This is much easier than manipulating the Substrait protobuf directly. It has no direct support for going from SQL to Substrait (that's covered by the second part) 2) **Isthmus** is the module that allows going from SQL to a Substrait plan. Both Java APIs and a top level script for conversion are present. Not all SQL is supported yet by this module, but a lot is. For example, all of the TPC-H queries and all but a few of the TPC-DS queries are translatable. +3) **Spark** is the module that provides an API for translating a Substrait plan to and from a Spark query plan. The most commonly used logical relations are supported, including those generated from all of the TPC-H queries, but there are currently some gaps in support that prevent all of the TPC-DS queries from being translatable. ## Building After you've cloned the project through git, Substrait Java is built with a tool called [Gradle](https://gradle.org/). To build, execute the following: @@ -20,7 +21,16 @@ A good way to get started is to experiment with building Substrait plans for you Another way to get an idea of what Substrait plans look like is to use our script that generates Substrait plans for all the TPC-H queries: ``` -./isthmus/src/test/script/tpch_smoke.sh +./isthmus-cli/src/test/script/tpch_smoke.sh +``` + +## Logging +This project uses the [SLF4J](https://www.slf4j.org/) logging API. If you are using the Substrait Java core component as a dependency in your own project, you should consider including an appropriate [SLF4J logging provider](https://www.slf4j.org/manual.html#swapping) in your runtime classpath. If you do not include a logging provider in your classpath, the code will still work correctly but you will not receive any logging output and might see the following warning in your standard error output: + +``` +SLF4J(W): No SLF4J providers were found. +SLF4J(W): Defaulting to no-operation (NOP) logger implementation +SLF4J(W): See https://www.slf4j.org/codes.html#noProviders for further details. ``` ## Getting Involved diff --git a/settings.gradle.kts b/settings.gradle.kts index 7c34e4695..224c6b509 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -1,6 +1,6 @@ rootProject.name = "substrait" -include("bom", "core", "isthmus") +include("bom", "core", "isthmus", "isthmus-cli", "spark") pluginManagement { plugins { diff --git a/spark/build.gradle.kts b/spark/build.gradle.kts new file mode 100644 index 000000000..0501cf51e --- /dev/null +++ b/spark/build.gradle.kts @@ -0,0 +1,113 @@ +plugins { + `maven-publish` + id("java") + id("scala") + id("idea") + id("com.diffplug.spotless") version "6.11.0" + signing +} + +publishing { + publications { + create("maven-publish") { + from(components["java"]) + + pom { + name.set("Substrait Java") + description.set( + "Create a well-defined, cross-language specification for data compute operations" + ) + url.set("https://github.com/substrait-io/substrait-java") + licenses { + license { + name.set("The Apache License, Version 2.0") + url.set("http://www.apache.org/licenses/LICENSE-2.0.txt") + } + } + developers { + developer { + // TBD Get the list of + } + } + scm { + connection.set("scm:git:git://github.com:substrait-io/substrait-java.git") + developerConnection.set("scm:git:ssh://github.com:substrait-io/substrait-java") + url.set("https://github.com/substrait-io/substrait-java/") + } + } + } + } + repositories { + maven { + name = "local" + val releasesRepoUrl = layout.buildDirectory.dir("repos/releases") + val snapshotsRepoUrl = layout.buildDirectory.dir("repos/snapshots") + url = uri(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) + } + } +} + +signing { + setRequired({ gradle.taskGraph.hasTask("publishToSonatype") }) + val signingKeyId = + System.getenv("SIGNING_KEY_ID").takeUnless { it.isNullOrEmpty() } + ?: extra["SIGNING_KEY_ID"].toString() + val signingPassword = + System.getenv("SIGNING_PASSWORD").takeUnless { it.isNullOrEmpty() } + ?: extra["SIGNING_PASSWORD"].toString() + val signingKey = + System.getenv("SIGNING_KEY").takeUnless { it.isNullOrEmpty() } + ?: extra["SIGNING_KEY"].toString() + useInMemoryPgpKeys(signingKeyId, signingKey, signingPassword) + sign(publishing.publications["maven-publish"]) +} + +configurations.all { + if (name.startsWith("incrementalScalaAnalysis")) { + setExtendsFrom(emptyList()) + } +} + +java { + toolchain { languageVersion.set(JavaLanguageVersion.of(17)) } + withJavadocJar() + withSourcesJar() +} + +tasks.withType() { + targetCompatibility = "" + scalaCompileOptions.additionalParameters = listOf("-release:17") +} + +var SLF4J_VERSION = properties.get("slf4j.version") +var SPARKBUNDLE_VERSION = properties.get("sparkbundle.version") +var SPARK_VERSION = properties.get("spark.version") + +sourceSets { + main { scala { setSrcDirs(listOf("src/main/scala", "src/main/spark-${SPARKBUNDLE_VERSION}")) } } + test { scala { setSrcDirs(listOf("src/test/scala", "src/test/spark-3.2", "src/main/scala")) } } +} + +dependencies { + implementation(project(":core")) + implementation("org.scala-lang:scala-library:2.12.16") + implementation("org.apache.spark:spark-core_2.12:${SPARK_VERSION}") + implementation("org.apache.spark:spark-sql_2.12:${SPARK_VERSION}") + implementation("org.apache.spark:spark-catalyst_2.12:${SPARK_VERSION}") + implementation("org.slf4j:slf4j-api:${SLF4J_VERSION}") + + testImplementation("org.scalatest:scalatest_2.12:3.2.18") + testRuntimeOnly("org.junit.platform:junit-platform-engine:1.10.0") + testRuntimeOnly("org.junit.platform:junit-platform-launcher:1.10.0") + testRuntimeOnly("org.scalatestplus:junit-5-10_2.12:3.2.18.0") + testImplementation("org.apache.spark:spark-core_2.12:${SPARK_VERSION}:tests") + testImplementation("org.apache.spark:spark-sql_2.12:${SPARK_VERSION}:tests") + testImplementation("org.apache.spark:spark-catalyst_2.12:${SPARK_VERSION}:tests") +} + +tasks { + test { + dependsOn(":core:shadowJar") + useJUnitPlatform { includeEngines("scalatest") } + } +} diff --git a/spark/src/main/resources/spark.yml b/spark/src/main/resources/spark.yml new file mode 100644 index 000000000..e398aa3a6 --- /dev/null +++ b/spark/src/main/resources/spark.yml @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +%YAML 1.2 +--- +scalar_functions: + - + name: year + description: Returns the year component of the date/timestamp + impls: + - args: + - value: date + return: i32 + - + name: unscaled + description: >- + Return the unscaled Long value of a Decimal, assuming it fits in a Long. + Note: this expression is internal and created only by the optimizer, + we don't need to do type check for it. + impls: + - args: + - value: DECIMAL + return: i64 diff --git a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala new file mode 100644 index 000000000..421a040d2 --- /dev/null +++ b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.debug + +import io.substrait.spark.DefaultExpressionVisitor + +import org.apache.spark.sql.catalyst.util.DateTimeUtils + +import io.substrait.expression.{Expression, FieldReference} +import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, StrLiteral} +import io.substrait.function.ToTypeString +import io.substrait.util.DecimalUtil + +import scala.collection.JavaConverters.asScalaBufferConverter + +class ExpressionToString extends DefaultExpressionVisitor[String] { + + override def visit(expr: DecimalLiteral): String = { + val value = expr.value.toByteArray + val decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale, 16) + decimal.toString + } + + override def visit(expr: StrLiteral): String = { + expr.value() + } + + override def visit(expr: I32Literal): String = { + expr.value().toString + } + + override def visit(expr: DateLiteral): String = { + DateTimeUtils.toJavaDate(expr.value()).toString + } + + override def visit(expr: FieldReference): String = { + withFieldReference(expr)(i => "$" + i.toString) + } + + override def visit(expr: Expression.SingleOrList): String = { + expr.toString + } + + override def visit(expr: Expression.ScalarFunctionInvocation): String = { + val args = expr + .arguments() + .asScala + .zipWithIndex + .map { + case (arg, i) => + arg.accept(expr.declaration(), i, this) + } + .mkString(",") + + s"${expr.declaration().key()}[${expr.outputType().accept(ToTypeString.INSTANCE)}]($args)" + } + + override def visit(expr: Expression.UserDefinedLiteral): String = { + expr.toString + } +} diff --git a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala new file mode 100644 index 000000000..9f4f5c9f8 --- /dev/null +++ b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.debug + +import io.substrait.spark.DefaultRelVisitor + +import io.substrait.relation._ + +import scala.collection.mutable + +class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { + + private val expressionStringConverter = new ExpressionToString + + private def stringBuilder(rel: Rel, remapLength: Int): mutable.StringBuilder = { + val nodeName = rel.getClass.getSimpleName.replaceAll("Immutable", "") + val builder: mutable.StringBuilder = new mutable.StringBuilder(s"$nodeName[") + rel.getRemap.ifPresent(remap => builder.append("remap=").append(remap)) + if (builder.length > remapLength) builder.append(", ") + builder + } + + private def withBuilder(rel: Rel, remapLength: Int)(f: mutable.StringBuilder => Unit): String = { + val builder = stringBuilder(rel, remapLength) + f(builder) + builder.append("]").toString + } + + def apply(rel: Rel, maxFields: Int): String = { + rel.accept(this) + } + + override def visit(fetch: Fetch): String = { + withBuilder(fetch, 7)( + builder => { + builder.append("offset=").append(fetch.getOffset) + fetch.getCount.ifPresent( + count => { + builder.append(", ") + builder.append("count=").append(count) + }) + }) + } + override def visit(sort: Sort): String = { + withBuilder(sort, 5)( + builder => { + builder.append("sortFields=").append(sort.getSortFields) + }) + } + + override def visit(join: Join): String = { + withBuilder(join, 5)( + builder => { + join.getCondition.ifPresent( + condition => { + builder.append("condition=").append(condition) + builder.append(", ") + }) + + join.getPostJoinFilter.ifPresent( + postJoinFilter => { + builder.append("postJoinFilter=").append(postJoinFilter) + builder.append(", ") + }) + builder.append("joinType=").append(join.getJoinType) + }) + } + + override def visit(filter: Filter): String = { + withBuilder(filter, 7)( + builder => { + builder.append(filter.getCondition.accept(expressionStringConverter)) + }) + } + + def fillReadRel(read: AbstractReadRel, builder: mutable.StringBuilder): Unit = { + builder.append("initialSchema=").append(read.getInitialSchema) + read.getFilter.ifPresent( + filter => { + builder.append(", ") + builder.append("filter=").append(filter) + }) + read.getCommonExtension.ifPresent( + commonExtension => { + builder.append(", ") + builder.append("commonExtension=").append(commonExtension) + }) + } + override def visit(namedScan: NamedScan): String = { + withBuilder(namedScan, 10)( + builder => { + fillReadRel(namedScan, builder) + builder.append(", ") + builder.append("names=").append(namedScan.getNames) + + namedScan.getExtension.ifPresent( + extension => { + builder.append(", ") + builder.append("extension=").append(extension) + }) + }) + } + + override def visit(emptyScan: EmptyScan): String = { + withBuilder(emptyScan, 10)( + builder => { + fillReadRel(emptyScan, builder) + }) + } + + override def visit(project: Project): String = { + withBuilder(project, 8)( + builder => { + builder + .append("expressions=") + .append(project.getExpressions) + }) + } + + override def visit(aggregate: Aggregate): String = { + withBuilder(aggregate, 10)( + builder => { + builder + .append("groupings=") + .append(aggregate.getGroupings) + .append(", ") + .append("measures=") + .append(aggregate.getMeasures) + }) + } + + override def visit(localFiles: LocalFiles): String = { + withBuilder(localFiles, 10)( + builder => { + builder + .append("items=") + .append(localFiles.getItems) + }) + } +} + +object RelToVerboseString { + val verboseStringWithSuffix = new RelToVerboseString(true) + val verboseString = new RelToVerboseString(false) +} diff --git a/spark/src/main/scala/io/substrait/debug/TreePrinter.scala b/spark/src/main/scala/io/substrait/debug/TreePrinter.scala new file mode 100644 index 000000000..cd50f412b --- /dev/null +++ b/spark/src/main/scala/io/substrait/debug/TreePrinter.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.debug + +import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat +import org.apache.spark.sql.internal.SQLConf + +import RelToVerboseString.{verboseString, verboseStringWithSuffix} +import io.substrait.relation +import io.substrait.relation.Rel + +import scala.collection.JavaConverters.asScalaBufferConverter + +trait TreePrinter[T] { + def tree(t: T): String +} + +object TreePrinter { + + implicit object SubstraitRel extends TreePrinter[relation.Rel] { + override def tree(t: Rel): String = TreePrinter.tree(t) + } + + final def tree(rel: relation.Rel): String = treeString(rel, verbose = true) + + final def treeString( + rel: relation.Rel, + verbose: Boolean, + addSuffix: Boolean = false, + maxFields: Int = SQLConf.get.maxToStringFields, + printOperatorId: Boolean = false): String = { + val concat = new PlanStringConcat() + treeString(rel, concat.append, verbose, addSuffix, maxFields, printOperatorId) + concat.toString + } + + def treeString( + rel: relation.Rel, + append: String => Unit, + verbose: Boolean, + addSuffix: Boolean, + maxFields: Int, + printOperatorId: Boolean): Unit = { + generateTreeString(rel, 0, Nil, append, verbose, "", addSuffix, maxFields, printOperatorId) + } + + /** + * Appends the string representation of this node and its children to the given Writer. + * + * The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at + * depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and + * `lastChildren` for the root node should be empty. + */ + def generateTreeString( + rel: relation.Rel, + depth: Int, + lastChildren: Seq[Boolean], + append: String => Unit, + verbose: Boolean, + prefix: String = "", + addSuffix: Boolean = false, + maxFields: Int, + printNodeId: Boolean, + indent: Int = 0): Unit = { + + append(" " * indent) + if (depth > 0) { + lastChildren.init.foreach(isLast => append(if (isLast) " " else ": ")) + append(if (lastChildren.last) "+- " else ":- ") + } + + val str = if (verbose) { + if (addSuffix) verboseStringWithSuffix(rel, maxFields) else verboseString(rel, maxFields) + } else { + "" + } + append(prefix) + append(str) + append("\n") + + val children = rel.getInputs.asScala + if (children.nonEmpty) { + children.init.foreach( + generateTreeString( + _, + depth + 1, + lastChildren :+ false, + append, + verbose, + prefix, + addSuffix, + maxFields, + printNodeId = printNodeId, + indent = indent)) + + generateTreeString( + children.last, + depth + 1, + lastChildren :+ true, + append, + verbose, + prefix, + addSuffix, + maxFields, + printNodeId = printNodeId, + indent = indent) + } + } +} diff --git a/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala new file mode 100644 index 000000000..d0d2e0d00 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark + +import io.substrait.`type`.Type +import io.substrait.expression._ +import io.substrait.extension.SimpleExtension + +class DefaultExpressionVisitor[T] + extends AbstractExpressionVisitor[T, RuntimeException] + with FunctionArg.FuncArgVisitor[T, RuntimeException] { + + override def visitFallback(expr: Expression): T = + throw new UnsupportedOperationException( + s"Expression type ${expr.getClass.getCanonicalName} " + + s"not handled by visitor type ${getClass.getCanonicalName}.") + + override def visitType(fnDef: SimpleExtension.Function, argIdx: Int, t: Type): T = + throw new UnsupportedOperationException( + s"FunctionArg $t not handled by visitor type ${getClass.getCanonicalName}.") + + override def visitEnumArg(fnDef: SimpleExtension.Function, argIdx: Int, e: EnumArg): T = + throw new UnsupportedOperationException( + s"EnumArg(value=${e.value()}) not handled by visitor type ${getClass.getCanonicalName}.") + + protected def withFieldReference(fieldReference: FieldReference)(f: Int => T): T = { + if (fieldReference.isSimpleRootReference) { + val segment = fieldReference.segments().get(0) + segment match { + case s: FieldReference.StructField => f(s.offset()) + case _ => throw new IllegalArgumentException(s"Unhandled type: $segment") + } + } else { + visitFallback(fieldReference) + } + } + + override def visitExpr(fnDef: SimpleExtension.Function, argIdx: Int, e: Expression): T = + e.accept(this) + + override def visit(userDefinedLiteral: Expression.UserDefinedLiteral): T = { + visitFallback(userDefinedLiteral) + } +} diff --git a/spark/src/main/scala/io/substrait/spark/DefaultRelVisitor.scala b/spark/src/main/scala/io/substrait/spark/DefaultRelVisitor.scala new file mode 100644 index 000000000..7f1e181b5 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/DefaultRelVisitor.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark + +import io.substrait.relation +import io.substrait.relation.AbstractRelVisitor + +class DefaultRelVisitor[T] extends AbstractRelVisitor[T, RuntimeException] { + + override def visitFallback(rel: relation.Rel): T = + throw new UnsupportedOperationException( + s"Type ${rel.getClass.getCanonicalName}" + + s" not handled by visitor type ${getClass.getCanonicalName}.") +} diff --git a/spark/src/main/scala/io/substrait/spark/HasOutputStack.scala b/spark/src/main/scala/io/substrait/spark/HasOutputStack.scala new file mode 100644 index 000000000..3ff416991 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/HasOutputStack.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark + +import scala.collection.mutable + +trait HasOutputStack[T] { + private val outputStack = mutable.Stack[T]() + def currentOutput: T = outputStack.top + def pushOutput(e: T): Unit = outputStack.push(e) + def popOutput(): T = outputStack.pop() +} diff --git a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala new file mode 100644 index 000000000..0d6d84b71 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark + +import io.substrait.spark.expression.ToAggregateFunction + +import io.substrait.extension.SimpleExtension + +import java.util.Collections + +import scala.collection.JavaConverters +import scala.collection.JavaConverters.asScalaBufferConverter + +object SparkExtension { + private val SparkImpls: SimpleExtension.ExtensionCollection = + SimpleExtension.load(Collections.singletonList("/spark.yml")) + + private val EXTENSION_COLLECTION: SimpleExtension.ExtensionCollection = + SimpleExtension.loadDefaults() + + lazy val SparkScalarFunctions: Seq[SimpleExtension.ScalarFunctionVariant] = { + val ret = new collection.mutable.ArrayBuffer[SimpleExtension.ScalarFunctionVariant]() + ret.appendAll(EXTENSION_COLLECTION.scalarFunctions().asScala) + ret.appendAll(SparkImpls.scalarFunctions().asScala) + ret + } + + val toAggregateFunction: ToAggregateFunction = ToAggregateFunction( + JavaConverters.asScalaBuffer(EXTENSION_COLLECTION.aggregateFunctions())) +} diff --git a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala new file mode 100644 index 000000000..9522042ee --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark + +import io.substrait.`type`.{NamedStruct, Type, TypeVisitor} +import io.substrait.function.TypeExpression +import io.substrait.utils.Util +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.types._ + +import scala.collection.JavaConverters +import scala.collection.JavaConverters.asScalaBufferConverter + +private class ToSparkType + extends TypeVisitor.TypeThrowsVisitor[DataType, RuntimeException]("Unknown expression type.") { + + override def visit(expr: Type.I32): DataType = IntegerType + override def visit(expr: Type.I64): DataType = LongType + + override def visit(expr: Type.FP32): DataType = FloatType + override def visit(expr: Type.FP64): DataType = DoubleType + + override def visit(expr: Type.Decimal): DataType = + DecimalType(expr.precision(), expr.scale()) + + override def visit(expr: Type.Date): DataType = DateType + + override def visit(expr: Type.Str): DataType = StringType + + override def visit(expr: Type.FixedChar): DataType = StringType + + override def visit(expr: Type.VarChar): DataType = StringType +} +class ToSubstraitType { + + def convert(typeExpression: TypeExpression): DataType = { + typeExpression.accept(new ToSparkType) + } + + def convert(dataType: DataType, nullable: Boolean): Option[Type] = { + convert(dataType, Seq.empty, nullable) + } + + def apply(dataType: DataType, nullable: Boolean): Type = { + convert(dataType, Seq.empty, nullable) + .getOrElse( + throw new UnsupportedOperationException(s"Unable to convert the type ${dataType.typeName}")) + } + + protected def convert(dataType: DataType, names: Seq[String], nullable: Boolean): Option[Type] = { + val creator = Type.withNullability(nullable) + dataType match { + case BooleanType => Some(creator.BOOLEAN) + case ByteType => Some(creator.I8) + case ShortType => Some(creator.I16) + case IntegerType => Some(creator.I32) + case LongType => Some(creator.I64) + case FloatType => Some(creator.FP32) + case DoubleType => Some(creator.FP64) + case decimal: DecimalType if decimal.precision <= 38 => + Some(creator.decimal(decimal.precision, decimal.scale)) + case charType: CharType => Some(creator.fixedChar(charType.length)) + case varcharType: VarcharType => Some(creator.varChar(varcharType.length)) + case StringType => Some(creator.STRING) + case DateType => Some(creator.DATE) + case TimestampType => Some(creator.TIMESTAMP) + case TimestampNTZType => Some(creator.TIMESTAMP_TZ) + case BinaryType => Some(creator.BINARY) + case ArrayType(elementType, containsNull) => + convert(elementType, Seq.empty, containsNull).map(creator.list) + case MapType(keyType, valueType, valueContainsNull) => + convert(keyType, Seq.empty, nullable = false) + .flatMap( + keyT => + convert(valueType, Seq.empty, valueContainsNull) + .map(valueT => creator.map(keyT, valueT))) + case _ => + None + } + } + def toNamedStruct(output: Seq[Attribute]): Option[NamedStruct] = { + val names = JavaConverters.seqAsJavaList(output.map(_.name)) + val creator = Type.withNullability(false) + Util + .seqToOption(output.map(a => convert(a.dataType, a.nullable))) + .map(l => creator.struct(JavaConverters.asJavaIterable(l))) + .map(NamedStruct.of(names, _)) + } + def toNamedStruct(schema: StructType): NamedStruct = { + val creator = Type.withNullability(false) + val names = new java.util.ArrayList[String] + val children = new java.util.ArrayList[Type] + schema.fields.foreach( + field => { + names.add(field.name) + children.add(apply(field.dataType, field.nullable)) + }) + val struct = creator.struct(children) + NamedStruct.of(names, struct) + } + + def toStructType(namedStruct: NamedStruct): StructType = { + StructType( + fields = namedStruct + .struct() + .fields() + .asScala + .map(t => (t, convert(t))) + .zip(namedStruct.names().asScala) + .map { case ((t, d), name) => StructField(name, d, t.nullable()) } + ) + } + + def toAttribute(namedStruct: NamedStruct): Seq[AttributeReference] = { + namedStruct + .struct() + .fields() + .asScala + .map(t => (t, convert(t))) + .zip(namedStruct.names().asScala) + .map { case ((t, d), name) => StructField(name, d, t.nullable()) } + .map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + } +} + +object ToSubstraitType extends ToSubstraitType diff --git a/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala b/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala new file mode 100644 index 000000000..7b0c2b473 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.types.DataType + +import com.google.common.collect.{ArrayListMultimap, Multimap} +import io.substrait.`type`.Type +import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FunctionArg} +import io.substrait.expression.Expression.FailureBehavior +import io.substrait.extension.SimpleExtension +import io.substrait.function.{ParameterizedType, ToTypeString} +import io.substrait.spark.ToSubstraitType +import io.substrait.utils.Util + +import java.{util => ju} + +import scala.annotation.tailrec +import scala.collection.JavaConverters +import scala.collection.JavaConverters.collectionAsScalaIterableConverter + +abstract class FunctionConverter[F <: SimpleExtension.Function, T](functions: Seq[F]) + extends Logging { + + protected val (signatures, substraitFuncKeyToSig) = init(functions) + + def generateBinding( + sparkExp: Expression, + function: F, + arguments: Seq[FunctionArg], + outputType: Type): T + def getSigs: Seq[Sig] + + private def init( + functions: Seq[F]): (ju.Map[Class[_], FunctionFinder[F, T]], Multimap[String, Sig]) = { + val alm = ArrayListMultimap.create[String, F]() + functions.foreach(f => alm.put(f.name().toLowerCase(ju.Locale.ROOT), f)) + + val sparkExpressions = ArrayListMultimap.create[String, Sig]() + getSigs.foreach(f => sparkExpressions.put(f.name, f)) + val matcherMap = + new ju.IdentityHashMap[Class[_], FunctionFinder[F, T]] + + JavaConverters + .asScalaSet(alm.keySet()) + .foreach( + key => { + val sigs = sparkExpressions.get(key) + if (sigs == null) { + logInfo("Dropping function due to no binding:" + key) + } else { + JavaConverters + .asScalaBuffer(sigs) + .foreach( + sig => { + val implList = alm.get(key) + if (implList != null && !implList.isEmpty) { + matcherMap + .put(sig.expClass, createFinder(key, JavaConverters.asScalaBuffer(implList))) + } + }) + } + }) + val keyMap = ArrayListMultimap.create[String, Sig] + + alm.entries.asScala.foreach( + entry => + sparkExpressions + .get(entry.getKey) + .asScala + .foreach(keyMap.put(entry.getValue.key(), _))) + + (matcherMap, keyMap) + } + + def getSparkExpressionFromSubstraitFunc(key: String, outputType: Type): Option[Sig] = { + val sigs = substraitFuncKeyToSig.get(key) + sigs.size() match { + case 0 => None + case 1 => Some(sigs.iterator().next()) + case _ => None + } + } + private def createFinder(name: String, functions: Seq[F]): FunctionFinder[F, T] = { + new FunctionFinder[F, T]( + name, + functions + .flatMap( + func => + if (func.requiredArguments().size() != func.args().size()) { + Seq( + func.key() -> func, + SimpleExtension.Function.constructKey(name, func.requiredArguments()) -> func) + } else { + Seq(func.key() -> func) + }) + .toMap, + FunctionFinder.getSingularInputType(functions), + parent = this + ) + } +} + +object FunctionFinder extends SQLConfHelper { + + /** + * Returns the most general of a set of types (that is, one type to which they can all be cast), + * or [[None]] if conversion is not possible. The result may be a new type that is less + * restrictive than any of the input types, e.g. leastRestrictive(INT, NUMERIC(3, 2)) + * could be NUMERIC(12, 2). + * + * @param types + * input types to be combined using union (not null, not empty) + * @return + * canonical union type descriptor + */ + def leastRestrictive(types: Seq[DataType]): Option[DataType] = { + val typeCoercion = if (conf.ansiEnabled) { + AnsiTypeCoercion + } else { + TypeCoercion + } + typeCoercion.findWiderCommonType(types) + } + + /** + * If some of the function variants for this function name have single, repeated argument type, we + * will attempt to find matches using these patterns and least-restrictive casting. + * + *

If this exists, the function finder will attempt to find a least-restrictive match using + * these. + */ + def getSingularInputType[F <: SimpleExtension.Function]( + functions: Seq[F]): Option[SingularArgumentMatcher[F]] = { + + @tailrec + def determineFirstType( + first: ParameterizedType, + index: Int, + list: ju.List[SimpleExtension.Argument]): ParameterizedType = + if (index >= list.size()) { + first + } else { + list.get(index) match { + case argument: SimpleExtension.ValueArgument => + val pt = argument.value() + val first_or_pt = if (first == null) pt else first + if (first == null || isMatch(first, pt)) { + determineFirstType(first_or_pt, index + 1, list) + } else { + null + } + case _ => null + } + } + + val matchers = functions + .map(f => (f, determineFirstType(null, 0, f.requiredArguments()))) + .filter(_._2 != null) + .map(f => singular(f._1, f._2)) + + matchers.size match { + case 0 => None + case 1 => Some(matchers.head) + case _ => Some(chained(matchers)) + } + } + + private def isMatch( + inputType: ParameterizedType, + parameterizedType: ParameterizedType): Boolean = { + if (parameterizedType.isWildcard) { + true + } else { + inputType.accept(new IgnoreNullableAndParameters(parameterizedType)) + } + } + + private def isMatch(inputType: Type, parameterizedType: ParameterizedType): Boolean = { + if (parameterizedType.isWildcard) { + true + } else { + inputType.accept(new IgnoreNullableAndParameters(parameterizedType)) + } + } + + def singular[F <: SimpleExtension.Function]( + function: F, + t: ParameterizedType): SingularArgumentMatcher[F] = + (inputType: Type, outputType: Type) => if (isMatch(inputType, t)) Some(function) else None + + def collectFirst[F <: SimpleExtension.Function]( + matchers: Seq[SingularArgumentMatcher[F]], + inputType: Type, + outputType: Type): Option[F] = { + val iter = matchers.iterator + while (iter.hasNext) { + val s = iter.next() + val result = s.apply(inputType, outputType) + if (result.isDefined) { + return result + } + } + None + } + + def chained[F <: SimpleExtension.Function]( + matchers: Seq[SingularArgumentMatcher[F]]): SingularArgumentMatcher[F] = + (inputType: Type, outputType: Type) => collectFirst(matchers, inputType, outputType) +} + +trait SingularArgumentMatcher[F <: SimpleExtension.Function] extends ((Type, Type) => Option[F]) + +class FunctionFinder[F <: SimpleExtension.Function, T]( + val name: String, + val directMap: Map[String, F], + val singularInputType: Option[SingularArgumentMatcher[F]], + val parent: FunctionConverter[F, T]) { + + def attemptMatch(expression: Expression, operands: Seq[SExpression]): Option[T] = { + + val opTypes = operands.map(_.getType) + val outputType = ToSubstraitType.apply(expression.dataType, expression.nullable) + val opTypesStr = opTypes.map(t => t.accept(ToTypeString.INSTANCE)) + + val possibleKeys = + Util.crossProduct(opTypesStr.map(s => Seq(s))).map(list => list.mkString("_")) + + val directMatchKey = possibleKeys + .map(name + ":" + _) + .find(k => directMap.contains(k)) + + if (directMatchKey.isDefined) { + val variant = directMap(directMatchKey.get) + variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType) + val funcArgs: Seq[FunctionArg] = operands + Option(parent.generateBinding(expression, variant, funcArgs, outputType)) + } else if (singularInputType.isDefined) { + val types = expression match { + case agg: AggregateExpression => agg.aggregateFunction.children.map(_.dataType) + case other => other.children.map(_.dataType) + } + val nullable = expression.children.exists(e => e.nullable) + FunctionFinder + .leastRestrictive(types) + .flatMap( + leastRestrictive => { + val leastRestrictiveSubstraitT = + ToSubstraitType.apply(leastRestrictive, nullable = nullable) + singularInputType + .flatMap(f => f(leastRestrictiveSubstraitT, outputType)) + .map( + declaration => { + val coercedArgs = coerceArguments(operands, leastRestrictiveSubstraitT) + declaration.validateOutputType( + JavaConverters.bufferAsJavaList(coercedArgs.toBuffer), + outputType) + val funcArgs: Seq[FunctionArg] = coercedArgs + parent.generateBinding(expression, declaration, funcArgs, outputType) + }) + }) + } else { + None + } + } + + /** + * Coerced types according to an expected output type. Coercion is only done for type mismatches, + * not for nullability or parameter mismatches. + */ + private def coerceArguments(arguments: Seq[SExpression], t: Type): Seq[SExpression] = { + arguments.map( + a => { + if (FunctionFinder.isMatch(t, a.getType)) { + a + } else { + ExpressionCreator.cast(t, a, FailureBehavior.THROW_EXCEPTION) + } + }) + } + + def allowedArgCount(count: Int): Boolean = true +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala new file mode 100644 index 000000000..08326454e --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ + +import scala.reflect.ClassTag + +case class Sig(expClass: Class[_], name: String, builder: Seq[Expression] => Expression) { + def makeCall(args: Seq[Expression]): Expression = + builder(args) +} + +class FunctionMappings { + + private def s[T <: Expression: ClassTag](name: String): Sig = { + val builder = FunctionRegistryBase.build[T](name, None)._2 + Sig(scala.reflect.classTag[T].runtimeClass, name, builder) + } + + val SCALAR_SIGS: Seq[Sig] = Seq( + s[Add]("add"), + s[Subtract]("subtract"), + s[Multiply]("multiply"), + s[Divide]("divide"), + s[And]("and"), + s[Or]("or"), + s[Not]("not"), + s[LessThan]("lt"), + s[LessThanOrEqual]("lte"), + s[GreaterThan]("gt"), + s[GreaterThanOrEqual]("gte"), + s[EqualTo]("equal"), + // s[BitwiseXor]("xor"), + s[IsNull]("is_null"), + s[IsNotNull]("is_not_null"), + s[EndsWith]("ends_with"), + s[Like]("like"), + s[Contains]("contains"), + s[StartsWith]("starts_with"), + s[Substring]("substring"), + s[Year]("year"), + + // internal + s[UnscaledValue]("unscaled") + ) + + val AGGREGATE_SIGS: Seq[Sig] = Seq( + s[Sum]("sum"), + s[Average]("avg"), + s[Count]("count"), + s[Min]("min"), + s[Max]("max"), + s[HyperLogLogPlusPlus]("approx_count_distinct") + ) + + lazy val scalar_functions_map: Map[Class[_], Sig] = SCALAR_SIGS.map(s => (s.expClass, s)).toMap + lazy val aggregate_functions_map: Map[Class[_], Sig] = + AGGREGATE_SIGS.map(s => (s.expClass, s)).toMap +} + +object FunctionMappings extends FunctionMappings diff --git a/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala b/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala new file mode 100644 index 000000000..962a98b16 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import io.substrait.`type`.Type +import io.substrait.function.{ParameterizedType, ParameterizedTypeVisitor} + +import scala.annotation.nowarn + +class IgnoreNullableAndParameters(val typeToMatch: ParameterizedType) + extends ParameterizedTypeVisitor[Boolean, RuntimeException] { + + override def visit(`type`: Type.Bool): Boolean = typeToMatch.isInstanceOf[Type.Bool] + + override def visit(`type`: Type.I8): Boolean = typeToMatch.isInstanceOf[Type.I8] + + override def visit(`type`: Type.I16): Boolean = typeToMatch.isInstanceOf[Type.I16] + + override def visit(`type`: Type.I32): Boolean = typeToMatch.isInstanceOf[Type.I32] + + override def visit(`type`: Type.I64): Boolean = typeToMatch.isInstanceOf[Type.I64] + + override def visit(`type`: Type.FP32): Boolean = typeToMatch.isInstanceOf[Type.FP32] + + override def visit(`type`: Type.FP64): Boolean = typeToMatch.isInstanceOf[Type.FP64] + + override def visit(`type`: Type.Str): Boolean = typeToMatch.isInstanceOf[Type.Str] + + override def visit(`type`: Type.Binary): Boolean = typeToMatch.isInstanceOf[Type.Binary] + + override def visit(`type`: Type.Date): Boolean = typeToMatch.isInstanceOf[Type.Date] + + override def visit(`type`: Type.Time): Boolean = typeToMatch.isInstanceOf[Type.Time] + + @nowarn + override def visit(`type`: Type.TimestampTZ): Boolean = typeToMatch.isInstanceOf[Type.TimestampTZ] + + @nowarn + override def visit(`type`: Type.Timestamp): Boolean = typeToMatch.isInstanceOf[Type.Timestamp] + + override def visit(`type`: Type.IntervalYear): Boolean = + typeToMatch.isInstanceOf[Type.IntervalYear] + + override def visit(`type`: Type.IntervalDay): Boolean = typeToMatch.isInstanceOf[Type.IntervalDay] + + override def visit(`type`: Type.UUID): Boolean = typeToMatch.isInstanceOf[Type.UUID] + + override def visit(`type`: Type.FixedChar): Boolean = + typeToMatch.isInstanceOf[Type.FixedChar] || typeToMatch + .isInstanceOf[ParameterizedType.FixedChar] + + override def visit(`type`: Type.VarChar): Boolean = + typeToMatch.isInstanceOf[Type.VarChar] || typeToMatch.isInstanceOf[ParameterizedType.VarChar] + + override def visit(`type`: Type.FixedBinary): Boolean = + typeToMatch.isInstanceOf[Type.FixedBinary] || typeToMatch + .isInstanceOf[ParameterizedType.FixedBinary] + + override def visit(`type`: Type.Decimal): Boolean = + typeToMatch.isInstanceOf[Type.Decimal] || typeToMatch.isInstanceOf[ParameterizedType.Decimal] + + override def visit(`type`: Type.Struct): Boolean = + typeToMatch.isInstanceOf[Type.Struct] || typeToMatch.isInstanceOf[ParameterizedType.Struct] + + override def visit(`type`: Type.ListType): Boolean = + typeToMatch.isInstanceOf[Type.ListType] || typeToMatch.isInstanceOf[ParameterizedType.ListType] + + override def visit(`type`: Type.Map): Boolean = + typeToMatch.isInstanceOf[Type.Map] || typeToMatch.isInstanceOf[ParameterizedType.Map] + + override def visit(`type`: Type.UserDefined): Boolean = + typeToMatch.isInstanceOf[Type.UserDefined] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.FixedChar): Boolean = + typeToMatch.isInstanceOf[Type.FixedChar] || typeToMatch + .isInstanceOf[ParameterizedType.FixedChar] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.VarChar): Boolean = + typeToMatch.isInstanceOf[Type.VarChar] || typeToMatch.isInstanceOf[ParameterizedType.VarChar] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.FixedBinary): Boolean = + typeToMatch.isInstanceOf[Type.FixedBinary] || typeToMatch + .isInstanceOf[ParameterizedType.FixedBinary] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.Decimal): Boolean = + typeToMatch.isInstanceOf[Type.Decimal] || typeToMatch.isInstanceOf[ParameterizedType.Decimal] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.Struct): Boolean = + typeToMatch.isInstanceOf[Type.Struct] || typeToMatch.isInstanceOf[ParameterizedType.Struct] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.ListType): Boolean = + typeToMatch.isInstanceOf[Type.ListType] || typeToMatch.isInstanceOf[ParameterizedType.ListType] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.Map): Boolean = + typeToMatch.isInstanceOf[Type.Map] || typeToMatch.isInstanceOf[ParameterizedType.Map] + + @throws[RuntimeException] + override def visit(stringLiteral: ParameterizedType.StringLiteral): Boolean = false + + @throws[RuntimeException] + override def visit(precisionTimestamp: ParameterizedType.PrecisionTimestamp): Boolean = + typeToMatch.isInstanceOf[ParameterizedType.PrecisionTimestamp] + + @throws[RuntimeException] + override def visit(precisionTimestampTZ: ParameterizedType.PrecisionTimestampTZ): Boolean = + typeToMatch.isInstanceOf[ParameterizedType.PrecisionTimestampTZ] + + @throws[RuntimeException] + override def visit(precisionTimestamp: Type.PrecisionTimestamp): Boolean = + typeToMatch.isInstanceOf[Type.PrecisionTimestamp] + + @throws[RuntimeException] + override def visit(precisionTimestampTZ: Type.PrecisionTimestampTZ): Boolean = + typeToMatch.isInstanceOf[Type.PrecisionTimestampTZ] +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToAggregateFunction.scala b/spark/src/main/scala/io/substrait/spark/expression/ToAggregateFunction.scala new file mode 100644 index 000000000..0c5b50c6c --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/ToAggregateFunction.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.aggregate._ + +import io.substrait.`type`.Type +import io.substrait.expression.{AggregateFunctionInvocation, Expression => SExpression, ExpressionCreator, FunctionArg} +import io.substrait.extension.SimpleExtension + +import java.util.Collections + +import scala.collection.JavaConverters + +abstract class ToAggregateFunction(functions: Seq[SimpleExtension.AggregateFunctionVariant]) + extends FunctionConverter[SimpleExtension.AggregateFunctionVariant, AggregateFunctionInvocation]( + functions) { + + override def generateBinding( + sparkExp: Expression, + function: SimpleExtension.AggregateFunctionVariant, + arguments: Seq[FunctionArg], + outputType: Type): AggregateFunctionInvocation = { + + val sparkAggregate = sparkExp.asInstanceOf[AggregateExpression] + + ExpressionCreator.aggregateFunction( + function, + outputType, + ToAggregateFunction.fromSpark(sparkAggregate.mode), + Collections.emptyList[SExpression.SortField](), + ToAggregateFunction.fromSpark(sparkAggregate.isDistinct), + JavaConverters.asJavaIterable(arguments) + ) + } + + def convert( + expression: AggregateExpression, + operands: Seq[SExpression]): Option[AggregateFunctionInvocation] = { + Option(signatures.get(expression.aggregateFunction.getClass)) + .filter(m => m.allowedArgCount(2)) + .flatMap(m => m.attemptMatch(expression, operands)) + } + + def apply( + expression: AggregateExpression, + operands: Seq[SExpression]): AggregateFunctionInvocation = { + convert(expression, operands).getOrElse( + throw new UnsupportedOperationException( + s"Unable to find binding for call ${expression.aggregateFunction}")) + } +} + +object ToAggregateFunction { + def fromSpark(mode: AggregateMode): SExpression.AggregationPhase = mode match { + case Partial => SExpression.AggregationPhase.INITIAL_TO_INTERMEDIATE + case PartialMerge => SExpression.AggregationPhase.INTERMEDIATE_TO_INTERMEDIATE + case Final => SExpression.AggregationPhase.INTERMEDIATE_TO_RESULT + case Complete => SExpression.AggregationPhase.INITIAL_TO_RESULT + case other => throw new UnsupportedOperationException(s"not currently supported: $other.") + } + def toSpark(phase: SExpression.AggregationPhase): AggregateMode = phase match { + case SExpression.AggregationPhase.INITIAL_TO_INTERMEDIATE => Partial + case SExpression.AggregationPhase.INTERMEDIATE_TO_INTERMEDIATE => PartialMerge + case SExpression.AggregationPhase.INTERMEDIATE_TO_RESULT => Final + case SExpression.AggregationPhase.INITIAL_TO_RESULT => Complete + } + def fromSpark(isDistinct: Boolean): SExpression.AggregationInvocation = if (isDistinct) { + SExpression.AggregationInvocation.DISTINCT + } else { + SExpression.AggregationInvocation.ALL + } + + def toSpark(innovation: SExpression.AggregationInvocation): Boolean = innovation match { + case SExpression.AggregationInvocation.DISTINCT => true + case _ => false + } + + def apply(functions: Seq[SimpleExtension.AggregateFunctionVariant]): ToAggregateFunction = { + new ToAggregateFunction(functions) { + override def getSigs: Seq[Sig] = FunctionMappings.AGGREGATE_SIGS + } + } + +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToScalarFunction.scala b/spark/src/main/scala/io/substrait/spark/expression/ToScalarFunction.scala new file mode 100644 index 000000000..cd23611ec --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/ToScalarFunction.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import org.apache.spark.sql.catalyst.expressions.Expression + +import io.substrait.`type`.Type +import io.substrait.expression.{Expression => SExpression, FunctionArg} +import io.substrait.extension.SimpleExtension + +import scala.collection.JavaConverters + +abstract class ToScalarFunction(functions: Seq[SimpleExtension.ScalarFunctionVariant]) + extends FunctionConverter[SimpleExtension.ScalarFunctionVariant, SExpression](functions) { + + override def generateBinding( + sparkExp: Expression, + function: SimpleExtension.ScalarFunctionVariant, + arguments: Seq[FunctionArg], + outputType: Type): SExpression = { + SExpression.ScalarFunctionInvocation + .builder() + .outputType(outputType) + .declaration(function) + .addAllArguments(JavaConverters.asJavaIterable(arguments)) + .build() + } + + def convert(expression: Expression, operands: Seq[SExpression]): Option[SExpression] = { + Option(signatures.get(expression.getClass)) + .filter(m => m.allowedArgCount(2)) + .flatMap(m => m.attemptMatch(expression, operands)) + } +} + +object ToScalarFunction { + def apply(functions: Seq[SimpleExtension.ScalarFunctionVariant]): ToScalarFunction = { + new ToScalarFunction(functions) { + override def getSigs: Seq[Sig] = FunctionMappings.SCALAR_SIGS + } + } +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala new file mode 100644 index 000000000..a4ee9aaee --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, ToSubstraitType} +import io.substrait.spark.logical.ToLogicalPlan + +import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, NamedExpression, ScalarSubquery} +import org.apache.spark.sql.types.Decimal +import org.apache.spark.unsafe.types.UTF8String + +import io.substrait.`type`.{StringTypeVisitor, Type} +import io.substrait.{expression => exp} +import io.substrait.expression.{Expression => SExpression} +import io.substrait.util.DecimalUtil +import org.apache.spark.substrait.SparkTypeUtil + +import scala.collection.JavaConverters.asScalaBufferConverter + +class ToSparkExpression( + val scalarFunctionConverter: ToScalarFunction, + val toLogicalPlan: Option[ToLogicalPlan] = None) + extends DefaultExpressionVisitor[Expression] + with HasOutputStack[Seq[NamedExpression]] { + + override def visit(expr: SExpression.BoolLiteral): Expression = { + if (expr.value()) { + Literal.TrueLiteral + } else { + Literal.FalseLiteral + } + } + override def visit(expr: SExpression.I32Literal): Expression = { + Literal(expr.value(), ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.I64Literal): Expression = { + Literal(expr.value(), ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.FP64Literal): Expression = { + Literal(expr.value(), ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.StrLiteral): Expression = { + Literal(UTF8String.fromString(expr.value()), ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.FixedCharLiteral): Expression = { + Literal(UTF8String.fromString(expr.value()), ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.VarCharLiteral): Expression = { + Literal(UTF8String.fromString(expr.value()), ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.DecimalLiteral): Expression = { + val value = expr.value.toByteArray + val decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale, 16) + Literal(Decimal(decimal), ToSubstraitType.convert(expr.getType)) + } + override def visit(expr: SExpression.DateLiteral): Expression = { + Literal(expr.value(), ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.Cast): Expression = { + val childExp = expr.input().accept(this) + Cast(childExp, ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: exp.FieldReference): Expression = { + withFieldReference(expr)(i => currentOutput(i).clone()) + } + override def visit(expr: SExpression.IfThen): Expression = { + val branches = expr + .ifClauses() + .asScala + .map( + ifClause => { + val predicate = ifClause.condition().accept(this) + val elseValue = ifClause.`then`().accept(this) + (predicate, elseValue) + }) + val default = expr.elseClause().accept(this) match { + case l: Literal if l.nullable => None + case other => Some(other) + } + CaseWhen(branches, default) + } + + override def visit(expr: SExpression.ScalarSubquery): Expression = { + val rel = expr.input() + val dataType = ToSubstraitType.convert(expr.getType) + toLogicalPlan + .map( + relConverter => { + val plan = rel.accept(relConverter) + require(plan.resolved) + val result = ScalarSubquery(plan) + SparkTypeUtil.sameType(result.dataType, dataType) + result + }) + .getOrElse(visitFallback(expr)) + } + + override def visit(expr: SExpression.SingleOrList): Expression = { + val value = expr.condition().accept(this) + val list = expr.options().asScala.map(e => e.accept(this)) + In(value, list) + } + override def visit(expr: SExpression.ScalarFunctionInvocation): Expression = { + val eArgs = expr.arguments().asScala + val args = eArgs.zipWithIndex.map { + case (arg, i) => + arg.accept(expr.declaration(), i, this) + } + + scalarFunctionConverter + .getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType()) + .flatMap(sig => Option(sig.makeCall(args))) + .getOrElse({ + val msg = String.format( + "Unable to convert scalar function %s(%s).", + expr.declaration.name, + expr.arguments.asScala + .map { + case ea: exp.EnumArg => ea.value.toString + case e: SExpression => e.getType.accept(new StringTypeVisitor) + case t: Type => t.accept(new StringTypeVisitor) + case a => throw new IllegalStateException("Unexpected value: " + a) + } + .mkString(", ") + ) + throw new IllegalArgumentException(msg) + }) + } +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala new file mode 100644 index 000000000..4048829b3 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import io.substrait.spark.{HasOutputStack, ToSubstraitType} + +import org.apache.spark.sql.catalyst.expressions._ + +import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FieldReference, ImmutableExpression} +import io.substrait.expression.Expression.FailureBehavior +import io.substrait.utils.Util +import org.apache.spark.substrait.SparkTypeUtil + +import scala.collection.JavaConverters.asJavaIterableConverter + +/** The builder to generate substrait expressions from catalyst expressions. */ +abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] { + + object ScalarFunction { + def unapply(e: Expression): Option[Seq[Expression]] = e match { + case BinaryExpression(left, right) => Some(Seq(left, right)) + case UnaryExpression(child) => Some(Seq(child)) + case t: TernaryExpression => Some(Seq(t.first, t.second, t.third)) + case _ => None + } + } + + type OutputT = Seq[Attribute] + + protected val toScalarFunction: ToScalarFunction + + protected def default(e: Expression): Option[SExpression] = { + throw new UnsupportedOperationException(s"Unable to convert the expression $e") + } + + def apply(e: Expression, output: OutputT = Nil): SExpression = { + convert(e, output).getOrElse( + throw new UnsupportedOperationException(s"Unable to convert the expression $e") + ) + } + def convert(expr: Expression, output: OutputT = Nil): Option[SExpression] = { + pushOutput(output) + try { + translateUp(expr) + } finally { + popOutput() + } + } + + protected def translateSubQuery(expr: PlanExpression[_]): Option[SExpression] = default(expr) + + protected def translateAttribute(a: AttributeReference): Option[SExpression] = { + val bindReference = + BindReferences.bindReference[Expression](a, currentOutput, allowFailures = false) + if (bindReference == a) { + default(a) + } else { + Some( + FieldReference.newRootStructReference( + bindReference.asInstanceOf[BoundReference].ordinal, + ToSubstraitType.apply(a.dataType, a.nullable)) + ) + } + } + + protected def translateCaseWhen( + branches: Seq[(Expression, Expression)], + elseValue: Option[Expression]): Option[SExpression] = { + val cases = + for ((predicate, trueValue) <- branches) + yield translateUp(predicate).flatMap( + p => + translateUp(trueValue).map( + t => { + ImmutableExpression.IfClause.builder + .condition(p) + .`then`(t) + .build() + })) + val sparkElse = elseValue.getOrElse(Literal.create(null, branches.head._2.dataType)) + Util + .seqToOption(cases.toList) + .flatMap( + caseConditions => + translateUp(sparkElse).map( + defaultResult => { + ExpressionCreator.ifThenStatement(defaultResult, caseConditions.asJava) + })) + } + protected def translateIn(value: Expression, list: Seq[Expression]): Option[SExpression] = { + Util + .seqToOption(list.map(translateUp).toList) + .flatMap( + inList => + translateUp(value).map( + inValue => { + SExpression.SingleOrList + .builder() + .condition(inValue) + .options(inList.asJava) + .build() + })) + } + + protected def translateUp(expr: Expression): Option[SExpression] = { + expr match { + case c @ Cast(child, dataType, _, _) => + translateUp(child) + .map(ExpressionCreator + .cast(ToSubstraitType.apply(dataType, c.nullable), _, FailureBehavior.THROW_EXCEPTION)) + case c @ CheckOverflow(child, dataType, _) => + // CheckOverflow similar with cast + translateUp(child) + .map( + childExpr => { + if (SparkTypeUtil.sameType(dataType, child.dataType)) { + childExpr + } else { + ExpressionCreator.cast( + ToSubstraitType.apply(dataType, c.nullable), + childExpr, + FailureBehavior.THROW_EXCEPTION) + } + }) + case SubstraitLiteral(substraitLiteral) => Some(substraitLiteral) + case a: AttributeReference if currentOutput.nonEmpty => translateAttribute(a) + case a: Alias => translateUp(a.child) + case p + if p.getClass.getCanonicalName.equals( // removed in spark-3.3 + "org.apache.spark.sql.catalyst.expressions.PromotePrecision") => + translateUp(p.children.head) + case CaseWhen(branches, elseValue) => translateCaseWhen(branches, elseValue) + case scalar @ ScalarFunction(children) => + Util + .seqToOption(children.map(translateUp)) + .flatMap(toScalarFunction.convert(scalar, _)) + case In(value, list) => translateIn(value, list) + case p: PlanExpression[_] => translateSubQuery(p) + case other => default(other) + } + } +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala new file mode 100644 index 000000000..95633c15b --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +import io.substrait.expression.{Expression => SExpression} +import io.substrait.expression.ExpressionCreator._ +import io.substrait.spark.ToSubstraitType + +class ToSubstraitLiteral { + + object Nonnull { + private def sparkDecimal2Substrait( + d: Decimal, + precision: Int, + scale: Int): SExpression.Literal = + decimal(false, d.toJavaBigDecimal, precision, scale) + + val _bool: Boolean => SExpression.Literal = bool(false, _) + val _i8: Byte => SExpression.Literal = i8(false, _) + val _i16: Short => SExpression.Literal = i16(false, _) + val _i32: Int => SExpression.Literal = i32(false, _) + val _i64: Long => SExpression.Literal = i64(false, _) + val _fp32: Float => SExpression.Literal = fp32(false, _) + val _fp64: Double => SExpression.Literal = fp64(false, _) + val _decimal: (Decimal, Int, Int) => SExpression.Literal = sparkDecimal2Substrait + val _date: Int => SExpression.Literal = date(false, _) + val _string: String => SExpression.Literal = string(false, _) + } + + private def convertWithValue(literal: Literal): Option[SExpression.Literal] = { + Option.apply( + literal match { + case Literal(b: Boolean, BooleanType) => Nonnull._bool(b) + case Literal(b: Byte, ByteType) => Nonnull._i8(b) + case Literal(s: Short, ShortType) => Nonnull._i16(s) + case Literal(i: Integer, IntegerType) => Nonnull._i32(i) + case Literal(l: Long, LongType) => Nonnull._i64(l) + case Literal(f: Float, FloatType) => Nonnull._fp32(f) + case Literal(d: Double, DoubleType) => Nonnull._fp64(d) + case Literal(d: Decimal, dataType: DecimalType) => + Nonnull._decimal(d, dataType.precision, dataType.scale) + case Literal(d: Integer, DateType) => Nonnull._date(d) + case Literal(u: UTF8String, StringType) => Nonnull._string(u.toString) + case _ => null + } + ) + } + + def convert(literal: Literal): Option[SExpression.Literal] = { + if (literal.nullable) { + ToSubstraitType + .convert(literal.dataType, nullable = true) + .map(typedNull) + } else { + convertWithValue(literal) + } + } + + def apply(literal: Literal): SExpression.Literal = { + convert(literal) + .getOrElse( + throw new UnsupportedOperationException( + s"Unable to convert the type ${literal.dataType.typeName}")) + } +} + +object ToSubstraitLiteral extends ToSubstraitLiteral + +object SubstraitLiteral { + def unapply(literal: Literal): Option[SExpression.Literal] = { + ToSubstraitLiteral.convert(literal) + } +} diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala new file mode 100644 index 000000000..908a8aa0d --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.logical + +import io.substrait.spark.{DefaultRelVisitor, SparkExtension, ToSubstraitType} +import io.substrait.spark.expression._ + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InMemoryFileIndex, LogicalRelation} +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.types.{DataTypes, IntegerType, StructType} + +import io.substrait.`type`.{StringTypeVisitor, Type} +import io.substrait.{expression => exp} +import io.substrait.expression.{Expression => SExpression} +import io.substrait.plan.Plan +import io.substrait.relation +import io.substrait.relation.LocalFiles +import org.apache.hadoop.fs.Path + +import scala.collection.JavaConverters.asScalaBufferConverter +import scala.collection.mutable.ArrayBuffer + +/** + * RelVisitor to convert Substrait Rel plan to [[LogicalPlan]]. Unsupported Rel node will call + * visitFallback and throw UnsupportedOperationException. + */ +class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] { + + private val expressionConverter = + new ToSparkExpression(ToScalarFunction(SparkExtension.SparkScalarFunctions), Some(this)) + + private def fromMeasure(measure: relation.Aggregate.Measure): AggregateExpression = { + // this functions is called in createParentwithChild + val function = measure.getFunction + var arguments = function.arguments().asScala.zipWithIndex.map { + case (arg, i) => + arg.accept(function.declaration(), i, expressionConverter) + } + if (function.declaration.name == "count" && function.arguments.size == 0) { + // HACK - count() needs to be rewritten as count(1) + arguments = ArrayBuffer(Literal(1)) + } + + val aggregateFunction = SparkExtension.toAggregateFunction + .getSparkExpressionFromSubstraitFunc(function.declaration.key, function.outputType) + .map(sig => sig.makeCall(arguments)) + .map(_.asInstanceOf[AggregateFunction]) + .getOrElse({ + val msg = String.format( + "Unable to convert Aggregate function %s(%s).", + function.declaration.name, + function.arguments.asScala + .map { + case ea: exp.EnumArg => ea.value.toString + case e: SExpression => e.getType.accept(new StringTypeVisitor) + case t: Type => t.accept(new StringTypeVisitor) + case a => throw new IllegalStateException("Unexpected value: " + a) + } + .mkString(", ") + ) + throw new IllegalArgumentException(msg) + }) + AggregateExpression( + aggregateFunction, + ToAggregateFunction.toSpark(function.aggregationPhase()), + ToAggregateFunction.toSpark(function.invocation()), + None + ) + } + + private def toNamedExpression(e: Expression): NamedExpression = e match { + case ne: NamedExpression => ne + case other => Alias(other, toPrettySQL(other))() + } + + override def visit(aggregate: relation.Aggregate): LogicalPlan = { + require(aggregate.getGroupings.size() == 1) + val child = aggregate.getInput.accept(this) + withChild(child) { + val groupBy = aggregate.getGroupings + .get(0) + .getExpressions + .asScala + .map(expr => expr.accept(expressionConverter)) + + val outputs = groupBy.map(toNamedExpression) + val aggregateExpressions = + aggregate.getMeasures.asScala.map(fromMeasure).map(toNamedExpression) + Aggregate(groupBy, outputs ++= aggregateExpressions, child) + } + } + + override def visit(join: relation.Join): LogicalPlan = { + val left = join.getLeft.accept(this) + val right = join.getRight.accept(this) + withChild(left, right) { + val condition = Option(join.getCondition.orElse(null)) + .map(_.accept(expressionConverter)) + + val joinType = join.getJoinType match { + case relation.Join.JoinType.INNER => Inner + case relation.Join.JoinType.LEFT => LeftOuter + case relation.Join.JoinType.RIGHT => RightOuter + case relation.Join.JoinType.OUTER => FullOuter + case relation.Join.JoinType.SEMI => LeftSemi + case relation.Join.JoinType.ANTI => LeftAnti + case relation.Join.JoinType.UNKNOWN => + throw new UnsupportedOperationException("Unknown join type is not supported") + } + Join(left, right, joinType, condition, hint = JoinHint.NONE) + } + } + + override def visit(join: relation.Cross): LogicalPlan = { + val left = join.getLeft.accept(this) + val right = join.getRight.accept(this) + withChild(left, right) { + // TODO: Support different join types here when join types are added to cross rel for BNLJ + // Currently, this will change both cross and inner join types to inner join + Join(left, right, Inner, Option(null), hint = JoinHint.NONE) + } + } + + private def toSortOrder(sortField: SExpression.SortField): SortOrder = { + val expression = sortField.expr().accept(expressionConverter) + val (direction, nullOrdering) = sortField.direction() match { + case SExpression.SortDirection.ASC_NULLS_FIRST => (Ascending, NullsFirst) + case SExpression.SortDirection.DESC_NULLS_FIRST => (Descending, NullsFirst) + case SExpression.SortDirection.ASC_NULLS_LAST => (Ascending, NullsLast) + case SExpression.SortDirection.DESC_NULLS_LAST => (Descending, NullsLast) + case other => + throw new RuntimeException(s"Unexpected Expression.SortDirection enum: $other !") + } + SortOrder(expression, direction, nullOrdering, Seq.empty) + } + override def visit(fetch: relation.Fetch): LogicalPlan = { + val child = fetch.getInput.accept(this) + val limit = Literal(fetch.getCount.getAsLong.intValue(), IntegerType) + fetch.getOffset match { + case 1L => GlobalLimit(limitExpr = limit, child = child) + case -1L => LocalLimit(limitExpr = limit, child = child) + case _ => visitFallback(fetch) + } + } + override def visit(sort: relation.Sort): LogicalPlan = { + val child = sort.getInput.accept(this) + withChild(child) { + val sortOrders = sort.getSortFields.asScala.map(toSortOrder) + Sort(sortOrders, global = true, child) + } + } + + override def visit(project: relation.Project): LogicalPlan = { + val child = project.getInput.accept(this) + val (output, createProject) = child match { + case a: Aggregate => (a.aggregateExpressions, false) + case other => (other.output, true) + } + + withOutput(output) { + val projectList = + project.getExpressions.asScala + .map(expr => expr.accept(expressionConverter)) + .map(toNamedExpression) + if (createProject) { + Project(projectList, child) + } else { + val aggregate: Aggregate = child.asInstanceOf[Aggregate] + aggregate.copy(aggregateExpressions = projectList) + } + } + } + + override def visit(filter: relation.Filter): LogicalPlan = { + val child = filter.getInput.accept(this) + withChild(child) { + val condition = filter.getCondition.accept(expressionConverter) + Filter(condition, child) + } + } + + override def visit(emptyScan: relation.EmptyScan): LogicalPlan = { + LocalRelation(ToSubstraitType.toAttribute(emptyScan.getInitialSchema)) + } + override def visit(namedScan: relation.NamedScan): LogicalPlan = { + resolve(UnresolvedRelation(namedScan.getNames.asScala)) match { + case m: MultiInstanceRelation => m.newInstance() + case other => other + } + } + + override def visit(localFiles: LocalFiles): LogicalPlan = { + val schema = ToSubstraitType.toStructType(localFiles.getInitialSchema) + val output = schema.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + new LogicalRelation( + relation = HadoopFsRelation( + location = new InMemoryFileIndex( + spark, + localFiles.getItems.asScala.map(i => new Path(i.getPath.get())), + Map(), + Some(schema)), + partitionSchema = new StructType(), + dataSchema = schema, + bucketSpec = None, + fileFormat = new CSVFileFormat(), + options = Map() + )(spark), + output = output, + catalogTable = None, + isStreaming = false + ) + } + + private def withChild(child: LogicalPlan*)(body: => LogicalPlan): LogicalPlan = { + val output = child.flatMap(_.output) + withOutput(output)(body) + } + + private def withOutput(output: Seq[NamedExpression])(body: => LogicalPlan): LogicalPlan = { + expressionConverter.pushOutput(output) + try { + body + } finally { + expressionConverter.popOutput() + } + } + private def resolve(plan: LogicalPlan): LogicalPlan = { + val qe = new QueryExecution(spark, plan) + qe.analyzed match { + case SubqueryAlias(_, child) => child + case other => other + } + } + + def convert(plan: Plan): LogicalPlan = { + val root = plan.getRoots.get(0) + val names = root.getNames.asScala + val output = names.map(name => AttributeReference(name, DataTypes.StringType)()) + withOutput(output) { + val logicalPlan = root.getInput.accept(this); + val projectList: List[NamedExpression] = logicalPlan.output.zipWithIndex + .map( + z => { + val (e, i) = z; + if (e.name == names(i)) { + e + } else { + Alias(e, names(i))() + } + }) + .toList + val wrapper = Project(projectList, logicalPlan) + require(wrapper.resolved) + wrapper + } + } +} diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala new file mode 100644 index 000000000..4085860f4 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.logical + +import io.substrait.spark.{SparkExtension, ToSubstraitType} +import io.substrait.spark.expression._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} +import org.apache.spark.sql.types.StructType +import ToSubstraitType.toNamedStruct + +import io.substrait.{proto, relation} +import io.substrait.debug.TreePrinter +import io.substrait.expression.{Expression => SExpression, ExpressionCreator} +import io.substrait.extension.ExtensionCollector +import io.substrait.plan.{ImmutablePlan, ImmutableRoot, Plan} +import io.substrait.relation.RelProtoConverter +import io.substrait.relation.files.{FileFormat, ImmutableFileOrFiles} +import io.substrait.relation.files.FileOrFiles.PathType + +import java.util.Collections + +import scala.collection.JavaConverters.asJavaIterableConverter +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { + + private val toSubstraitExp = new WithLogicalSubQuery(this) + + private val TRUE = ExpressionCreator.bool(false, true) + + override def default(p: LogicalPlan): relation.Rel = p match { + case p: LeafNode => convertReadOperator(p) + case s: SubqueryAlias => visit(s.child) + case other => t(other) + } + + private def fromGroupSet( + e: Seq[Expression], + output: Seq[Attribute]): relation.Aggregate.Grouping = { + + relation.Aggregate.Grouping.builder + .addAllExpressions(e.map(toExpression(output)).asJava) + .build() + } + + private def fromAggCall( + expression: AggregateExpression, + output: Seq[Attribute]): relation.Aggregate.Measure = { + val substraitExps = expression.aggregateFunction.children.map(toExpression(output)) + val invocation = + SparkExtension.toAggregateFunction.apply(expression, substraitExps) + relation.Aggregate.Measure.builder.function(invocation).build() + } + + private def collectAggregates( + resultExpressions: Seq[NamedExpression], + aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = { + var ordinal = 0 + resultExpressions.flatMap { + expr => + expr.collect { + // Do not push down duplicated aggregate expressions. For example, + // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one + // `max(a)` to the data source. + case agg: AggregateExpression if !aggExprToOutputOrdinal.contains(agg.canonicalized) => + aggExprToOutputOrdinal(agg.canonicalized) = ordinal + ordinal += 1 + agg + } + } + } + + private def translateAggregation( + groupBy: Seq[Expression], + aggregates: Seq[AggregateExpression], + output: Seq[Attribute], + input: relation.Rel): relation.Aggregate = { + val groupings = Collections.singletonList(fromGroupSet(groupBy, output)) + val aggCalls = aggregates.map(fromAggCall(_, output)).asJava + + relation.Aggregate.builder + .input(input) + .addAllGroupings(groupings) + .addAllMeasures(aggCalls) + .build + } + + /** + * The current substrait [[relation.Aggregate]] can't specify output, but spark [[Aggregate]] + * allow. So To support #1 select max(b) from table group by a, and #2 select + * a, max(b) + 1 from table group by a, We need create [[Project]] on top of [[Aggregate]] + * to correctly support it. + * + * TODO: support [[Rollup]] and [[GroupingSets]] + */ + override def visitAggregate(agg: Aggregate): relation.Rel = { + val input = visit(agg.child) + val actualResultExprs = agg.aggregateExpressions + val actualGroupExprs = agg.groupingExpressions + + val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] + val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal) + val aggOutputMap = aggregates.zipWithIndex.map { + case (e, i) => + AttributeReference(s"agg_func_$i", e.dataType)() -> e + } + val aggOutput = aggOutputMap.map(_._1) + + // collect group by + val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] + actualGroupExprs.zipWithIndex.foreach { + case (expr, ordinal) => + if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) { + groupByExprToOutputOrdinal(expr.canonicalized) = ordinal + } + } + val groupOutputMap = actualGroupExprs.zipWithIndex.map { + case (e, i) => + AttributeReference(s"group_col_$i", e.dataType)() -> e + } + val groupOutput = groupOutputMap.map(_._1) + + val substraitAgg = translateAggregation(actualGroupExprs, aggregates, agg.child.output, input) + val newOutput = groupOutput ++ aggOutput + + val projectExpressions = actualResultExprs.map { + expr => + expr.transformDown { + case agg: AggregateExpression => + val ordinal = aggExprToOutputOrdinal(agg.canonicalized) + aggOutput(ordinal) + case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => + val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) + groupOutput(ordinal) + } + } + val projects = projectExpressions.map(toExpression(newOutput)) + + relation.Project.builder + .remap(relation.Rel.Remap.offset(newOutput.size, projects.size)) + .expressions(projects.asJava) + .input(substraitAgg) + .build() + } + + private def asLong(e: Expression): Long = e match { + case IntegerLiteral(limit) => limit + case other => throw new UnsupportedOperationException(s"Unknown type: $other") + } + + private def fetchBuilder(limit: Long, global: Boolean): relation.ImmutableFetch.Builder = { + val offset = if (global) 1L else -1L + relation.Fetch + .builder() + .count(limit) + .offset(offset) + } + override def visitGlobalLimit(p: GlobalLimit): relation.Rel = { + fetchBuilder(asLong(p.limitExpr), global = true) + .input(visit(p.child)) + .build() + } + + override def visitLocalLimit(p: LocalLimit): relation.Rel = { + fetchBuilder(asLong(p.limitExpr), global = false) + .input(visit(p.child)) + .build() + } + + override def visitFilter(p: Filter): relation.Rel = { + val condition = toExpression(p.child.output)(p.condition) + relation.Filter.builder().condition(condition).input(visit(p.child)).build() + } + + private def toSubstraitJoin(joinType: JoinType): relation.Join.JoinType = joinType match { + case Inner | Cross => relation.Join.JoinType.INNER + case LeftOuter => relation.Join.JoinType.LEFT + case RightOuter => relation.Join.JoinType.RIGHT + case FullOuter => relation.Join.JoinType.OUTER + case LeftSemi => relation.Join.JoinType.SEMI + case LeftAnti => relation.Join.JoinType.ANTI + case other => throw new UnsupportedOperationException(s"Unsupported join type $other") + } + + override def visitJoin(p: Join): relation.Rel = { + val left = visit(p.left) + val right = visit(p.right) + val condition = p.condition.map(toExpression(p.left.output ++ p.right.output)).getOrElse(TRUE) + val joinType = toSubstraitJoin(p.joinType) + + if (joinType == relation.Join.JoinType.INNER && TRUE == condition) { + relation.Cross.builder + .left(left) + .right(right) + .build + } else { + relation.Join.builder + .condition(condition) + .joinType(joinType) + .left(left) + .right(right) + .build + } + } + + override def visitProject(p: Project): relation.Rel = { + val expressions = p.projectList.map(toExpression(p.child.output)).toList + + relation.Project.builder + .remap(relation.Rel.Remap.offset(p.child.output.size, expressions.size)) + .expressions(expressions.asJava) + .input(visit(p.child)) + .build() + } + + private def toSortField(output: Seq[Attribute] = Nil)(order: SortOrder): SExpression.SortField = { + val direction = (order.direction, order.nullOrdering) match { + case (Ascending, NullsFirst) => SExpression.SortDirection.ASC_NULLS_FIRST + case (Descending, NullsFirst) => SExpression.SortDirection.DESC_NULLS_FIRST + case (Ascending, NullsLast) => SExpression.SortDirection.ASC_NULLS_LAST + case (Descending, NullsLast) => SExpression.SortDirection.DESC_NULLS_LAST + } + val expr = toExpression(output)(order.child) + SExpression.SortField.builder().expr(expr).direction(direction).build() + } + override def visitSort(sort: Sort): relation.Rel = { + val input = visit(sort.child) + val fields = sort.order.map(toSortField(sort.child.output)).asJava + relation.Sort.builder.addAllSortFields(fields).input(input).build + } + + private def toExpression(output: Seq[Attribute])(e: Expression): SExpression = { + toSubstraitExp(e, output) + } + + private def buildNamedScan(schema: StructType, tableNames: List[String]): relation.NamedScan = { + val namedStruct = toNamedStruct(schema) + + val namedScan = relation.NamedScan.builder + .initialSchema(namedStruct) + .addAllNames(tableNames.asJava) + .build + namedScan + } + private def buildVirtualTableScan(localRelation: LocalRelation): relation.AbstractReadRel = { + val namedStruct = toNamedStruct(localRelation.schema) + + if (localRelation.data.isEmpty) { + relation.EmptyScan.builder().initialSchema(namedStruct).build() + } else { + relation.VirtualTableScan + .builder() + .initialSchema(namedStruct) + .addAllRows( + localRelation.data + .map( + row => { + var idx = 0 + val buf = new ArrayBuffer[SExpression.Literal](row.numFields) + while (idx < row.numFields) { + val l = Literal.apply(row.get(idx, localRelation.schema(idx).dataType)) + buf += ToSubstraitLiteral.apply(l) + idx += 1 + } + ExpressionCreator.struct(false, buf.asJava) + }) + .asJava) + .build() + } + } + + private def buildLocalFileScan(fsRelation: HadoopFsRelation): relation.AbstractReadRel = { + val namedStruct = toNamedStruct(fsRelation.schema) + + val ff = new FileFormat.ParquetReadOptions { + override def toString: String = "csv" // TODO this is hardcoded at the moment + } + + relation.LocalFiles + .builder() + .initialSchema(namedStruct) + .addAllItems( + fsRelation.location.inputFiles + .map( + file => { + ImmutableFileOrFiles + .builder() + .fileFormat(ff) + .partitionIndex(0) + .start(0) + .length(fsRelation.sizeInBytes) + .path(file) + .pathType(PathType.URI_FILE) + .build() + }) + .toList + .asJava + ) + .build() + } + + /** Read Operator: https://substrait.io/relations/logical_relations/#read-operator */ + private def convertReadOperator(plan: LeafNode): relation.AbstractReadRel = { + var tableNames: List[String] = null + plan match { + case logicalRelation: LogicalRelation if logicalRelation.catalogTable.isDefined => + tableNames = logicalRelation.catalogTable.get.identifier.unquotedString.split("\\.").toList + buildNamedScan(logicalRelation.schema, tableNames) + case dataSourceV2ScanRelation: DataSourceV2ScanRelation => + tableNames = dataSourceV2ScanRelation.relation.identifier.get.toString.split("\\.").toList + buildNamedScan(dataSourceV2ScanRelation.schema, tableNames) + case dataSourceV2Relation: DataSourceV2Relation => + tableNames = dataSourceV2Relation.identifier.get.toString.split("\\.").toList + buildNamedScan(dataSourceV2Relation.schema, tableNames) + case hiveTableRelation: HiveTableRelation => + tableNames = hiveTableRelation.tableMeta.identifier.unquotedString.split("\\.").toList + buildNamedScan(hiveTableRelation.schema, tableNames) + case localRelation: LocalRelation => buildVirtualTableScan(localRelation) + case logicalRelation: LogicalRelation => + logicalRelation.relation match { + case fsRelation: HadoopFsRelation => + buildLocalFileScan(fsRelation) + case _ => + throw new UnsupportedOperationException( + s"******* Unable to convert the plan to a substrait relation: " + + s"${logicalRelation.relation.toString}") + } + case _ => + throw new UnsupportedOperationException( + s"******* Unable to convert the plan to a substrait NamedScan: $plan") + } + } + def convert(p: LogicalPlan): Plan = { + val rel = visit(p) + ImmutablePlan.builder + .roots( + Collections.singletonList( + ImmutableRoot.builder().input(rel).addAllNames(p.output.map(_.name).asJava).build() + )) + .build() + } + + def tree(p: LogicalPlan): String = { + TreePrinter.tree(visit(p)) + } + + def toProtoSubstrait(p: LogicalPlan): Array[Byte] = { + val substraitRel = visit(p) + + val extensionCollector = new ExtensionCollector + val relProtoConverter = new RelProtoConverter(extensionCollector) + val builder = proto.Plan + .newBuilder() + .addRelations( + proto.PlanRel + .newBuilder() + .setRel(substraitRel + .accept(relProtoConverter)) + ) + extensionCollector.addExtensionsToPlan(builder) + builder.build().toByteArray + } +} + +private[logical] class WithLogicalSubQuery(toSubstraitRel: ToSubstraitRel) + extends ToSubstraitExpression { + override protected val toScalarFunction: ToScalarFunction = + ToScalarFunction(SparkExtension.SparkScalarFunctions) + + override protected def translateSubQuery(expr: PlanExpression[_]): Option[SExpression] = { + expr match { + case s: ScalarSubquery if s.outerAttrs.isEmpty && s.joinCond.isEmpty => + val rel = toSubstraitRel.visit(s.plan) + Some( + SExpression.ScalarSubquery.builder + .input(rel) + .`type`(ToSubstraitType.apply(s.dataType, s.nullable)) + .build()) + case other => default(other) + } + } +} diff --git a/spark/src/main/scala/io/substrait/utils/Util.scala b/spark/src/main/scala/io/substrait/utils/Util.scala new file mode 100644 index 000000000..165d59953 --- /dev/null +++ b/spark/src/main/scala/io/substrait/utils/Util.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.utils + +import scala.annotation.tailrec +import scala.collection.mutable.ArrayBuffer + +object Util { + + /** + * Compute the cartesian product for n lists. + * + *

Based on Soln by + * Thomas Preissler + */ + def crossProduct[T](lists: Seq[Seq[T]]): Seq[Seq[T]] = { + + /** list [a, b], element 1 => list + element => [a, b, 1] */ + val appendElementToList: (Seq[T], T) => Seq[T] = + (list, element) => list :+ element + + /** ([a, b], [1, 2]) ==> [a, b, 1], [a, b, 2] */ + val appendAndGen: (Seq[T], Seq[T]) => Seq[Seq[T]] = + (list, elemsToAppend) => elemsToAppend.map(e => appendElementToList(list, e)) + + val firstListToJoin = lists.head + val startProduct = appendAndGen(new ArrayBuffer[T], firstListToJoin) + + /** ([ [a, b], [c, d] ], [1, 2]) -> [a, b, 1], [a, b, 2], [c, d, 1], [c, d, 2] */ + val appendAndGenLists: (Seq[Seq[T]], Seq[T]) => Seq[Seq[T]] = + (products, toJoin) => products.flatMap(product => appendAndGen(product, toJoin)) + lists.tail.foldLeft(startProduct)(appendAndGenLists) + } + + def seqToOption[T](s: Seq[Option[T]]): Option[Seq[T]] = { + @tailrec + def seqToOptionHelper(s: Seq[Option[T]], accum: Seq[T] = Seq[T]()): Option[Seq[T]] = { + s match { + case Some(head) :: Nil => + Option(accum :+ head) + case Some(head) :: tail => + seqToOptionHelper(tail, accum :+ head) + case _ => None + } + } + seqToOptionHelper(s) + } + +} diff --git a/spark/src/main/scala/org/apache/spark/substrait/SparkTypeUtil.scala b/spark/src/main/scala/org/apache/spark/substrait/SparkTypeUtil.scala new file mode 100644 index 000000000..7af796444 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/substrait/SparkTypeUtil.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.substrait + +import org.apache.spark.sql.types.DataType + +object SparkTypeUtil { + + def sameType(left: DataType, right: DataType): Boolean = { + left.sameType(right) + } + +} diff --git a/spark/src/main/spark-3.2/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala b/spark/src/main/spark-3.2/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala new file mode 100644 index 000000000..836a087f1 --- /dev/null +++ b/spark/src/main/spark-3.2/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.logical + +import org.apache.spark.sql.catalyst.plans.logical._ + +import io.substrait.relation +import io.substrait.relation.Rel + +class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] { + + protected def t(p: LogicalPlan): relation.Rel = + throw new UnsupportedOperationException(s"Unable to convert the LogicalPlan ${p.nodeName}") + + override def visitDistinct(p: Distinct): relation.Rel = t(p) + + override def visitExcept(p: Except): relation.Rel = t(p) + + override def visitExpand(p: Expand): relation.Rel = t(p) + + override def visitRepartition(p: Repartition): relation.Rel = t(p) + + override def visitRepartitionByExpr(p: RepartitionByExpression): relation.Rel = t(p) + + override def visitSample(p: Sample): relation.Rel = t(p) + + override def visitScriptTransform(p: ScriptTransformation): relation.Rel = t(p) + + override def visitUnion(p: Union): relation.Rel = t(p) + + override def visitWindow(p: Window): relation.Rel = t(p) + + override def visitTail(p: Tail): relation.Rel = t(p) + + override def visitGenerate(p: Generate): relation.Rel = t(p) + + override def visitGlobalLimit(p: GlobalLimit): relation.Rel = t(p) + + override def visitIntersect(p: Intersect): relation.Rel = t(p) + + override def visitLocalLimit(p: LocalLimit): relation.Rel = t(p) + + override def visitPivot(p: Pivot): relation.Rel = t(p) + + override def default(p: LogicalPlan): Rel = t(p) + + override def visitAggregate(p: Aggregate): Rel = t(p) + + override def visitFilter(p: Filter): Rel = t(p) + + override def visitJoin(p: Join): Rel = t(p) + + override def visitProject(p: Project): Rel = t(p) + + override def visitSort(sort: Sort): Rel = t(sort) + + override def visitWithCTE(p: WithCTE): Rel = t(p) +} diff --git a/spark/src/main/spark-3.3/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala b/spark/src/main/spark-3.3/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala new file mode 100644 index 000000000..345cb215f --- /dev/null +++ b/spark/src/main/spark-3.3/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.logical + +import org.apache.spark.sql.catalyst.plans.logical._ + +import io.substrait.relation +import io.substrait.relation.Rel + +class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] { + + protected def t(p: LogicalPlan): relation.Rel = + throw new UnsupportedOperationException(s"Unable to convert the expression ${p.nodeName}") + + override def visitDistinct(p: Distinct): relation.Rel = t(p) + + override def visitExcept(p: Except): relation.Rel = t(p) + + override def visitExpand(p: Expand): relation.Rel = t(p) + + override def visitRepartition(p: Repartition): relation.Rel = t(p) + + override def visitRepartitionByExpr(p: RepartitionByExpression): relation.Rel = t(p) + + override def visitSample(p: Sample): relation.Rel = t(p) + + override def visitScriptTransform(p: ScriptTransformation): relation.Rel = t(p) + + override def visitUnion(p: Union): relation.Rel = t(p) + + override def visitWindow(p: Window): relation.Rel = t(p) + + override def visitTail(p: Tail): relation.Rel = t(p) + + override def visitGenerate(p: Generate): relation.Rel = t(p) + + override def visitGlobalLimit(p: GlobalLimit): relation.Rel = t(p) + + override def visitIntersect(p: Intersect): relation.Rel = t(p) + + override def visitLocalLimit(p: LocalLimit): relation.Rel = t(p) + + override def visitPivot(p: Pivot): relation.Rel = t(p) + + override def default(p: LogicalPlan): Rel = t(p) + + override def visitAggregate(p: Aggregate): Rel = t(p) + + override def visitFilter(p: Filter): Rel = t(p) + + override def visitJoin(p: Join): Rel = t(p) + + override def visitProject(p: Project): Rel = t(p) + + override def visitSort(sort: Sort): Rel = t(sort) + + override def visitWithCTE(p: WithCTE): Rel = t(p) + + override def visitRebalancePartitions(p: RebalancePartitions): Rel = t(p) +} diff --git a/spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala b/spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala new file mode 100644 index 000000000..ec3ee78e8 --- /dev/null +++ b/spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.logical + +import org.apache.spark.sql.catalyst.plans.logical._ + +import io.substrait.relation +import io.substrait.relation.Rel + +class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] { + + protected def t(p: LogicalPlan): relation.Rel = + throw new UnsupportedOperationException(s"Unable to convert the expression ${p.nodeName}") + + override def visitDistinct(p: Distinct): relation.Rel = t(p) + + override def visitExcept(p: Except): relation.Rel = t(p) + + override def visitExpand(p: Expand): relation.Rel = t(p) + + override def visitRepartition(p: Repartition): relation.Rel = t(p) + + override def visitRepartitionByExpr(p: RepartitionByExpression): relation.Rel = t(p) + + override def visitSample(p: Sample): relation.Rel = t(p) + + override def visitScriptTransform(p: ScriptTransformation): relation.Rel = t(p) + + override def visitUnion(p: Union): relation.Rel = t(p) + + override def visitWindow(p: Window): relation.Rel = t(p) + + override def visitTail(p: Tail): relation.Rel = t(p) + + override def visitGenerate(p: Generate): relation.Rel = t(p) + + override def visitGlobalLimit(p: GlobalLimit): relation.Rel = t(p) + + override def visitIntersect(p: Intersect): relation.Rel = t(p) + + override def visitLocalLimit(p: LocalLimit): relation.Rel = t(p) + + override def visitPivot(p: Pivot): relation.Rel = t(p) + + override def default(p: LogicalPlan): Rel = t(p) + + override def visitAggregate(p: Aggregate): Rel = t(p) + + override def visitFilter(p: Filter): Rel = t(p) + + override def visitJoin(p: Join): Rel = t(p) + + override def visitProject(p: Project): Rel = t(p) + + override def visitSort(sort: Sort): Rel = t(sort) + + override def visitWithCTE(p: WithCTE): Rel = t(p) + + override def visitOffset(p: Offset): Rel = t(p) + + override def visitRebalancePartitions(p: RebalancePartitions): Rel = t(p) +} diff --git a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala new file mode 100644 index 000000000..4fa9ec263 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark + +import io.substrait.spark.logical.{ToLogicalPlan, ToSubstraitRel} + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.resourceToString +import org.apache.spark.sql.test.SharedSparkSession + +import io.substrait.debug.TreePrinter +import io.substrait.extension.ExtensionCollector +import io.substrait.plan.{Plan, PlanProtoConverter, ProtoPlanConverter} +import io.substrait.proto +import io.substrait.relation.RelProtoConverter +import org.scalactic.Equality +import org.scalactic.source.Position +import org.scalatest.Succeeded +import org.scalatest.compatible.Assertion +import org.scalatest.exceptions.{StackDepthException, TestFailedException} + +trait SubstraitPlanTestBase { self: SharedSparkSession => + + implicit class PlainEquality[T: TreePrinter](actual: T) { + // Like should equal, but does not try to mark diffs in strings with square brackets, + // so that IntelliJ can show a proper diff. + def shouldEqualPlainly(expected: T)(implicit equality: Equality[T]): Assertion = + if (!equality.areEqual(actual, expected)) { + + throw new TestFailedException( + (e: StackDepthException) => + Some( + s"${implicitly[TreePrinter[T]].tree(actual)}" + + s" did not equal ${implicitly[TreePrinter[T]].tree(expected)}"), + None, + Position.here + ) + } else Succeeded + } + + def sqlToProtoPlan(sql: String): proto.Plan = { + val convert = new ToSubstraitRel() + val logicalPlan = plan(sql) + val substraitRel = convert.visit(logicalPlan) + + val extensionCollector = new ExtensionCollector + val relProtoConverter = new RelProtoConverter(extensionCollector) + val builder = proto.Plan + .newBuilder() + .addRelations( + proto.PlanRel + .newBuilder() + .setRoot( + proto.RelRoot + .newBuilder() + .setInput(substraitRel + .accept(relProtoConverter)) + ) + ) + extensionCollector.addExtensionsToPlan(builder) + builder.build() + } + + def assertProtoPlanRoundrip(sql: String): Plan = { + val protoPlan1 = sqlToProtoPlan(sql) + val plan = new ProtoPlanConverter().from(protoPlan1) + val protoPlan2 = new PlanProtoConverter().toProto(plan) + assertResult(protoPlan1)(protoPlan2) + assertResult(1)(plan.getRoots.size()) + plan + } + + def assertSqlSubstraitRelRoundTrip(query: String): LogicalPlan = { + // TODO need a more robust way of testing this than round-tripping. + val logicalPlan = plan(query) + val pojoRel = new ToSubstraitRel().visit(logicalPlan) + val converter = new ToLogicalPlan(spark = spark); + val logicalPlan2 = pojoRel.accept(converter); + require(logicalPlan2.resolved); + val pojoRel2 = new ToSubstraitRel().visit(logicalPlan2) + + pojoRel2.shouldEqualPlainly(pojoRel) + logicalPlan2 + } + + def plan(sql: String): LogicalPlan = { + spark.sql(sql).queryExecution.optimizedPlan + } + + def assertPlanRoundrip(plan: Plan): Unit = { + val protoPlan1 = new PlanProtoConverter().toProto(plan) + val protoPlan2 = new PlanProtoConverter().toProto(new ProtoPlanConverter().from(protoPlan1)) + assertResult(protoPlan1)(protoPlan2) + } + + def testQuery(group: String, query: String, suffix: String = ""): Unit = { + val queryString = resourceToString( + s"$group/$query.sql", + classLoader = Thread.currentThread().getContextClassLoader) + assert(queryString != null) + assertSqlSubstraitRelRoundTrip(queryString) + } +} diff --git a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala new file mode 100644 index 000000000..7cfb3cd2d --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark + +import org.apache.spark.sql.TPCDSBase +import org.apache.spark.sql.internal.SQLConf + +class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { + + private val runAllQueriesIncludeFailed = false + override def beforeAll(): Unit = { + super.beforeAll() + sparkContext.setLogLevel("WARN") + + conf.setConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED, false) + // introduced in spark 3.4 + spark.conf.set("spark.sql.readSideCharPadding", "false") + } + + // "q9" failed in spark 3.3 + val successfulSQL: Set[String] = Set("q41", "q62", "q93", "q96", "q99") + + tpcdsQueries.foreach { + q => + if (runAllQueriesIncludeFailed || successfulSQL.contains(q)) { + test(s"check simplified (tpcds-v1.4/$q)") { + testQuery("tpcds", q) + } + } else { + ignore(s"check simplified (tpcds-v1.4/$q)") { + testQuery("tpcds", q) + } + } + } + + ignore("window") { + val qry = s"""(SELECT + | item_sk, + | rank() + | OVER ( + | ORDER BY rank_col DESC) rnk + | FROM (SELECT + | ss_item_sk item_sk, + | avg(ss_net_profit) rank_col + | FROM store_sales ss1 + | WHERE ss_store_sk = 4 + | GROUP BY ss_item_sk + | HAVING avg(ss_net_profit) > 0.9 * (SELECT avg(ss_net_profit) rank_col + | FROM store_sales + | WHERE ss_store_sk = 4 + | AND ss_addr_sk IS NULL + | GROUP BY ss_store_sk)) V2) """.stripMargin + assertSqlSubstraitRelRoundTrip(qry) + } +} diff --git a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala new file mode 100644 index 000000000..df5ca4f81 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark + +import org.apache.spark.sql.TPCHBase + +class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { + + override def beforeAll(): Unit = { + super.beforeAll() + sparkContext.setLogLevel("WARN") + } + + tpchQueries.foreach { + q => + test(s"check simplified (tpch/$q)") { + testQuery("tpch", q) + } + } + + test("Decimal") { + assertSqlSubstraitRelRoundTrip("select l_returnflag," + + " sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) from lineitem group by l_returnflag") + } + + test("simpleJoin") { + assertSqlSubstraitRelRoundTrip( + "select l_partkey + l_orderkey, l_extendedprice * 0.1 + 100.0, o_orderkey from lineitem " + + "join orders on l_orderkey = o_orderkey where l_shipdate < date '1998-01-01' ") + assertSqlSubstraitRelRoundTrip( + "select l_partkey + l_orderkey, l_extendedprice * 0.1 + 100.0, o_orderkey from lineitem " + + "left join orders on l_orderkey = o_orderkey where l_shipdate < date '1998-01-01' ") + assertSqlSubstraitRelRoundTrip( + "select l_partkey + l_orderkey, l_extendedprice * 0.1 + 100.0, o_orderkey from lineitem " + + "right join orders on l_orderkey = o_orderkey where l_shipdate < date '1998-01-01' ") + assertSqlSubstraitRelRoundTrip( + "select l_partkey + l_orderkey, l_extendedprice * 0.1 + 100.0, o_orderkey from lineitem " + + "full join orders on l_orderkey = o_orderkey where l_shipdate < date '1998-01-01' ") + } + + test("simpleOrderByClause") { + assertSqlSubstraitRelRoundTrip( + "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + + "order by l_shipdate, l_discount") + assertSqlSubstraitRelRoundTrip( + "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + + "order by l_shipdate asc, l_discount desc") + assertSqlSubstraitRelRoundTrip( + "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + + "order by l_shipdate asc, l_discount desc limit 100") + assertSqlSubstraitRelRoundTrip( + "select l_partkey from lineitem where l_shipdate < date '1998-01-01' limit 100") + assertSqlSubstraitRelRoundTrip( + "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + + "order by l_shipdate asc, l_discount desc nulls first") + assertSqlSubstraitRelRoundTrip( + "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + + "order by l_shipdate asc, l_discount desc nulls last") + } + + ignore("simpleOffsetClause") { // TODO need to implement the 'offset' clause for this to pass + assertSqlSubstraitRelRoundTrip( + "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + + "order by l_shipdate asc, l_discount desc limit 100 offset 1000") + } + + test("simpleTest") { + val query = "select p_size from part where p_partkey > cast(100 as bigint)" + assertSqlSubstraitRelRoundTrip(query) + } + + test("simpleTest2") { + val query = "select l_partkey, l_discount from lineitem where l_orderkey > cast(100 as bigint)" + assertSqlSubstraitRelRoundTrip(query) + } + + test("simpleTestAgg") { + assertSqlSubstraitRelRoundTrip( + "select l_partkey, count(l_tax), COUNT(distinct l_discount) from lineitem group by l_partkey") + + assertSqlSubstraitRelRoundTrip( + "select count(l_tax), COUNT(distinct l_discount)" + + " from lineitem group by l_partkey + l_orderkey") + + assertSqlSubstraitRelRoundTrip( + "select l_partkey + l_orderkey, count(l_tax), COUNT(distinct l_discount)" + + " from lineitem group by l_partkey + l_orderkey") + } + + ignore("avg(distinct)") { + assertSqlSubstraitRelRoundTrip( + "select l_partkey, sum(l_tax), sum(distinct l_tax)," + + " avg(l_discount), avg(distinct l_discount) from lineitem group by l_partkey") + } + + test("simpleTestAgg3") { + assertSqlSubstraitRelRoundTrip( + "select l_partkey, sum(l_extendedprice * (1.0-l_discount)) from lineitem group by l_partkey") + } + + ignore("simpleTestAggFilter") { + assertSqlSubstraitRelRoundTrip( + "select sum(l_tax) filter(WHERE l_orderkey > l_partkey) from lineitem") + // cast is added to avoid the difference by implicit cast + assertSqlSubstraitRelRoundTrip( + "select sum(l_tax) filter(WHERE l_orderkey > cast(10.0 as bigint)) from lineitem") + } + + test("simpleTestAggNoGB") { + assertSqlSubstraitRelRoundTrip("select count(l_tax), count(distinct l_discount) from lineitem") + } + + test("simpleTestApproxCountDistinct") { + val query = "select approx_count_distinct(l_tax) from lineitem" + val plan = assertSqlSubstraitRelRoundTrip(query) + } + + test("simpleTestDateInterval") { + assertSqlSubstraitRelRoundTrip( + "select l_partkey+l_orderkey, l_shipdate from lineitem where l_shipdate < date '1998-01-01' ") + assertSqlSubstraitRelRoundTrip( + "select l_partkey+l_orderkey, l_shipdate from lineitem " + + "where l_shipdate < date '1998-01-01' + interval '3' month ") + assertSqlSubstraitRelRoundTrip( + "select l_partkey+l_orderkey, l_shipdate from lineitem " + + "where l_shipdate < date '1998-01-01' + interval '1' year") + assertSqlSubstraitRelRoundTrip( + "select l_partkey+l_orderkey, l_shipdate from lineitem " + + "where l_shipdate < date '1998-01-01' + interval '1-3' year to month") + } + + test("simpleTestDecimal") { + assertSqlSubstraitRelRoundTrip( + "select l_partkey + l_orderkey, l_extendedprice * 0.1 + 100.0 from lineitem" + + " where l_shipdate < date '1998-01-01' ") + } + + ignore("simpleTestGroupingSets [has Expand]") { + assertSqlSubstraitRelRoundTrip( + "select sum(l_discount) from lineitem group by grouping sets " + + "((l_orderkey, L_COMMITDATE), l_shipdate)") + + assertSqlSubstraitRelRoundTrip( + "select sum(l_discount) from lineitem group by grouping sets " + + "((l_orderkey, L_COMMITDATE), l_shipdate), l_linestatus") + + assertSqlSubstraitRelRoundTrip( + "select sum(l_discount) from lineitem group by grouping sets " + + "((l_orderkey, L_COMMITDATE), l_shipdate, ())") + + assertSqlSubstraitRelRoundTrip( + "select sum(l_discount) from lineitem group by grouping sets " + + "((l_orderkey, L_COMMITDATE), l_shipdate, ()), l_linestatus") + assertSqlSubstraitRelRoundTrip( + "select sum(l_discount) from lineitem group by grouping sets " + + "((l_orderkey, L_COMMITDATE), (l_orderkey, L_COMMITDATE, l_linestatus), l_shipdate, ())") + } + + test("tpch_q1_variant") { + // difference from tpch_q1 : 1) remove order by clause; 2) remove interval date literal + assertSqlSubstraitRelRoundTrip( + "select\n" + + " l_returnflag,\n" + + " l_linestatus,\n" + + " sum(l_quantity) as sum_qty,\n" + + " sum(l_extendedprice) as sum_base_price,\n" + + " sum(l_extendedprice * (1.0 - l_discount)) as sum_disc_price,\n" + + " sum(l_extendedprice * (1.0 - l_discount) * (1.0 + l_tax)) as sum_charge,\n" + + " avg(l_quantity) as avg_qty,\n" + + " avg(l_extendedprice) as avg_price,\n" + + " avg(l_discount) as avg_disc,\n" + + " count(*) as count_order\n" + + "from\n" + + " lineitem\n" + + "where\n" + + " l_shipdate <= date '1998-12-01' \n" + + "group by\n" + + " l_returnflag,\n" + + " l_linestatus\n") + } +} diff --git a/spark/src/test/scala/io/substrait/spark/expression/ArithmeticExpressionSuite.scala b/spark/src/test/scala/io/substrait/spark/expression/ArithmeticExpressionSuite.scala new file mode 100644 index 000000000..f94230b10 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/expression/ArithmeticExpressionSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{IntegerType, LongType} + +import io.substrait.`type`.TypeCreator +import io.substrait.expression.{Expression => SExpression, ExpressionCreator} +import io.substrait.expression.Expression.FailureBehavior + +class ArithmeticExpressionSuite extends SparkFunSuite with SubstraitExpressionTestBase { + + test("+ (Add)") { + runTest( + "add:i64_i64", + Add(Literal(1), Literal(2L)), + func => { + assertResult(true)(func.arguments().get(1).isInstanceOf[SExpression.I64Literal]) + assertResult( + ExpressionCreator.cast( + TypeCreator.REQUIRED.I64, + ExpressionCreator.i32(false, 1), + FailureBehavior.THROW_EXCEPTION + ))(func.arguments().get(0)) + }, + bidirectional = false + ) // TODO: implicit calcite cast + + runTest( + "add:i64_i64", + Add(Cast(Literal(1), LongType), Literal(2L)), + func => {}, + bidirectional = true) + + runTest("add:i32_i32", Add(Literal(1), Cast(Literal(2L), IntegerType))) + + runTest( + "add:i32_i32", + Add(Literal(1), Literal(2)), + func => { + assertResult(true)(func.arguments().get(0).isInstanceOf[SExpression.I32Literal]) + assertResult(true)(func.arguments().get(1).isInstanceOf[SExpression.I32Literal]) + }, + bidirectional = true + ) + } +} diff --git a/spark/src/test/scala/io/substrait/spark/expression/PredicateSuite.scala b/spark/src/test/scala/io/substrait/spark/expression/PredicateSuite.scala new file mode 100644 index 000000000..254ba99c2 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/expression/PredicateSuite.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{And, Literal} + +class PredicateSuite extends SparkFunSuite with SubstraitExpressionTestBase { + + test("And") { + runTest("and:bool", And(Literal(true), Literal(false))) + } +} diff --git a/spark/src/test/scala/io/substrait/spark/expression/SubstraitExpressionTestBase.scala b/spark/src/test/scala/io/substrait/spark/expression/SubstraitExpressionTestBase.scala new file mode 100644 index 000000000..45de335bc --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/expression/SubstraitExpressionTestBase.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import io.substrait.spark.SparkExtension + +import org.apache.spark.sql.catalyst.expressions.Expression + +import io.substrait.expression.{Expression => SExpression} +import org.scalatest.Assertions.assertResult + +trait SubstraitExpressionTestBase { + + private val toSparkExpression = + new ToSparkExpression(ToScalarFunction(SparkExtension.SparkScalarFunctions)) + + private val toSubstraitExpression = new ToSubstraitExpression { + override protected val toScalarFunction: ToScalarFunction = + ToScalarFunction(SparkExtension.SparkScalarFunctions) + } + + protected def runTest(expectedName: String, expression: Expression): Unit = { + runTest(expectedName, expression, func => {}, bidirectional = true) + } + + protected def runTest( + expectedName: String, + expression: Expression, + f: SExpression.ScalarFunctionInvocation => Unit, + bidirectional: Boolean): Unit = { + val substraitExp = toSubstraitExpression(expression) + .asInstanceOf[SExpression.ScalarFunctionInvocation] + assertResult(expectedName)(substraitExp.declaration().key()) + f(substraitExp) + + if (bidirectional) { + val convertedExpression = substraitExp.accept(toSparkExpression) + assertResult(expression)(convertedExpression) + } + } +} diff --git a/spark/src/test/scala/io/substrait/spark/expression/YamlTest.scala b/spark/src/test/scala/io/substrait/spark/expression/YamlTest.scala new file mode 100644 index 000000000..e855e0d47 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/expression/YamlTest.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import io.substrait.spark.SparkExtension + +import org.apache.spark.SparkFunSuite + +class YamlTest extends SparkFunSuite { + + test("has_year_definition") { + assert( + SparkExtension.SparkScalarFunctions + .map(f => f.key()) + .exists(p => p.equals("year:date"))) + assert( + SparkExtension.SparkScalarFunctions + .map(f => f.key()) + .exists(p => p.equals("unscaled:dec"))) + } +} diff --git a/spark/src/test/spark-3.2/org/apache/spark/sql/TPCBase.scala b/spark/src/test/spark-3.2/org/apache/spark/sql/TPCBase.scala new file mode 100644 index 000000000..c2c0beacb --- /dev/null +++ b/spark/src/test/spark-3.2/org/apache/spark/sql/TPCBase.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.SparkConf +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +trait TPCBase extends SharedSparkSession { + + protected def injectStats: Boolean = false + + override protected def sparkConf: SparkConf = { + if (injectStats) { + super.sparkConf + .set(SQLConf.MAX_TO_STRING_FIELDS, Int.MaxValue) + .set(SQLConf.CBO_ENABLED, true) + .set(SQLConf.PLAN_STATS_ENABLED, true) + .set(SQLConf.JOIN_REORDER_ENABLED, true) + } else { + super.sparkConf.set(SQLConf.MAX_TO_STRING_FIELDS, Int.MaxValue) + } + } + + override def beforeAll(): Unit = { + super.beforeAll() + createTables() + } + + override def afterAll(): Unit = { + dropTables() + super.afterAll() + } + + protected def createTables(): Unit + + protected def dropTables(): Unit +} diff --git a/spark/src/test/spark-3.2/org/apache/spark/sql/TPCHBase.scala b/spark/src/test/spark-3.2/org/apache/spark/sql/TPCHBase.scala new file mode 100644 index 000000000..c3247c5ce --- /dev/null +++ b/spark/src/test/spark-3.2/org/apache/spark/sql/TPCHBase.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.TableIdentifier + +trait TPCHBase extends TPCBase { + + override def createTables(): Unit = { + tpchCreateTable.values.foreach(sql => spark.sql(sql)) + } + + override def dropTables(): Unit = { + tpchCreateTable.keys.foreach { + tableName => spark.sessionState.catalog.dropTable(TableIdentifier(tableName), true, true) + } + } + + val tpchCreateTable = Map( + "orders" -> + """ + |CREATE TABLE `orders` ( + |`o_orderkey` BIGINT, `o_custkey` BIGINT, `o_orderstatus` STRING, + |`o_totalprice` DECIMAL(10,0), `o_orderdate` DATE, `o_orderpriority` STRING, + |`o_clerk` STRING, `o_shippriority` INT, `o_comment` STRING) + |USING parquet + """.stripMargin, + "nation" -> + """ + |CREATE TABLE `nation` ( + |`n_nationkey` BIGINT, `n_name` STRING, `n_regionkey` BIGINT, `n_comment` STRING) + |USING parquet + """.stripMargin, + "region" -> + """ + |CREATE TABLE `region` ( + |`r_regionkey` BIGINT, `r_name` STRING, `r_comment` STRING) + |USING parquet + """.stripMargin, + "part" -> + """ + |CREATE TABLE `part` (`p_partkey` BIGINT, `p_name` STRING, `p_mfgr` STRING, + |`p_brand` STRING, `p_type` STRING, `p_size` INT, `p_container` STRING, + |`p_retailprice` DECIMAL(10,0), `p_comment` STRING) + |USING parquet + """.stripMargin, + "partsupp" -> + """ + |CREATE TABLE `partsupp` (`ps_partkey` BIGINT, `ps_suppkey` BIGINT, + |`ps_availqty` INT, `ps_supplycost` DECIMAL(10,0), `ps_comment` STRING) + |USING parquet + """.stripMargin, + "customer" -> + """ + |CREATE TABLE `customer` (`c_custkey` BIGINT, `c_name` STRING, `c_address` STRING, + |`c_nationkey` BIGINT, `c_phone` STRING, `c_acctbal` DECIMAL(10,0), + |`c_mktsegment` STRING, `c_comment` STRING) + |USING parquet + """.stripMargin, + "supplier" -> + """ + |CREATE TABLE `supplier` (`s_suppkey` BIGINT, `s_name` STRING, `s_address` STRING, + |`s_nationkey` BIGINT, `s_phone` STRING, `s_acctbal` DECIMAL(10,0), `s_comment` STRING) + |USING parquet + """.stripMargin, + "lineitem" -> + """ + |CREATE TABLE `lineitem` (`l_orderkey` BIGINT, `l_partkey` BIGINT, `l_suppkey` BIGINT, + |`l_linenumber` INT, `l_quantity` DECIMAL(10,0), `l_extendedprice` DECIMAL(10,0), + |`l_discount` DECIMAL(10,0), `l_tax` DECIMAL(10,0), `l_returnflag` STRING, + |`l_linestatus` STRING, `l_shipdate` DATE, `l_commitdate` DATE, `l_receiptdate` DATE, + |`l_shipinstruct` STRING, `l_shipmode` STRING, `l_comment` STRING) + |USING parquet + """.stripMargin + ) + + val tpchQueries = Seq( + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "q16", + "q17", + "q18", + "q19", + "q20", + "q21", + "q22") +} diff --git a/substrait b/substrait index 52e81a9fe..a68c1ac62 160000 --- a/substrait +++ b/substrait @@ -1 +1 @@ -Subproject commit 52e81a9fe725881036eddaa77ae0dba8b2ad6f83 +Subproject commit a68c1ac62f92d703da624cb8ac0cef854dd2b35f