diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml new file mode 100644 index 000000000000..13a3008b74bc --- /dev/null +++ b/.github/actions/setup-builder/action.yaml @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Prepare Rust Builder +description: 'Prepare Rust Build Environment' +inputs: + rust-version: + description: 'version of rust to install (e.g. stable)' + required: true + default: 'stable' +runs: + using: "composite" + steps: + - name: Install Build Dependencies + shell: bash + run: | + apt-get update + apt-get install -y protobuf-compiler + - name: Setup Rust toolchain + shell: bash + run: | + echo "Installing ${{ inputs.rust-version }}" + rustup toolchain install ${{ inputs.rust-version }} + rustup default ${{ inputs.rust-version }} + rustup component add rustfmt diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 38c4f2fc6ea3..21c9cfefed0b 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -54,10 +54,9 @@ jobs: path: /github/home/target key: ${{ runner.os }}-${{ matrix.arch }}-target-cache-${{ matrix.rust }}- - name: Setup Rust toolchain - run: | - rustup toolchain install ${{ matrix.rust }} - rustup default ${{ matrix.rust }} - rustup component add rustfmt + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ matrix.rust }} - name: Build workspace in debug mode run: | cargo build @@ -117,10 +116,9 @@ jobs: # this key equals the ones on `linux-build-lib` for re-use key: ${{ runner.os }}-${{ matrix.arch }}-target-cache-${{ matrix.rust }} - name: Setup Rust toolchain - run: | - rustup toolchain install ${{ matrix.rust }} - rustup default ${{ matrix.rust }} - rustup component add rustfmt + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ matrix.rust }} - name: Run tests run: | export ARROW_TEST_DATA=$(pwd)/testing/data @@ -285,10 +283,9 @@ jobs: echo "LIBRARY_PATH=$LD_LIBRARY_PATH" >> $GITHUB_ENV python -m pip install pyarrow - name: Setup Rust toolchain - run: | - rustup toolchain install ${{ matrix.rust }} - rustup default ${{ matrix.rust }} - rustup component add rustfmt + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ matrix.rust }} - name: Run tests run: | cd datafusion @@ -343,10 +340,12 @@ jobs: # this key equals the ones on `linux-build-lib` for re-use key: ${{ runner.os }}-${{ matrix.arch }}-target-cache-${{ matrix.rust }} - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ matrix.rust }} + - name: Install Clippy run: | - rustup toolchain install ${{ matrix.rust }} - rustup default ${{ matrix.rust }} - rustup component add rustfmt clippy + rustup component add clippy - name: Run clippy run: | cargo clippy --all-targets --workspace -- -D warnings @@ -420,10 +419,9 @@ jobs: # this key equals the ones on `linux-build-lib` for re-use key: ${{ runner.os }}-${{ matrix.arch }}-target-cache-${{ matrix.rust }} - name: Setup Rust toolchain - run: | - rustup toolchain install ${{ matrix.rust }} - rustup default ${{ matrix.rust }} - rustup component add rustfmt + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ matrix.rust }} - name: Run tests run: | export ARROW_TEST_DATA=$(pwd)/testing/data @@ -466,9 +464,9 @@ jobs: # this key equals the ones on `linux-build-lib` for re-use key: ${{ runner.os }}-${{ matrix.arch }}-target-cache-${{ matrix.rust }} - name: Setup Rust toolchain - run: | - rustup toolchain install ${{ matrix.rust }} - rustup default ${{ matrix.rust }} + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ matrix.rust }} - name: Install cargo-tomlfmt run: | which cargo-tomlfmt || cargo install cargo-tomlfmt diff --git a/ballista-examples/Cargo.toml b/ballista-examples/Cargo.toml index fb956ea2cd9c..afdd4862b459 100644 --- a/ballista-examples/Cargo.toml +++ b/ballista-examples/Cargo.toml @@ -39,6 +39,6 @@ ballista = { path = "../ballista/rust/client", version = "0.6.0" } datafusion = { path = "../datafusion/core" } futures = "0.3" num_cpus = "1.13.0" -prost = "0.9" +prost = "0.10" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } -tonic = "0.6" +tonic = "0.7" diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index a89ccad788a7..c0b568ef684c 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -34,7 +34,7 @@ simd = ["datafusion/simd"] [dependencies] ahash = { version = "0.7", default-features = false } -arrow-flight = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "5eec82b6bde95d824bb9f3721789c98e033f812f" } +arrow-flight = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "85c9d642464886bf977cb6b95fad7c28d0b64cc8" } async-trait = "0.1.41" chrono = { version = "0.4", default-features = false } clap = { version = "3", features = ["derive", "cargo"] } @@ -46,16 +46,16 @@ log = "0.4" parking_lot = "0.12" parse_arg = "0.1.3" -prost = "0.9" -prost-types = "0.9" +prost = "0.10" +prost-types = "0.10" serde = { version = "1", features = ["derive"] } sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "ca8131c36325e86fd733f337673e6cae5e946711" } tokio = "1.0" -tonic = "0.6" +tonic = "0.7" uuid = { version = "0.8", features = ["v4"] } [dev-dependencies] tempfile = "3" [build-dependencies] -tonic-build = { version = "0.6" } +tonic-build = { version = "0.7" } diff --git a/ballista/rust/executor/Cargo.toml b/ballista/rust/executor/Cargo.toml index 438391422879..6ddce8981e21 100644 --- a/ballista/rust/executor/Cargo.toml +++ b/ballista/rust/executor/Cargo.toml @@ -33,8 +33,8 @@ snmalloc = ["snmalloc-rs"] [dependencies] anyhow = "1" -arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "5eec82b6bde95d824bb9f3721789c98e033f812f" } -arrow-flight = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "5eec82b6bde95d824bb9f3721789c98e033f812f" } +arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "85c9d642464886bf977cb6b95fad7c28d0b64cc8" } +arrow-flight = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "85c9d642464886bf977cb6b95fad7c28d0b64cc8" } async-trait = "0.1.41" ballista-core = { path = "../core", version = "0.6.0" } chrono = { version = "0.4", default-features = false } @@ -49,7 +49,7 @@ snmalloc-rs = { version = "0.2", optional = true } tempfile = "3" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "parking_lot"] } tokio-stream = { version = "0.1", features = ["net"] } -tonic = "0.6" +tonic = "0.7" uuid = { version = "0.8", features = ["v4"] } [dev-dependencies] diff --git a/ballista/rust/scheduler/Cargo.toml b/ballista/rust/scheduler/Cargo.toml index 25465adf53e5..884573c2dce4 100644 --- a/ballista/rust/scheduler/Cargo.toml +++ b/ballista/rust/scheduler/Cargo.toml @@ -42,7 +42,7 @@ clap = { version = "3", features = ["derive", "cargo"] } configure_me = "0.4.0" datafusion = { path = "../../../datafusion/core", version = "7.0.0" } env_logger = "0.9" -etcd-client = { version = "0.8", optional = true } +etcd-client = { version = "0.9", optional = true } futures = "0.3" http = "0.2" http-body = "0.4" @@ -50,13 +50,13 @@ hyper = "0.14.4" log = "0.4" parking_lot = "0.12" parse_arg = "0.1.3" -prost = "0.9" +prost = "0.10" rand = "0.8" serde = { version = "1", features = ["derive"] } sled_package = { package = "sled", version = "0.34", optional = true } tokio = { version = "1.0", features = ["full"] } tokio-stream = { version = "0.1", features = ["net"], optional = true } -tonic = "0.6" +tonic = "0.7" tower = { version = "0.4" } warp = "0.3" @@ -66,4 +66,4 @@ uuid = { version = "0.8", features = ["v4"] } [build-dependencies] configure_me_codegen = "0.4.1" -tonic-build = { version = "0.6" } +tonic-build = { version = "0.7" } diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index e2c41df24197..0e8795afa9ec 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -28,7 +28,7 @@ repository = "https://github.com/apache/arrow-datafusion" rust-version = "1.59" [dependencies] -arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "5eec82b6bde95d824bb9f3721789c98e033f812f" } +arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "85c9d642464886bf977cb6b95fad7c28d0b64cc8" } ballista = { path = "../ballista/rust/client", version = "0.6.0", optional = true } clap = { version = "3", features = ["derive", "cargo"] } datafusion = { path = "../datafusion/core", version = "7.0.0" } diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index e1f2c5a24a9c..1e4b592ea77e 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -34,11 +34,11 @@ path = "examples/avro_sql.rs" required-features = ["datafusion/avro"] [dev-dependencies] -arrow-flight = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "5eec82b6bde95d824bb9f3721789c98e033f812f" } +arrow-flight = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "85c9d642464886bf977cb6b95fad7c28d0b64cc8" } async-trait = "0.1.41" datafusion = { path = "../datafusion/core" } futures = "0.3" num_cpus = "1.13.0" -prost = "0.9" +prost = "0.10" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } -tonic = "0.6" +tonic = "0.7" diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index dd0257427b81..3ae2574be651 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -38,10 +38,10 @@ jit = ["cranelift-module"] pyarrow = ["pyo3"] [dependencies] -arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "5eec82b6bde95d824bb9f3721789c98e033f812f", features = ["prettyprint"] } +arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "85c9d642464886bf977cb6b95fad7c28d0b64cc8", features = ["prettyprint"] } avro-rs = { version = "0.13", features = ["snappy"], optional = true } cranelift-module = { version = "0.82.0", optional = true } ordered-float = "2.10" -parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "5eec82b6bde95d824bb9f3721789c98e033f812f", features = ["arrow"], optional = true } +parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "85c9d642464886bf977cb6b95fad7c28d0b64cc8", features = ["arrow"], optional = true } pyo3 = { version = "0.16", optional = true } sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "ca8131c36325e86fd733f337673e6cae5e946711" } diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index d74330aa58fa..20201701066d 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -39,6 +39,8 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; /// This is the single-valued counter-part of arrow’s `Array`. #[derive(Clone)] pub enum ScalarValue { + /// represents `DataType::Null` (castable to/from any other type) + Null, /// true or false value Boolean(Option), /// 32bit float @@ -170,6 +172,8 @@ impl PartialEq for ScalarValue { (IntervalMonthDayNano(_), _) => false, (Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2), (Struct(_, _), _) => false, + (Null, Null) => true, + (Null, _) => false, } } } @@ -270,6 +274,8 @@ impl PartialOrd for ScalarValue { } } (Struct(_, _), _) => None, + (Null, Null) => Some(Ordering::Equal), + (Null, _) => None, } } } @@ -325,6 +331,8 @@ impl std::hash::Hash for ScalarValue { v.hash(state); t.hash(state); } + // stable hash for Null value + Null => 1.hash(state), } } } @@ -594,6 +602,7 @@ impl ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) } ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()), + ScalarValue::Null => DataType::Null, } } @@ -623,7 +632,8 @@ impl ScalarValue { pub fn is_null(&self) -> bool { matches!( *self, - ScalarValue::Boolean(None) + ScalarValue::Null + | ScalarValue::Boolean(None) | ScalarValue::UInt8(None) | ScalarValue::UInt16(None) | ScalarValue::UInt32(None) @@ -844,6 +854,7 @@ impl ScalarValue { ScalarValue::iter_to_decimal_array(scalars, precision, scale)?; Arc::new(decimal_array) } + DataType::Null => ScalarValue::iter_to_null_array(scalars), DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), DataType::Float32 => build_array_primitive!(Float32Array, Float32), DataType::Float64 => build_array_primitive!(Float64Array, Float64), @@ -976,6 +987,17 @@ impl ScalarValue { Ok(array) } + fn iter_to_null_array(scalars: impl IntoIterator) -> ArrayRef { + let length = + scalars + .into_iter() + .fold(0usize, |r, element: ScalarValue| match element { + ScalarValue::Null => r + 1, + _ => unreachable!(), + }); + new_null_array(&DataType::Null, length) + } + fn iter_to_decimal_array( scalars: impl IntoIterator, precision: &usize, @@ -1249,6 +1271,7 @@ impl ScalarValue { Arc::new(StructArray::from(field_values)) } }, + ScalarValue::Null => new_null_array(&DataType::Null, size), } } @@ -1274,6 +1297,7 @@ impl ScalarValue { } Ok(match array.data_type() { + DataType::Null => ScalarValue::Null, DataType::Decimal(precision, scale) => { ScalarValue::get_decimal_value_from_array(array, index, precision, scale) } @@ -1519,6 +1543,7 @@ impl ScalarValue { eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val) } ScalarValue::Struct(_, _) => unimplemented!(), + ScalarValue::Null => array.data().is_null(index), } } @@ -1740,6 +1765,7 @@ impl TryFrom<&DataType> for ScalarValue { DataType::Struct(fields) => { ScalarValue::Struct(None, Box::new(fields.clone())) } + DataType::Null => ScalarValue::Null, _ => { return Err(DataFusionError::NotImplemented(format!( "Can't create a scalar from data_type \"{:?}\"", @@ -1832,6 +1858,7 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, + ScalarValue::Null => write!(f, "NULL")?, }; Ok(()) } @@ -1899,6 +1926,7 @@ impl fmt::Debug for ScalarValue { None => write!(f, "Struct(NULL)"), } } + ScalarValue::Null => write!(f, "NULL"), } } } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index a129823ba26c..daba1edae5b1 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -55,7 +55,7 @@ unicode_expressions = ["datafusion-physical-expr/regex_expressions"] [dependencies] ahash = { version = "0.7", default-features = false } -arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "5eec82b6bde95d824bb9f3721789c98e033f812f", features = ["prettyprint"] } +arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "85c9d642464886bf977cb6b95fad7c28d0b64cc8", features = ["prettyprint"] } async-trait = "0.1.41" avro-rs = { version = "0.13", features = ["snappy"], optional = true } chrono = { version = "0.4", default-features = false } @@ -72,7 +72,7 @@ num-traits = { version = "0.2", optional = true } num_cpus = "1.13.0" ordered-float = "2.10" parking_lot = "0.12" -parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "5eec82b6bde95d824bb9f3721789c98e033f812f", features = ["arrow"] } +parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "85c9d642464886bf977cb6b95fad7c28d0b64cc8", features = ["arrow"] } paste = "^1.0" pin-project-lite= "^0.2.7" pyo3 = { version = "0.16", optional = true } diff --git a/datafusion/core/fuzz-utils/Cargo.toml b/datafusion/core/fuzz-utils/Cargo.toml index 5c8e7d327216..1af47949eaae 100644 --- a/datafusion/core/fuzz-utils/Cargo.toml +++ b/datafusion/core/fuzz-utils/Cargo.toml @@ -23,6 +23,6 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "5eec82b6bde95d824bb9f3721789c98e033f812f", features = ["prettyprint"] } +arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "85c9d642464886bf977cb6b95fad7c28d0b64cc8", features = ["prettyprint"] } env_logger = "0.9.0" rand = "0.8" diff --git a/datafusion/core/src/logical_plan/builder.rs b/datafusion/core/src/logical_plan/builder.rs index c83857f1e47d..cd6df27a08ed 100644 --- a/datafusion/core/src/logical_plan/builder.rs +++ b/datafusion/core/src/logical_plan/builder.rs @@ -154,7 +154,7 @@ impl LogicalPlanBuilder { .iter() .enumerate() .map(|(j, expr)| { - if let Expr::Literal(ScalarValue::Utf8(None)) = expr { + if let Expr::Literal(ScalarValue::Null) = expr { nulls.push((i, j)); Ok(field_types[j].clone()) } else { diff --git a/datafusion/core/src/logical_plan/expr_schema.rs b/datafusion/core/src/logical_plan/expr_schema.rs index 1a8adbf04264..da0b6f7eabcc 100644 --- a/datafusion/core/src/logical_plan/expr_schema.rs +++ b/datafusion/core/src/logical_plan/expr_schema.rs @@ -172,7 +172,9 @@ impl ExprSchemable for Expr { } else if let Some(e) = else_expr { e.nullable(input_schema) } else { - Ok(false) + // CASE produces NULL if there is no `else` expr + // (aka when none of the `when_then_exprs` match) + Ok(true) } } Expr::Cast { expr, .. } => expr.nullable(input_schema), diff --git a/datafusion/core/src/physical_plan/file_format/json.rs b/datafusion/core/src/physical_plan/file_format/json.rs index a6e7840e7d33..fce9cca50ede 100644 --- a/datafusion/core/src/physical_plan/file_format/json.rs +++ b/datafusion/core/src/physical_plan/file_format/json.rs @@ -16,6 +16,7 @@ // under the License. //! Execution plan for reading line-delimited JSON files +use arrow::json::reader::DecoderOptions; use async_trait::async_trait; use crate::error::{DataFusionError, Result}; @@ -109,12 +110,19 @@ impl ExecutionPlan for NdJsonExec { // The json reader cannot limit the number of records, so `remaining` is ignored. let fun = move |file, _remaining: &Option| { - Box::new(json::Reader::new( - file, - Arc::clone(&file_schema), - batch_size, - proj.clone(), - )) as BatchIter + // TODO: make DecoderOptions implement Clone so we can + // clone here rather than recreating the options each time + // https://github.com/apache/arrow-rs/issues/1580 + let options = DecoderOptions::new().with_batch_size(batch_size); + + let options = if let Some(proj) = proj.clone() { + options.with_projection(proj) + } else { + options + }; + + Box::new(json::Reader::new(file, Arc::clone(&file_schema), options)) + as BatchIter }; Ok(Box::pin(FileStream::new( diff --git a/datafusion/core/src/physical_plan/file_format/parquet.rs b/datafusion/core/src/physical_plan/file_format/parquet.rs index 310c5e77ee63..1af9d2eacdd3 100644 --- a/datafusion/core/src/physical_plan/file_format/parquet.rs +++ b/datafusion/core/src/physical_plan/file_format/parquet.rs @@ -1233,9 +1233,10 @@ mod tests { .enumerate() .map(|(i, g)| row_group_predicate(g, i)) .collect::>(); - // no row group is filtered out because the predicate expression can't be evaluated - // when a null array is generated for a statistics column, - assert_eq!(row_group_filter, vec![true, true]); + + // bool = NULL always evaluates to NULL (and thus will not + // pass predicates. Ideally these should both be false + assert_eq!(row_group_filter, vec![false, true]); Ok(()) } diff --git a/datafusion/core/src/physical_plan/functions.rs b/datafusion/core/src/physical_plan/functions.rs index 4127afe79df3..cdea2c7c16ea 100644 --- a/datafusion/core/src/physical_plan/functions.rs +++ b/datafusion/core/src/physical_plan/functions.rs @@ -63,6 +63,7 @@ macro_rules! make_utf8_to_return_type { Ok(match arg_type { DataType::LargeUtf8 => $largeUtf8Type, DataType::Utf8 => $utf8Type, + DataType::Null => DataType::Null, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal(format!( diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs index 49783c6f326a..55ad1c947730 100644 --- a/datafusion/core/src/physical_plan/hash_join.rs +++ b/datafusion/core/src/physical_plan/hash_join.rs @@ -837,7 +837,11 @@ fn equal_rows( .iter() .zip(right_arrays) .all(|(l, r)| match l.data_type() { - DataType::Null => true, + DataType::Null => { + // lhs and rhs are both `DataType::Null`, so the euqal result + // is dependent on `null_equals_null` + null_equals_null + } DataType::Boolean => { equal_rows_elem!(BooleanArray, l, r, left, right, null_equals_null) } diff --git a/datafusion/core/src/physical_plan/hash_utils.rs b/datafusion/core/src/physical_plan/hash_utils.rs index 4e503b19e7bf..2ca1fa3df9d1 100644 --- a/datafusion/core/src/physical_plan/hash_utils.rs +++ b/datafusion/core/src/physical_plan/hash_utils.rs @@ -39,6 +39,19 @@ fn combine_hashes(l: u64, r: u64) -> u64 { hash.wrapping_mul(37).wrapping_add(r) } +fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: bool) { + if mul_col { + hashes_buffer.iter_mut().for_each(|hash| { + // stable hash for null value + *hash = combine_hashes(i128::get_hash(&1, random_state), *hash); + }) + } else { + hashes_buffer.iter_mut().for_each(|hash| { + *hash = i128::get_hash(&1, random_state); + }) + } +} + fn hash_decimal128<'a>( array: &ArrayRef, random_state: &RandomState, @@ -284,6 +297,9 @@ pub fn create_hashes<'a>( for col in arrays { match col.data_type() { + DataType::Null => { + hash_null(random_state, hashes_buffer, multi_col); + } DataType::Decimal(_, _) => { hash_decimal128(col, random_state, hashes_buffer, multi_col); } diff --git a/datafusion/core/src/physical_plan/projection.rs b/datafusion/core/src/physical_plan/projection.rs index 4419b6f91276..c8a15f2f3c88 100644 --- a/datafusion/core/src/physical_plan/projection.rs +++ b/datafusion/core/src/physical_plan/projection.rs @@ -215,7 +215,7 @@ fn get_field_metadata( input_schema .field_with_name(name) .ok() - .and_then(|f| f.metadata().as_ref().cloned()) + .and_then(|f| f.metadata().cloned()) } fn stats_projection( @@ -340,7 +340,7 @@ mod tests { )?; let col_field = projection.schema.field(0); - let col_metadata = col_field.metadata().clone().unwrap().clone(); + let col_metadata = col_field.metadata().unwrap().clone(); let data: &str = &col_metadata["testing"]; assert_eq!(data, "test"); diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index 67a6e5fec244..89a99e1a0c24 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -804,7 +804,7 @@ mod tests { // explicitlty ensure the metadata is present assert_eq!( result[0].schema().fields()[0].metadata(), - &Some(field_metadata) + Some(&field_metadata) ); assert_eq!(result[0].schema().metadata(), &schema_metadata); diff --git a/datafusion/core/src/physical_plan/table_fun.rs b/datafusion/core/src/physical_plan/table_fun.rs index 3c202dcd71e5..082f1fdd383a 100644 --- a/datafusion/core/src/physical_plan/table_fun.rs +++ b/datafusion/core/src/physical_plan/table_fun.rs @@ -218,7 +218,7 @@ fn get_field_metadata( input_schema .field_with_name(name) .ok() - .and_then(|f| f.metadata().as_ref().cloned()) + .and_then(|f| f.metadata().cloned()) } fn stats_table_fun( diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index 47eadf26deaf..ca9732fa9fba 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -1609,7 +1609,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n), SQLExpr::Value(Value::SingleQuotedString(s)) => Ok(lit(s)), SQLExpr::Value(Value::Null) => { - Ok(Expr::Literal(ScalarValue::Utf8(None))) + Ok(Expr::Literal(ScalarValue::Null)) } SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)), SQLExpr::UnaryOp { op, expr } => { @@ -1635,7 +1635,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())), SQLExpr::Value(Value::EscapedStringLiteral(ref s)) => Ok(lit(s.clone())), SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)), - SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Utf8(None))), + SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Null)), SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction { fun: functions::BuiltinScalarFunction::DatePart, args: vec![ diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 9ef49d3f79a3..90930f6bb581 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -16,6 +16,7 @@ // under the License. use super::*; +use datafusion::datasource::empty::EmptyTable; #[tokio::test] async fn case_when() -> Result<()> { @@ -109,6 +110,106 @@ async fn case_when_else_with_base_expr() -> Result<()> { Ok(()) } +#[tokio::test] +async fn case_when_else_with_null_contant() -> Result<()> { + let ctx = create_case_context()?; + let sql = "SELECT \ + CASE WHEN c1 = 'a' THEN 1 \ + WHEN NULL THEN 2 \ + ELSE 999 END \ + FROM t1"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----------------------------------------------------------------------------------------+", + "| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN NULL THEN Int64(2) ELSE Int64(999) END |", + "+----------------------------------------------------------------------------------------+", + "| 1 |", + "| 999 |", + "| 999 |", + "| 999 |", + "+----------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT CASE WHEN NULL THEN 'foo' ELSE 'bar' END"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------------------------------------------------+", + "| CASE WHEN NULL THEN Utf8(\"foo\") ELSE Utf8(\"bar\") END |", + "+------------------------------------------------------+", + "| bar |", + "+------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn case_expr_with_null() -> Result<()> { + let ctx = SessionContext::new(); + let sql = "select case when b is null then null else b end from (select a,b from (values (1,null),(2,3)) as t (a,b)) a;"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+------------------------------------------------+", + "| CASE WHEN #a.b IS NULL THEN NULL ELSE #a.b END |", + "+------------------------------------------------+", + "| |", + "| 3 |", + "+------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "select case when b is null then null else b end from (select a,b from (values (1,1),(2,3)) as t (a,b)) a;"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+------------------------------------------------+", + "| CASE WHEN #a.b IS NULL THEN NULL ELSE #a.b END |", + "+------------------------------------------------+", + "| 1 |", + "| 3 |", + "+------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn case_expr_with_nulls() -> Result<()> { + let ctx = SessionContext::new(); + let sql = "select case when b is null then null when b < 3 then null when b >=3 then b + 1 else b end from (select a,b from (values (1,null),(1,2),(2,3)) as t (a,b)) a"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+--------------------------------------------------------------------------------------------------------------------------+", + "| CASE WHEN #a.b IS NULL THEN NULL WHEN #a.b < Int64(3) THEN NULL WHEN #a.b >= Int64(3) THEN #a.b + Int64(1) ELSE #a.b END |", + "+--------------------------------------------------------------------------------------------------------------------------+", + "| |", + "| |", + "| 4 |", + "+--------------------------------------------------------------------------------------------------------------------------+" + ]; + assert_batches_eq!(expected, &actual); + + let sql = "select case b when 1 then null when 2 then null when 3 then b + 1 else b end from (select a,b from (values (1,null),(1,2),(2,3)) as t (a,b)) a;"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+------------------------------------------------------------------------------------------------------------+", + "| CASE #a.b WHEN Int64(1) THEN NULL WHEN Int64(2) THEN NULL WHEN Int64(3) THEN #a.b + Int64(1) ELSE #a.b END |", + "+------------------------------------------------------------------------------------------------------------+", + "| |", + "| |", + "| 4 |", + "+------------------------------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + #[tokio::test] async fn query_not() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Boolean, true)])); @@ -311,11 +412,11 @@ async fn test_string_concat_operator() -> Result<()> { let sql = "SELECT 'aa' || NULL || 'd'"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+---------------------------------------+", - "| Utf8(\"aa\") || Utf8(NULL) || Utf8(\"d\") |", - "+---------------------------------------+", - "| |", - "+---------------------------------------+", + "+---------------------------------+", + "| Utf8(\"aa\") || NULL || Utf8(\"d\") |", + "+---------------------------------+", + "| |", + "+---------------------------------+", ]; assert_batches_eq!(expected, &actual); @@ -333,6 +434,45 @@ async fn test_string_concat_operator() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_not_expressions() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "SELECT not(true), not(false)"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-------------------+--------------------+", + "| NOT Boolean(true) | NOT Boolean(false) |", + "+-------------------+--------------------+", + "| false | true |", + "+-------------------+--------------------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT null, not(null)"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------+----------+", + "| NULL | NOT NULL |", + "+------+----------+", + "| | |", + "+------+----------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT NOT('hi')"; + let result = plan_and_collect(&ctx, sql).await; + match result { + Ok(_) => panic!("expected error"), + Err(e) => { + assert_contains!(e.to_string(), + "NOT 'Literal { value: Utf8(\"hi\") }' can't be evaluated because the expression's type is Utf8, not boolean or NULL" + ); + } + } + Ok(()) +} + #[tokio::test] async fn test_boolean_expressions() -> Result<()> { test_expression!("true", "true"); @@ -460,9 +600,11 @@ async fn test_array_index() -> Result<()> { #[tokio::test] async fn binary_bitwise_shift() -> Result<()> { test_expression!("2 << 10", "2048"); - test_expression!("2048 >> 10", "2"); + test_expression!("2 << NULL", "NULL"); + test_expression!("2048 >> NULL", "NULL"); + Ok(()) } @@ -1174,3 +1316,163 @@ async fn csv_query_sqrt_sqrt() -> Result<()> { assert_float_eq(&expected, &actual); Ok(()) } + +#[tokio::test] +async fn nested_subquery() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Schema::new(vec![ + Field::new("id", DataType::Int16, false), + Field::new("a", DataType::Int16, false), + ]); + let empty_table = Arc::new(EmptyTable::new(Arc::new(schema))); + ctx.register_table("t1", empty_table.clone())?; + ctx.register_table("t2", empty_table)?; + let sql = "SELECT COUNT(*) as cnt \ + FROM (\ + (SELECT id FROM t1) EXCEPT \ + (SELECT id FROM t2)\ + ) foo"; + let actual = execute_to_batches(&ctx, sql).await; + // the purpose of this test is just to make sure the query produces a valid plan + #[rustfmt::skip] + let expected = vec![ + "+-----+", + "| cnt |", + "+-----+", + "| 0 |", + "+-----+" + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn like_nlike_with_null_lt() { + let ctx = SessionContext::new(); + let sql = "SELECT column1 like NULL as col_null, NULL like column1 as null_col from (values('a'), ('b'), (NULL)) as t"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----------+----------+", + "| col_null | null_col |", + "+----------+----------+", + "| | |", + "| | |", + "| | |", + "+----------+----------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT column1 not like NULL as col_null, NULL not like column1 as null_col from (values('a'), ('b'), (NULL)) as t"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----------+----------+", + "| col_null | null_col |", + "+----------+----------+", + "| | |", + "| | |", + "| | |", + "+----------+----------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn comparisons_with_null_lt() { + let ctx = SessionContext::new(); + + // we expect all the following queries to yield a two null values + let cases = vec![ + // 1. Numeric comparison with NULL + "select column1 < NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t", + "select column1 <= NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t", + "select column1 > NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t", + "select column1 >= NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t", + "select column1 = NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t", + "select column1 != NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t", + // 1.1 Float value comparison with NULL + "select column3 < NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t", + // String comparison with NULL + "select column2 < NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t", + // Boolean comparison with NULL + "select column1 < NULL from (VALUES (true), (false)) as t", + // ---- + // ---- same queries, reversed argument order (as they go through + // ---- a different evaluation path) + // ---- + + // 1. Numeric comparison with NULL + "select NULL < column1 from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t", + "select NULL <= column1 from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t", + "select NULL > column1 from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t", + "select NULL >= column1 from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t", + "select NULL = column1 from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t", + "select NULL != column1 from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t", + // 1.1 Float value comparison with NULL + "select NULL < column3 from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t", + // String comparison with NULL + "select NULL < column2 from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t", + // Boolean comparison with NULL + "select NULL < column1 from (VALUES (true), (false)) as t", + ]; + + for sql in cases { + println!("Computing: {}", sql); + + let mut actual = execute_to_batches(&ctx, sql).await; + assert_eq!(actual.len(), 1); + + let batch = actual.pop().unwrap(); + assert_eq!(batch.num_rows(), 2); + assert_eq!(batch.num_columns(), 1); + assert!(batch.columns()[0].is_null(0)); + assert!(batch.columns()[0].is_null(1)); + } +} + +#[tokio::test] +async fn binary_mathematical_operator_with_null_lt() { + let ctx = SessionContext::new(); + + let cases = vec![ + // 1. Integer and NULL + "select column1 + NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + "select column1 - NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + "select column1 * NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + "select column1 / NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + "select column1 % NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + // 2. Float and NULL + "select column2 + NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + "select column2 - NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + "select column2 * NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + "select column2 / NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + "select column2 % NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + // ---- + // ---- same queries, reversed argument order + // ---- + // 3. NULL and Integer + "select NULL + column1 from (VALUES (1, 2.3), (2, 5.4)) as t", + "select NULL - column1 from (VALUES (1, 2.3), (2, 5.4)) as t", + "select NULL * column1 from (VALUES (1, 2.3), (2, 5.4)) as t", + "select NULL / column1 from (VALUES (1, 2.3), (2, 5.4)) as t", + "select NULL % column1 from (VALUES (1, 2.3), (2, 5.4)) as t", + // 4. NULL and Float + "select NULL + column2 from (VALUES (1, 2.3), (2, 5.4)) as t", + "select NULL - column2 from (VALUES (1, 2.3), (2, 5.4)) as t", + "select NULL * column2 from (VALUES (1, 2.3), (2, 5.4)) as t", + "select NULL / column2 from (VALUES (1, 2.3), (2, 5.4)) as t", + "select NULL % column2 from (VALUES (1, 2.3), (2, 5.4)) as t", + ]; + + for sql in cases { + println!("Computing: {}", sql); + + let mut actual = execute_to_batches(&ctx, sql).await; + assert_eq!(actual.len(), 1); + + let batch = actual.pop().unwrap(); + assert_eq!(batch.num_rows(), 2); + assert_eq!(batch.num_columns(), 1); + assert!(batch.columns()[0].is_null(0)); + assert!(batch.columns()[0].is_null(1)); + } +} diff --git a/datafusion/core/tests/sql/functions.rs b/datafusion/core/tests/sql/functions.rs index 226bb8159d78..ae86aeb17458 100644 --- a/datafusion/core/tests/sql/functions.rs +++ b/datafusion/core/tests/sql/functions.rs @@ -197,11 +197,11 @@ async fn coalesce_static_value_with_null() -> Result<()> { let sql = "SELECT COALESCE(NULL, 'test')"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+-----------------------------------+", - "| coalesce(Utf8(NULL),Utf8(\"test\")) |", - "+-----------------------------------+", - "| test |", - "+-----------------------------------+", + "+-----------------------------+", + "| coalesce(NULL,Utf8(\"test\")) |", + "+-----------------------------+", + "| test |", + "+-----------------------------+", ]; assert_batches_eq!(expected, &actual); Ok(()) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 4859041579d0..007de1bfe751 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -829,7 +829,11 @@ async fn inner_join_nulls() { let sql = "SELECT * FROM (SELECT null AS id1) t1 INNER JOIN (SELECT null AS id2) t2 ON id1 = id2"; - let expected = vec!["++", "++"]; + #[rustfmt::skip] + let expected = vec![ + "++", + "++", + ]; let ctx = create_join_context_qualified().unwrap(); let actual = execute_to_batches(&ctx, sql).await; diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 09c0fc99bb99..1af11a494ab0 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -398,15 +398,37 @@ async fn select_distinct_from() { 1 IS NOT DISTINCT FROM CAST(NULL as INT) as c, 1 IS NOT DISTINCT FROM 1 as d, NULL IS DISTINCT FROM NULL as e, - NULL IS NOT DISTINCT FROM NULL as f + NULL IS NOT DISTINCT FROM NULL as f, + NULL is DISTINCT FROM 1 as g, + NULL is NOT DISTINCT FROM 1 as h "; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+------+-------+-------+------+-------+------+", - "| a | b | c | d | e | f |", - "+------+-------+-------+------+-------+------+", - "| true | false | false | true | false | true |", - "+------+-------+-------+------+-------+------+", + "+------+-------+-------+------+-------+------+------+-------+", + "| a | b | c | d | e | f | g | h |", + "+------+-------+-------+------+-------+------+------+-------+", + "| true | false | false | true | false | true | true | false |", + "+------+-------+-------+------+-------+------+------+-------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "select + NULL IS DISTINCT FROM NULL as a, + NULL IS NOT DISTINCT FROM NULL as b, + NULL is DISTINCT FROM 1 as c, + NULL is NOT DISTINCT FROM 1 as d, + 1 IS DISTINCT FROM CAST(NULL as INT) as e, + 1 IS DISTINCT FROM 1 as f, + 1 IS NOT DISTINCT FROM CAST(NULL as INT) as g, + 1 IS NOT DISTINCT FROM 1 as h + "; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-------+------+------+-------+------+-------+-------+------+", + "| a | b | c | d | e | f | g | h |", + "+-------+------+------+-------+------+-------+-------+------+", + "| false | true | true | false | true | false | false | true |", + "+-------+------+------+-------+------+-------+-------+------+", ]; assert_batches_eq!(expected, &actual); } diff --git a/datafusion/cube_ext/Cargo.toml b/datafusion/cube_ext/Cargo.toml index 8a2e47af027f..5e4c1d8cce03 100644 --- a/datafusion/cube_ext/Cargo.toml +++ b/datafusion/cube_ext/Cargo.toml @@ -35,7 +35,7 @@ name = "cube_ext" path = "src/lib.rs" [dependencies] -arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "5eec82b6bde95d824bb9f3721789c98e033f812f", features = ["prettyprint"] } +arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "85c9d642464886bf977cb6b95fad7c28d0b64cc8", features = ["prettyprint"] } datafusion-common = { path = "../common", version = "7.0.0" } datafusion-expr = { path = "../expr", version = "7.0.0" } chrono = { version = "0.4.16", package = "chrono", default-features = false, features = ["clock"]} diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 038a16fe0ba9..66d0a8ea2ad5 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -36,6 +36,6 @@ path = "src/lib.rs" [dependencies] ahash = { version = "0.7", default-features = false } -arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "5eec82b6bde95d824bb9f3721789c98e033f812f", features = ["prettyprint"] } +arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "85c9d642464886bf977cb6b95fad7c28d0b64cc8", features = ["prettyprint"] } datafusion-common = { path = "../common", version = "7.0.0" } sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "ca8131c36325e86fd733f337673e6cae5e946711" } diff --git a/datafusion/expr/src/type_coercion.rs b/datafusion/expr/src/type_coercion.rs new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/datafusion/jit/Cargo.toml b/datafusion/jit/Cargo.toml index 06a6c616d1e2..6533bd53d03a 100644 --- a/datafusion/jit/Cargo.toml +++ b/datafusion/jit/Cargo.toml @@ -36,7 +36,7 @@ path = "src/lib.rs" jit = [] [dependencies] -arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "5eec82b6bde95d824bb9f3721789c98e033f812f" } +arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "85c9d642464886bf977cb6b95fad7c28d0b64cc8" } cranelift = "0.82.0" cranelift-jit = "0.82.0" cranelift-module = "0.82.0" diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 4c48b93c56e5..ed55af6ed933 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -40,7 +40,7 @@ unicode_expressions = ["unicode-segmentation"] [dependencies] ahash = { version = "0.7", default-features = false } -arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "5eec82b6bde95d824bb9f3721789c98e033f812f", features = ["prettyprint"] } +arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "85c9d642464886bf977cb6b95fad7c28d0b64cc8", features = ["prettyprint"] } blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.0", optional = true } chrono = { version = "0.4", default-features = false } diff --git a/datafusion/physical-expr/src/coercion_rule/binary_rule.rs b/datafusion/physical-expr/src/coercion_rule/binary_rule.rs index 8ab5b6309eda..fd3684ba9ba8 100644 --- a/datafusion/physical-expr/src/coercion_rule/binary_rule.rs +++ b/datafusion/physical-expr/src/coercion_rule/binary_rule.rs @@ -85,12 +85,14 @@ pub(crate) fn coerce_types( fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option { use arrow::datatypes::DataType::*; - if !is_numeric(left_type) || !is_numeric(right_type) { + if !both_numeric_or_null_and_numeric(left_type, right_type) { return None; } + if left_type == right_type && !is_dictionary(left_type) { return Some(left_type.clone()); } + // TODO support other data type match (left_type, right_type) { (Int64, _) | (_, Int64) => Some(Int64), @@ -112,6 +114,8 @@ fn comparison_eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option bool { } } +/// Determine if at least of one of lhs and rhs is numeric, and the other must be NULL or numeric +fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> bool { + match (lhs_type, rhs_type) { + (_, DataType::Null) => is_numeric(lhs_type), + (DataType::Null, _) => is_numeric(rhs_type), + _ => is_numeric(lhs_type) && is_numeric(rhs_type), + } +} + /// Coercion rules for dictionary values (aka the type of the dictionary itself) fn dictionary_value_coercion( lhs_type: &DataType, @@ -436,6 +450,7 @@ fn string_boolean_equality_coercion( fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_coercion(lhs_type, rhs_type) .or_else(|| dictionary_coercion(lhs_type, rhs_type)) + .or_else(|| null_coercion(lhs_type, rhs_type)) } /// Coercion rules for Temporal columns: the type that both lhs and rhs can be @@ -538,6 +553,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { numerical_coercion(lhs_type, rhs_type) .or_else(|| dictionary_coercion(lhs_type, rhs_type)) .or_else(|| temporal_coercion(lhs_type, rhs_type)) + .or_else(|| null_coercion(lhs_type, rhs_type)) } /// Coercion rule for interval @@ -557,6 +573,28 @@ pub fn interval_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { + match (lhs_type, rhs_type) { + (DataType::Null, _) => { + if can_cast_types(&DataType::Null, rhs_type) { + Some(rhs_type.clone()) + } else { + None + } + } + (_, DataType::Null) => { + if can_cast_types(&DataType::Null, lhs_type) { + Some(lhs_type.clone()) + } else { + None + } + } + _ => None, + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 24018a74bc64..068b1823219e 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -16,7 +16,6 @@ // under the License. use std::convert::TryInto; -use std::ops::{BitAnd, BitOr}; use std::{any::Any, sync::Arc}; use arrow::array::TimestampMillisecondArray; @@ -26,10 +25,6 @@ use arrow::compute::kernels::arithmetic::{ multiply_scalar, subtract, subtract_scalar, }; use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene}; -use arrow::compute::kernels::comparison::{ - eq_bool_scalar, gt_bool_scalar, gt_eq_bool_scalar, lt_bool_scalar, lt_eq_bool_scalar, - neq_bool_scalar, -}; use arrow::compute::kernels::comparison::{ eq_dyn_bool_scalar, gt_dyn_bool_scalar, gt_eq_dyn_bool_scalar, lt_dyn_bool_scalar, lt_eq_dyn_bool_scalar, neq_dyn_bool_scalar, @@ -46,13 +41,10 @@ use arrow::compute::kernels::comparison::{ eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar, }; use arrow::compute::kernels::comparison::{ - eq_utf8_scalar, gt_eq_utf8_scalar, gt_utf8_scalar, ilike_utf8_scalar, - like_utf8_scalar, lt_eq_utf8_scalar, lt_utf8_scalar, neq_utf8_scalar, - nilike_utf8_scalar, nlike_utf8_scalar, regexp_is_match_utf8_scalar, -}; -use arrow::compute::kernels::comparison::{ - ilike_utf8, like_utf8, nilike_utf8, nlike_utf8, regexp_is_match_utf8, + ilike_utf8, ilike_utf8_scalar, like_utf8_scalar, nilike_utf8, nilike_utf8_scalar, + nlike_utf8_scalar, regexp_is_match_utf8_scalar, }; +use arrow::compute::kernels::comparison::{like_utf8, nlike_utf8, regexp_is_match_utf8}; use arrow::datatypes::{ArrowNumericType, DataType, Schema, TimeUnit}; use arrow::error::ArrowError::DivideByZero; use arrow::record_batch::RecordBatch; @@ -66,14 +58,6 @@ use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; use datafusion_expr::Operator; -// TODO move to arrow_rs -// https://github.com/apache/arrow-rs/issues/1312 -fn as_decimal_array(arr: &dyn Array) -> &DecimalArray { - arr.as_any() - .downcast_ref::() - .expect("Unable to downcast to typed array to DecimalArray") -} - /// create a `dyn_op` wrapper function for the specified operation /// that call the underlying dyn_op arrow kernel if the type is /// supported, and translates ArrowError to DataFusionError @@ -399,7 +383,7 @@ fn modulus_decimal(left: &DecimalArray, right: &DecimalArray) -> Result {{ + ($LEFT:expr, $RIGHT:expr, $METHOD:expr, $ARRAY_TYPE:ident) => {{ let len = $LEFT.len(); let left = $LEFT.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); let right = $RIGHT.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); @@ -409,7 +393,7 @@ macro_rules! binary_bitwise_array_op { if left.is_null(i) || right.is_null(i) { None } else { - Some(left.value(i).$METHOD(right.value(i) as $TYPE)) + Some($METHOD(left.value(i), right.value(i))) } }) .collect::<$ARRAY_TYPE>(); @@ -421,7 +405,7 @@ macro_rules! binary_bitwise_array_op { /// like int64, int32. /// It is used to do bitwise operation on an array with a scalar. macro_rules! binary_bitwise_array_scalar { - ($LEFT:expr, $RIGHT:expr, $METHOD:tt, $ARRAY_TYPE:ident, $TYPE:ty) => {{ + ($LEFT:expr, $RIGHT:expr, $METHOD:expr, $ARRAY_TYPE:ident, $TYPE:ty) => {{ let len = $LEFT.len(); let array = $LEFT.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); let scalar = $RIGHT; @@ -435,7 +419,7 @@ macro_rules! binary_bitwise_array_scalar { if array.is_null(i) { None } else { - Some(array.value(i).$METHOD(right)) + Some($METHOD(array.value(i), right)) } }) .collect::<$ARRAY_TYPE>(); @@ -447,16 +431,16 @@ macro_rules! binary_bitwise_array_scalar { fn bitwise_and(left: ArrayRef, right: ArrayRef) -> Result { match &left.data_type() { DataType::Int8 => { - binary_bitwise_array_op!(left, right, bitand, Int8Array, i8) + binary_bitwise_array_op!(left, right, |a, b| a & b, Int8Array) } DataType::Int16 => { - binary_bitwise_array_op!(left, right, bitand, Int16Array, i16) + binary_bitwise_array_op!(left, right, |a, b| a & b, Int16Array) } DataType::Int32 => { - binary_bitwise_array_op!(left, right, bitand, Int32Array, i32) + binary_bitwise_array_op!(left, right, |a, b| a & b, Int32Array) } DataType::Int64 => { - binary_bitwise_array_op!(left, right, bitand, Int64Array, i64) + binary_bitwise_array_op!(left, right, |a, b| a & b, Int64Array) } other => Err(DataFusionError::Internal(format!( "Data type {:?} not supported for binary operation '{}' on dyn arrays", @@ -469,16 +453,16 @@ fn bitwise_and(left: ArrayRef, right: ArrayRef) -> Result { fn bitwise_or(left: ArrayRef, right: ArrayRef) -> Result { match &left.data_type() { DataType::Int8 => { - binary_bitwise_array_op!(left, right, bitor, Int8Array, i8) + binary_bitwise_array_op!(left, right, |a, b| a | b, Int8Array) } DataType::Int16 => { - binary_bitwise_array_op!(left, right, bitor, Int16Array, i16) + binary_bitwise_array_op!(left, right, |a, b| a | b, Int16Array) } DataType::Int32 => { - binary_bitwise_array_op!(left, right, bitor, Int32Array, i32) + binary_bitwise_array_op!(left, right, |a, b| a | b, Int32Array) } DataType::Int64 => { - binary_bitwise_array_op!(left, right, bitor, Int64Array, i64) + binary_bitwise_array_op!(left, right, |a, b| a | b, Int64Array) } other => Err(DataFusionError::Internal(format!( "Data type {:?} not supported for binary operation '{}' on dyn arrays", @@ -491,16 +475,36 @@ fn bitwise_or(left: ArrayRef, right: ArrayRef) -> Result { fn bitwise_shift_right(left: ArrayRef, right: ArrayRef) -> Result { match &left.data_type() { DataType::Int8 => { - binary_bitwise_array_op!(left, right, wrapping_shr, Int8Array, u32) + binary_bitwise_array_op!( + left, + right, + |a: i8, b: i8| a.wrapping_shr(b as u32), + Int8Array + ) } DataType::Int16 => { - binary_bitwise_array_op!(left, right, wrapping_shr, Int16Array, u32) + binary_bitwise_array_op!( + left, + right, + |a: i16, b: i16| a.wrapping_shr(b as u32), + Int16Array + ) } DataType::Int32 => { - binary_bitwise_array_op!(left, right, wrapping_shr, Int32Array, u32) + binary_bitwise_array_op!( + left, + right, + |a: i32, b: i32| a.wrapping_shr(b as u32), + Int32Array + ) } DataType::Int64 => { - binary_bitwise_array_op!(left, right, wrapping_shr, Int64Array, u32) + binary_bitwise_array_op!( + left, + right, + |a: i64, b: i64| a.wrapping_shr(b as u32), + Int64Array + ) } other => Err(DataFusionError::Internal(format!( "Data type {:?} not supported for binary operation '{}' on dyn arrays", @@ -513,16 +517,36 @@ fn bitwise_shift_right(left: ArrayRef, right: ArrayRef) -> Result { fn bitwise_shift_left(left: ArrayRef, right: ArrayRef) -> Result { match &left.data_type() { DataType::Int8 => { - binary_bitwise_array_op!(left, right, wrapping_shl, Int8Array, u32) + binary_bitwise_array_op!( + left, + right, + |a: i8, b: i8| a.wrapping_shl(b as u32), + Int8Array + ) } DataType::Int16 => { - binary_bitwise_array_op!(left, right, wrapping_shl, Int16Array, u32) + binary_bitwise_array_op!( + left, + right, + |a: i16, b: i16| a.wrapping_shl(b as u32), + Int16Array + ) } DataType::Int32 => { - binary_bitwise_array_op!(left, right, wrapping_shl, Int32Array, u32) + binary_bitwise_array_op!( + left, + right, + |a: i32, b: i32| a.wrapping_shl(b as u32), + Int32Array + ) } DataType::Int64 => { - binary_bitwise_array_op!(left, right, wrapping_shl, Int64Array, u32) + binary_bitwise_array_op!( + left, + right, + |a: i64, b: i64| a.wrapping_shl(b as u32), + Int64Array + ) } other => Err(DataFusionError::Internal(format!( "Data type {:?} not supported for binary operation '{}' on dyn arrays", @@ -565,16 +589,16 @@ fn bitwise_and_scalar( ) -> Option> { let result = match array.data_type() { DataType::Int8 => { - binary_bitwise_array_scalar!(array, scalar, bitand, Int8Array, i8) + binary_bitwise_array_scalar!(array, scalar, |a, b| a & b, Int8Array, i8) } DataType::Int16 => { - binary_bitwise_array_scalar!(array, scalar, bitand, Int16Array, i16) + binary_bitwise_array_scalar!(array, scalar, |a, b| a & b, Int16Array, i16) } DataType::Int32 => { - binary_bitwise_array_scalar!(array, scalar, bitand, Int32Array, i32) + binary_bitwise_array_scalar!(array, scalar, |a, b| a & b, Int32Array, i32) } DataType::Int64 => { - binary_bitwise_array_scalar!(array, scalar, bitand, Int64Array, i64) + binary_bitwise_array_scalar!(array, scalar, |a, b| a & b, Int64Array, i64) } other => Err(DataFusionError::Internal(format!( "Data type {:?} not supported for binary operation '{}' on dyn arrays", @@ -588,16 +612,16 @@ fn bitwise_and_scalar( fn bitwise_or_scalar(array: &dyn Array, scalar: ScalarValue) -> Option> { let result = match array.data_type() { DataType::Int8 => { - binary_bitwise_array_scalar!(array, scalar, bitor, Int8Array, i8) + binary_bitwise_array_scalar!(array, scalar, |a, b| a | b, Int8Array, i8) } DataType::Int16 => { - binary_bitwise_array_scalar!(array, scalar, bitor, Int16Array, i16) + binary_bitwise_array_scalar!(array, scalar, |a, b| a | b, Int16Array, i16) } DataType::Int32 => { - binary_bitwise_array_scalar!(array, scalar, bitor, Int32Array, i32) + binary_bitwise_array_scalar!(array, scalar, |a, b| a | b, Int32Array, i32) } DataType::Int64 => { - binary_bitwise_array_scalar!(array, scalar, bitor, Int64Array, i64) + binary_bitwise_array_scalar!(array, scalar, |a, b| a | b, Int64Array, i64) } other => Err(DataFusionError::Internal(format!( "Data type {:?} not supported for binary operation '{}' on dyn arrays", @@ -614,16 +638,40 @@ fn bitwise_shift_left_scalar( ) -> Option> { let result = match array.data_type() { DataType::Int8 => { - binary_bitwise_array_scalar!(array, scalar, wrapping_shl, Int8Array, u32) + binary_bitwise_array_scalar!( + array, + scalar, + |a: i8, b: i8| a.wrapping_shl(b as u32), + Int8Array, + i8 + ) } DataType::Int16 => { - binary_bitwise_array_scalar!(array, scalar, wrapping_shl, Int16Array, u32) + binary_bitwise_array_scalar!( + array, + scalar, + |a: i16, b: i16| a.wrapping_shl(b as u32), + Int16Array, + i16 + ) } DataType::Int32 => { - binary_bitwise_array_scalar!(array, scalar, wrapping_shl, Int32Array, u32) + binary_bitwise_array_scalar!( + array, + scalar, + |a: i32, b: i32| a.wrapping_shl(b as u32), + Int32Array, + i32 + ) } DataType::Int64 => { - binary_bitwise_array_scalar!(array, scalar, wrapping_shl, Int64Array, u32) + binary_bitwise_array_scalar!( + array, + scalar, + |a: i64, b: i64| a.wrapping_shl(b as u32), + Int64Array, + i64 + ) } other => Err(DataFusionError::Internal(format!( "Data type {:?} not supported for binary operation '{}' on dyn arrays", @@ -640,16 +688,40 @@ fn bitwise_shift_right_scalar( ) -> Option> { let result = match array.data_type() { DataType::Int8 => { - binary_bitwise_array_scalar!(array, scalar, wrapping_shr, Int8Array, u32) + binary_bitwise_array_scalar!( + array, + scalar, + |a: i8, b: i8| a.wrapping_shr(b as u32), + Int8Array, + i8 + ) } DataType::Int16 => { - binary_bitwise_array_scalar!(array, scalar, wrapping_shr, Int16Array, u32) + binary_bitwise_array_scalar!( + array, + scalar, + |a: i16, b: i16| a.wrapping_shr(b as u32), + Int16Array, + i16 + ) } DataType::Int32 => { - binary_bitwise_array_scalar!(array, scalar, wrapping_shr, Int32Array, u32) + binary_bitwise_array_scalar!( + array, + scalar, + |a: i32, b: i32| a.wrapping_shr(b as u32), + Int32Array, + i32 + ) } DataType::Int64 => { - binary_bitwise_array_scalar!(array, scalar, wrapping_shr, Int64Array, u32) + binary_bitwise_array_scalar!( + array, + scalar, + |a: i64, b: i64| a.wrapping_shr(b as u32), + Int64Array, + i64 + ) } other => Err(DataFusionError::Internal(format!( "Data type {:?} not supported for binary operation '{}' on dyn arrays", @@ -718,6 +790,20 @@ macro_rules! compute_decimal_op { }}; } +macro_rules! compute_null_op { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ + let ll = $LEFT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + let rr = $RIGHT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + Ok(Arc::new(paste::expr! {[<$OP _null>]}(&ll, &rr)?)) + }}; +} + /// Invoke a compute kernel on a pair of binary data arrays macro_rules! compute_utf8_op { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ @@ -735,7 +821,7 @@ macro_rules! compute_utf8_op { /// Invoke a compute kernel on a data array and a scalar value macro_rules! compute_utf8_op_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ + ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident, $OP_TYPE:expr) => {{ let ll = $LEFT .as_any() .downcast_ref::<$DT>() @@ -745,6 +831,8 @@ macro_rules! compute_utf8_op_scalar { &ll, &string_value, )?)) + } else if $RIGHT.is_null() { + Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) } else { Err(DataFusionError::Internal(format!( "compute_utf8_op_scalar for '{}' failed to cast literal value {}", @@ -757,41 +845,22 @@ macro_rules! compute_utf8_op_scalar { /// Invoke a compute kernel on a data array and a scalar value macro_rules! compute_utf8_op_dyn_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ if let Some(string_value) = $RIGHT { Ok(Arc::new(paste::expr! {[<$OP _dyn_utf8_scalar>]}( $LEFT, &string_value, )?)) } else { - Err(DataFusionError::Internal(format!( - "compute_utf8_op_scalar for '{}' failed with literal 'none' value", - stringify!($OP), - ))) + // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE + Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) } }}; } -/// Invoke a compute kernel on a boolean data array and a scalar value -macro_rules! compute_bool_op_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - use std::convert::TryInto; - let ll = $LEFT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast array"); - // generate the scalar function name, such as lt_scalar, from the $OP parameter - // (which could have a value of lt) and the suffix _scalar - Ok(Arc::new(paste::expr! {[<$OP _bool_scalar>]}( - &ll, - $RIGHT.try_into()?, - )?)) - }}; -} - /// Invoke a compute kernel on a boolean data array and a scalar value macro_rules! compute_bool_op_dyn_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ // generate the scalar function name, such as lt_dyn_bool_scalar, from the $OP parameter // (which could have a value of lt) and the suffix _scalar if let Some(b) = $RIGHT { @@ -800,10 +869,8 @@ macro_rules! compute_bool_op_dyn_scalar { b, )?)) } else { - Err(DataFusionError::Internal(format!( - "compute_utf8_op_scalar for '{}' failed with literal 'none' value", - stringify!($OP), - ))) + // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE + Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) } }}; } @@ -836,23 +903,26 @@ macro_rules! compute_bool_op { /// LEFT is array, RIGHT is scalar value macro_rules! compute_op_scalar { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - let ll = $LEFT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast array"); - // generate the scalar function name, such as lt_scalar, from the $OP parameter - // (which could have a value of lt) and the suffix _scalar - Ok(Arc::new(paste::expr! {[<$OP _scalar>]}( - &ll, - $RIGHT.try_into()?, - )?)) + if $RIGHT.is_null() { + Ok(Arc::new(new_null_array($LEFT.data_type(), $LEFT.len()))) + } else { + let ll = $LEFT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + Ok(Arc::new(paste::expr! {[<$OP _scalar>]}( + &ll, + $RIGHT.try_into()?, + )?)) + } }}; } /// Invoke a dyn compute kernel on a data array and a scalar value -/// LEFT is Primitive or Dictionart array of numeric values, RIGHT is scalar value +/// LEFT is Primitive or Dictionary array of numeric values, RIGHT is scalar value +/// OP_TYPE is the return type of scalar function macro_rules! compute_op_dyn_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ // generate the scalar function name, such as lt_dyn_scalar, from the $OP parameter // (which could have a value of lt_dyn) and the suffix _scalar if let Some(value) = $RIGHT { @@ -861,10 +931,8 @@ macro_rules! compute_op_dyn_scalar { value, )?)) } else { - Err(DataFusionError::Internal(format!( - "compute_utf8_op_scalar for '{}' failed with literal 'none' value", - stringify!($OP), - ))) + // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE + Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) } }}; } @@ -894,9 +962,9 @@ macro_rules! compute_op { } macro_rules! binary_string_array_op_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ let result: Result> = match $LEFT.data_type() { - DataType::Utf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, StringArray), + DataType::Utf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, StringArray, $OP_TYPE), other => Err(DataFusionError::Internal(format!( "Data type {:?} not supported for scalar operation '{}' on string array", other, stringify!($OP) @@ -970,58 +1038,13 @@ macro_rules! binary_primitive_array_op_scalar { }}; } -/// The binary_array_op_scalar macro includes types that extend beyond the primitive, -/// such as Utf8 strings. -#[macro_export] -macro_rules! binary_array_op_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ - let result: Result> = match $LEFT.data_type() { - DataType::Decimal(_,_) => compute_decimal_op_scalar!($LEFT, $RIGHT, $OP, DecimalArray), - DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array), - DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array), - DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array), - DataType::Int64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int64Array), - DataType::UInt8 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt8Array), - DataType::UInt16 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt16Array), - DataType::UInt32 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt32Array), - DataType::UInt64 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt64Array), - DataType::Float32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array), - DataType::Float64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array), - DataType::Utf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, StringArray), - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampNanosecondArray) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMicrosecondArray) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMillisecondArray) - } - DataType::Timestamp(TimeUnit::Second, _) => { - compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampSecondArray) - } - DataType::Date32 => { - compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array) - } - DataType::Date64 => { - compute_op_scalar!($LEFT, $RIGHT, $OP, Date64Array) - } - DataType::Boolean => compute_bool_op_scalar!($LEFT, $RIGHT, $OP, BooleanArray), - other => Err(DataFusionError::Internal(format!( - "Data type {:?} not supported for scalar operation '{}' on dyn array", - other, stringify!($OP) - ))), - }; - Some(result) - }}; -} - /// The binary_array_op macro includes types that extend beyond the primitive, /// such as Utf8 strings. #[macro_export] macro_rules! binary_array_op { ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ match $LEFT.data_type() { + DataType::Null => compute_null_op!($LEFT, $RIGHT, $OP, NullArray), DataType::Decimal(_,_) => compute_decimal_op!($LEFT, $RIGHT, $OP, DecimalArray), DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array), DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array), @@ -1272,22 +1295,22 @@ impl PhysicalExpr for BinaryExpr { /// such as Utf8 strings. #[macro_export] macro_rules! binary_array_op_dyn_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ let result: Result> = match $RIGHT { - ScalarValue::Boolean(b) => compute_bool_op_dyn_scalar!($LEFT, b, $OP), + ScalarValue::Boolean(b) => compute_bool_op_dyn_scalar!($LEFT, b, $OP, $OP_TYPE), ScalarValue::Decimal128(..) => compute_decimal_op_scalar!($LEFT, $RIGHT, $OP, DecimalArray), - ScalarValue::Utf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP), - ScalarValue::LargeUtf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP), - ScalarValue::Int8(v) => compute_op_dyn_scalar!($LEFT, v, $OP), - ScalarValue::Int16(v) => compute_op_dyn_scalar!($LEFT, v, $OP), - ScalarValue::Int32(v) => compute_op_dyn_scalar!($LEFT, v, $OP), - ScalarValue::Int64(v) => compute_op_dyn_scalar!($LEFT, v, $OP), - ScalarValue::UInt8(v) => compute_op_dyn_scalar!($LEFT, v, $OP), - ScalarValue::UInt16(v) => compute_op_dyn_scalar!($LEFT, v, $OP), - ScalarValue::UInt32(v) => compute_op_dyn_scalar!($LEFT, v, $OP), - ScalarValue::UInt64(v) => compute_op_dyn_scalar!($LEFT, v, $OP), - ScalarValue::Float32(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array), - ScalarValue::Float64(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array), + ScalarValue::Utf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), + ScalarValue::LargeUtf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), + ScalarValue::Int8(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), + ScalarValue::Int16(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), + ScalarValue::Int32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), + ScalarValue::Int64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), + ScalarValue::UInt8(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), + ScalarValue::UInt16(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), + ScalarValue::UInt32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), + ScalarValue::UInt64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), + ScalarValue::Float32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), + ScalarValue::Float64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), ScalarValue::Date32(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array), ScalarValue::Date64(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Date64Array), ScalarValue::TimestampSecond(..) => compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampSecondArray), @@ -1300,6 +1323,20 @@ macro_rules! binary_array_op_dyn_scalar { }} } +/// Compares the array with the scalar value for equality, sometimes +/// used in other kernels +pub(crate) fn array_eq_scalar(lhs: &dyn Array, rhs: &ScalarValue) -> Result { + binary_array_op_dyn_scalar!(lhs, rhs.clone(), eq, &DataType::Boolean).ok_or_else( + || { + DataFusionError::Internal(format!( + "Data type {:?} and scalar {:?} not supported for array_eq_scalar", + lhs.data_type(), + rhs.get_datatype() + )) + }, + )? +} + impl BinaryExpr { /// Evaluate the expression of the left input is an array and /// right is literal - use scalar operations @@ -1308,36 +1345,37 @@ impl BinaryExpr { array: &dyn Array, scalar: &ScalarValue, ) -> Result>> { + let bool_type = &DataType::Boolean; let scalar_result = match &self.op { Operator::Lt => { - binary_array_op_dyn_scalar!(array, scalar.clone(), lt) + binary_array_op_dyn_scalar!(array, scalar.clone(), lt, bool_type) } Operator::LtEq => { - binary_array_op_dyn_scalar!(array, scalar.clone(), lt_eq) + binary_array_op_dyn_scalar!(array, scalar.clone(), lt_eq, bool_type) } Operator::Gt => { - binary_array_op_dyn_scalar!(array, scalar.clone(), gt) + binary_array_op_dyn_scalar!(array, scalar.clone(), gt, bool_type) } Operator::GtEq => { - binary_array_op_dyn_scalar!(array, scalar.clone(), gt_eq) + binary_array_op_dyn_scalar!(array, scalar.clone(), gt_eq, bool_type) } Operator::Eq => { - binary_array_op_dyn_scalar!(array, scalar.clone(), eq) + binary_array_op_dyn_scalar!(array, scalar.clone(), eq, bool_type) } Operator::NotEq => { - binary_array_op_dyn_scalar!(array, scalar.clone(), neq) + binary_array_op_dyn_scalar!(array, scalar.clone(), neq, bool_type) } Operator::Like => { - binary_string_array_op_scalar!(array, scalar.clone(), like) + binary_string_array_op_scalar!(array, scalar.clone(), like, bool_type) } Operator::NotLike => { - binary_string_array_op_scalar!(array, scalar.clone(), nlike) + binary_string_array_op_scalar!(array, scalar.clone(), nlike, bool_type) } Operator::ILike => { - binary_string_array_op_scalar!(array, scalar.clone(), ilike) + binary_string_array_op_scalar!(array, scalar.clone(), ilike, bool_type) } Operator::NotILike => { - binary_string_array_op_scalar!(array, scalar.clone(), nilike) + binary_string_array_op_scalar!(array, scalar.clone(), nilike, bool_type) } Operator::Plus => { binary_primitive_array_op_scalar!(array, scalar.clone(), add) @@ -1404,14 +1442,25 @@ impl BinaryExpr { scalar: &ScalarValue, array: &ArrayRef, ) -> Result>> { + let bool_type = &DataType::Boolean; let scalar_result = match &self.op { - Operator::Lt => binary_array_op_scalar!(array, scalar.clone(), gt), - Operator::LtEq => binary_array_op_scalar!(array, scalar.clone(), gt_eq), - Operator::Gt => binary_array_op_scalar!(array, scalar.clone(), lt), - Operator::GtEq => binary_array_op_scalar!(array, scalar.clone(), lt_eq), - Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq), + Operator::Lt => { + binary_array_op_dyn_scalar!(array, scalar.clone(), gt, bool_type) + } + Operator::LtEq => { + binary_array_op_dyn_scalar!(array, scalar.clone(), gt_eq, bool_type) + } + Operator::Gt => { + binary_array_op_dyn_scalar!(array, scalar.clone(), lt, bool_type) + } + Operator::GtEq => { + binary_array_op_dyn_scalar!(array, scalar.clone(), lt_eq, bool_type) + } + Operator::Eq => { + binary_array_op_dyn_scalar!(array, scalar.clone(), eq, bool_type) + } Operator::NotEq => { - binary_array_op_scalar!(array, scalar.clone(), neq) + binary_array_op_dyn_scalar!(array, scalar.clone(), neq, bool_type) } // if scalar operation is not supported - fallback to array implementation _ => None, @@ -1437,7 +1486,16 @@ impl BinaryExpr { Operator::GtEq => gt_eq_dyn(&left, &right), Operator::Eq => eq_dyn(&left, &right), Operator::NotEq => neq_dyn(&left, &right), - Operator::IsDistinctFrom => binary_array_op!(left, right, is_distinct_from), + Operator::IsDistinctFrom => { + match (left_data_type, right_data_type) { + // exchange lhs and rhs when lhs is Null, since `binary_array_op` is + // always try to down cast array according to $LEFT expression. + (DataType::Null, _) => { + binary_array_op!(right, left, is_distinct_from) + } + _ => binary_array_op!(left, right, is_distinct_from), + } + } Operator::IsNotDistinctFrom => { binary_array_op!(left, right, is_not_distinct_from) } @@ -1514,6 +1572,23 @@ fn is_distinct_from_utf8( .collect()) } +fn is_distinct_from_null(left: &NullArray, _right: &NullArray) -> Result { + let length = left.len(); + make_boolean_array(length, false) +} + +fn is_not_distinct_from_null( + left: &NullArray, + _right: &NullArray, +) -> Result { + let length = left.len(); + make_boolean_array(length, true) +} + +fn make_boolean_array(length: usize, value: bool) -> Result { + Ok((0..length).into_iter().map(|_| Some(value)).collect()) +} + fn is_not_distinct_from( left: &PrimitiveArray, right: &PrimitiveArray, diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 02f1d3c384cc..126187adec5a 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -113,13 +113,19 @@ macro_rules! if_then_else { .as_ref() .as_any() .downcast_ref::<$ARRAY_TYPE>() - .expect("true_values downcast failed"); + .expect(&format!( + "true_values downcast failed to {}", + stringify!($ARRAY_TYPE) + )); let false_values = $FALSE .as_ref() .as_any() .downcast_ref::<$ARRAY_TYPE>() - .expect("false_values downcast failed"); + .expect(&format!( + "false_values downcast failed to {}", + stringify!($ARRAY_TYPE) + )); let mut builder = <$BUILDER_TYPE>::new($BOOLS.len()); for i in 0..$BOOLS.len() { @@ -254,7 +260,7 @@ impl CaseExpr { /// [ELSE result] /// END fn case_when_with_expr(&self, batch: &RecordBatch) -> Result { - let return_type = self.when_then_expr[0].1.data_type(&batch.schema())?; + let return_type = self.data_type(&*batch.schema())?; let expr = self.expr.as_ref().unwrap(); let base_value = expr.evaluate(batch)?; let base_value = base_value.into_array(batch.num_rows()); @@ -275,7 +281,12 @@ impl CaseExpr { let then_value = self.when_then_expr[i] .1 .evaluate_selection(batch, &when_match)?; - let then_value = then_value.into_array(batch.num_rows()); + let then_value = match then_value { + ColumnarValue::Scalar(value) if value.is_null() => { + new_null_array(&return_type, batch.num_rows()) + } + _ => then_value.into_array(batch.num_rows()), + }; current_value = if_then_else(&when_match, then_value, current_value, &return_type)?; @@ -306,7 +317,7 @@ impl CaseExpr { /// [ELSE result] /// END fn case_when_no_expr(&self, batch: &RecordBatch) -> Result { - let return_type = self.when_then_expr[0].1.data_type(&batch.schema())?; + let return_type = self.data_type(&*batch.schema())?; // start with nulls as default output let mut current_value = new_null_array(&return_type, batch.num_rows()); @@ -315,6 +326,13 @@ impl CaseExpr { let when_value = self.when_then_expr[i] .0 .evaluate_selection(batch, &remainder)?; + // Treat 'NULL' as false value + let when_value = match when_value { + ColumnarValue::Scalar(value) if value.is_null() => { + continue; + } + _ => when_value, + }; let when_value = when_value.into_array(batch.num_rows()); let when_value = when_value .as_ref() @@ -325,7 +343,12 @@ impl CaseExpr { let then_value = self.when_then_expr[i] .1 .evaluate_selection(batch, when_value)?; - let then_value = then_value.into_array(batch.num_rows()); + let then_value = match then_value { + ColumnarValue::Scalar(value) if value.is_null() => { + new_null_array(&return_type, batch.num_rows()) + } + _ => then_value.into_array(batch.num_rows()), + }; current_value = if_then_else(when_value, then_value, current_value, &return_type)?; @@ -359,7 +382,23 @@ impl PhysicalExpr for CaseExpr { } fn data_type(&self, input_schema: &Schema) -> Result { - self.when_then_expr[0].1.data_type(input_schema) + // since all then results have the same data type, we can choose any one as the + // return data type except for the null. + let mut data_type = DataType::Null; + for i in 0..self.when_then_expr.len() { + data_type = self.when_then_expr[i].1.data_type(input_schema)?; + if !(data_type == DataType::Null) { + break; + } + } + // if all then results are null, we use data type of else expr instead if possible. + if data_type == DataType::Null { + if let Some(e) = &self.else_expr { + data_type = e.data_type(input_schema)?; + } + } + + Ok(data_type) } fn nullable(&self, input_schema: &Schema) -> Result { @@ -374,7 +413,9 @@ impl PhysicalExpr for CaseExpr { } else if let Some(e) = &self.else_expr { e.nullable(input_schema) } else { - Ok(false) + // CASE produces NULL if there is no `else` expr + // (aka when none of the `when_then_exprs` match) + Ok(true) } } diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 2aee0d87dbde..92a9d64c1420 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -451,6 +451,10 @@ impl PhysicalExpr for InListExpr { DataType::LargeUtf8 => { self.compare_utf8::(array, list_values, self.negated) } + DataType::Null => { + let null_array = new_null_array(&DataType::Boolean, array.len()); + Ok(ColumnarValue::Array(Arc::new(null_array))) + } datatype => Result::Err(DataFusionError::NotImplemented(format!( "InList does not support datatype {:?}.", datatype diff --git a/datafusion/physical-expr/src/expressions/nullif.rs b/datafusion/physical-expr/src/expressions/nullif.rs index f3f72096600f..86c65fec1b8e 100644 --- a/datafusion/physical-expr/src/expressions/nullif.rs +++ b/datafusion/physical-expr/src/expressions/nullif.rs @@ -17,19 +17,17 @@ use std::sync::Arc; -use crate::expressions::binary::{eq_decimal, eq_decimal_scalar}; use arrow::array::Array; use arrow::array::*; +use arrow::compute::eq_dyn; use arrow::compute::kernels::boolean::nullif; -use arrow::compute::kernels::comparison::{ - eq, eq_bool, eq_bool_scalar, eq_scalar, eq_utf8, eq_utf8_scalar, -}; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::DataType; use cube_ext::nullif_func_str; -use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; +use super::binary::array_eq_scalar; + /// Invoke a compute kernel on a primitive array and a Boolean Array macro_rules! compute_bool_array_op { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ @@ -88,7 +86,7 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result { match (lhs, rhs) { (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => { - let cond_array = binary_array_op_scalar!(lhs, rhs.clone(), eq).unwrap()?; + let cond_array = array_eq_scalar(lhs, rhs)?; let array = primitive_bool_array_op!(lhs, *cond_array, nullif)?; @@ -96,10 +94,10 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result { } (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => { // Get args0 == args1 evaluated and produce a boolean array - let cond_array = binary_array_op!(lhs, rhs, eq)?; + let cond_array = eq_dyn(lhs, rhs)?; // Now, invoke nullif on the result - let array = primitive_bool_array_op!(lhs, *cond_array, nullif)?; + let array = primitive_bool_array_op!(lhs, cond_array, nullif)?; Ok(ColumnarValue::Array(array)) } _ => Err(DataFusionError::NotImplemented( @@ -130,7 +128,7 @@ pub static SUPPORTED_NULLIF_TYPES: &[DataType] = &[ #[cfg(test)] mod tests { use super::*; - use datafusion_common::Result; + use datafusion_common::{Result, ScalarValue}; #[test] fn nullif_int32() -> Result<()> { diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 5b1cdae72cb2..bafc327125c3 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -36,7 +36,7 @@ path = "src/lib.rs" [dependencies] datafusion = { path = "../core", version = "7.0.0" } -prost = "0.9" +prost = "0.10" [build-dependencies] -tonic-build = { version = "0.6" } +tonic-build = { version = "0.7" }