diff --git a/datafusion/functions/src/core/greatest.rs b/datafusion/functions/src/core/greatest.rs new file mode 100644 index 000000000000..185002b0a7cf --- /dev/null +++ b/datafusion/functions/src/core/greatest.rs @@ -0,0 +1,236 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use arrow::array::{make_comparator, Array, ArrayRef, BooleanArray}; +use arrow::compute::kernels::cmp; +use arrow::compute::kernels::zip::zip; +use arrow::compute::SortOptions; +use arrow::datatypes::DataType; +use arrow_buffer::BooleanBuffer; +use datafusion_common::{exec_err, plan_err, Result, ScalarValue}; +use datafusion_expr::{ColumnarValue}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::binary::type_union_resolution; + +const SORT_OPTIONS: SortOptions = SortOptions { + // We want greatest first + descending: false, + + // NULL will be less than any other value + nulls_first: true, +}; + +#[derive(Debug)] +pub struct GreatestFunc { + signature: Signature, +} + +impl Default for GreatestFunc { + fn default() -> Self { + GreatestFunc::new() + } +} + +impl GreatestFunc { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +fn get_logical_null_count(arr: &dyn Array) -> usize { + arr.logical_nulls().map(|n| n.null_count()).unwrap_or_default() + +} + +/// Return boolean array where `arr[i] = lhs[i] >= rhs[i]` for all i, where `arr` is the result array +/// Nulls are always considered smaller than any other value +fn get_larger(lhs: &dyn Array, rhs: &dyn Array) -> Result { + // Fast path: + // If both arrays are not nested, have the same length and no nulls, we can use the faster vectorised kernel + // - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined. + // - both array does not have any nulls: cmp::gt_eq will return null if any of the input is null while we want to return false in that case + if !lhs.data_type().is_nested() && get_logical_null_count(lhs) == 0 && get_logical_null_count(rhs) == 0 { + return cmp::gt_eq(&lhs, &rhs).map_err(|e| e.into()); + } + + let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?; + + if lhs.len() != rhs.len() { + return exec_err!("All arrays should have the same length for greatest comparison") + } + + let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_ge()); + + // No nulls as we only want to keep the values that are larger, its either true or false + Ok(BooleanArray::new(values, None)) +} + +/// Return array where the largest value at each index is kept +fn keep_larger(lhs: ArrayRef, rhs: ArrayRef) -> Result { + // True for values that we should keep from the left array + let keep_lhs = get_larger(lhs.as_ref(), rhs.as_ref())?; + + let larger = zip(&keep_lhs, &lhs, &rhs)?; + + Ok(larger) +} + +fn keep_larger_scalar<'a>(lhs: &'a ScalarValue, rhs: &'a ScalarValue) -> Result<&'a ScalarValue> { + if !lhs.data_type().is_nested() { + return if lhs >= rhs { + Ok(lhs) + } else { + Ok(rhs) + }; + } + + // If complex type we can't compare directly as we want null values to be smaller + let cmp = make_comparator( + lhs.to_array()?.as_ref(), + rhs.to_array()?.as_ref(), + SORT_OPTIONS, + )?; + + if cmp(0, 0).is_ge() { + Ok(lhs) + } else { + Ok(rhs) + } +} + +fn find_coerced_type(data_types: &[DataType]) -> Result { + if data_types.is_empty() { + plan_err!("greatest was called without any arguments. It requires at least 1.") + } else if let Some(coerced_type) = type_union_resolution(data_types) { + Ok(coerced_type) + } else { + plan_err!("Cannot find a common type for arguments") + } +} + +impl ScalarUDFImpl for GreatestFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "greatest" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + find_coerced_type(arg_types) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if args.is_empty() { + return exec_err!("greatest was called with no arguments. It requires at least 1."); + } + + // Some engines (e.g. SQL Server) allow greatest with single arg, it's a noop + if args.len() == 1 { + return Ok(args[0].clone()) + } + + // Split to scalars and arrays for later optimization + let (scalars, arrays): (Vec<_>, Vec<_>) = args.iter().partition(|x| match x { + ColumnarValue::Scalar(_) => true, + ColumnarValue::Array(_) => false, + }); + + let mut arrays_iter = arrays + .iter() + .map(|x| match x { + ColumnarValue::Array(a) => a, + _ => unreachable!(), + }); + + let first_array = arrays_iter.next(); + + let mut largest: ArrayRef; + + // Optimization: merge all scalars into one to avoid recomputing + if !scalars.is_empty() { + let mut scalars_iter = scalars + .iter() + .map(|x| match x { + ColumnarValue::Scalar(s) => s, + _ => unreachable!(), + }); + + // We have at least one scalar + let mut largest_scalar = scalars_iter.next().unwrap(); + + for scalar in scalars_iter { + largest_scalar = keep_larger_scalar(largest_scalar, scalar)?; + } + + // If we only have scalars, return the largest one + if arrays.is_empty() { + return Ok(ColumnarValue::Scalar(largest_scalar.clone())); + } + + // We have at least one array + let first_array = first_array.unwrap(); + + // Start with the largest value + largest = keep_larger( + first_array.clone(), + largest_scalar.to_array_of_size(first_array.len())? + )?; + } else { + // If we only have arrays, start with the first array + // (We must have at least one array) + largest = first_array.unwrap().clone(); + } + + for array in arrays_iter { + largest = keep_larger(array.clone(), largest)?; + } + + Ok(ColumnarValue::Array(largest)) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let coerced_type = find_coerced_type(arg_types)?; + + Ok(vec![coerced_type; arg_types.len()]) + } +} + +#[cfg(test)] +mod test { + use crate::core; + use arrow::datatypes::DataType; + use datafusion_expr::ScalarUDFImpl; + + #[test] + fn test_greatest_return_types_without_common_supertype_in_arg_type() { + let greatest = core::greatest::GreatestFunc::new(); + let return_type = greatest + .return_type(&[DataType::Decimal128(10, 3), DataType::Decimal128(10, 4)]) + .unwrap(); + assert_eq!(return_type, DataType::Decimal128(11, 4)); + } +} diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index cf64c03766cb..7dc0536671cc 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -32,6 +32,7 @@ pub mod nvl2; pub mod planner; pub mod r#struct; pub mod version; +pub mod greatest; // create UDFs make_udf_function!(arrow_cast::ArrowCastFunc, ARROW_CAST, arrow_cast); @@ -43,6 +44,7 @@ make_udf_function!(r#struct::StructFunc, STRUCT, r#struct); make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct); make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce); +make_udf_function!(greatest::GreatestFunc, GREATEST, greatest); make_udf_function!(version::VersionFunc, VERSION, version); pub mod expr_fn { @@ -80,6 +82,10 @@ pub mod expr_fn { coalesce, "Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL", args, + ),( + greatest, + "Returns `greatest(args...)`, which evaluates to the greatest value in the list of expressions or NULL if all the expressions are NULL", + args, )); #[doc = "Returns the value of the field with the given name from the struct"] @@ -106,6 +112,7 @@ pub fn functions() -> Vec> { // calls to `get_field` get_field(), coalesce(), + greatest(), version(), r#struct(), ] diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index a7568d88f797..4b770a19fe20 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -756,3 +756,202 @@ query TT select substr('Andrew Lamb', 1, 6), '|' ---- Andrew | + + +# test for greatest +statement ok +CREATE TABLE t1 (a int, b int, c int) as VALUES +(4, NULL, NULL), +(1, 2, 3), +(3, 1, 2), +(1, NULL, -1), +(NULL, NULL, NULL), +(3, 0, -1); + +query I +SELECT greatest(a, b, c) FROM t1 +---- +4 +3 +3 +1 +NULL +3 + +statement ok +drop table t1 + +query I +SELECT greatest(1) +---- +1 + +query I +SELECT greatest(1, 2) +---- +2 + +query I +SELECT greatest(3, 1) +---- +3 + +query ? +SELECT greatest(NULL) +---- +NULL + +query I +SELECT greatest(1, NULL, -1) +---- +1 + +query I +SELECT greatest((3), (0), (-1)); +---- +3 + +query ? +SELECT greatest([4, 3], [4, 2], [4, 4]); +---- +[4, 4] + +query ? +SELECT greatest([2, 3], [1, 4], [5, 0]); +---- +[5, 0] + +query I +SELECT greatest(1::int, 2::text) +---- +2 + +query R +SELECT greatest(-1, 1, 2.3, 123456789, 3 + 5, -(-4)) +---- +123456789 + +query R +SELECT greatest(-1.123, 1.21313, 2.3, 123456789.321, 3 + 5.3213, -(-4.3213), abs(-9)) +---- +123456789.321 + +query R +SELECT greatest(-1, 1, 2.3, 123456789, 3 + 5, -(-4), abs(-9.0)) +---- +123456789 + + +query error greatest does not support zero arguments +SELECT greatest() + +query I +SELECT greatest(4, 5, 7, 1, 2) +---- +7 + +query I +SELECT greatest(4, NULL, 7, 1, 2) +---- +7 + +query I +SELECT greatest(NULL, NULL, 7, NULL, 2) +---- +7 + +query I +SELECT greatest(NULL, NULL, NULL, NULL, 2) +---- +2 + +query I +SELECT greatest(2, NULL, NULL, NULL, NULL) +---- +2 + +query ? +SELECT greatest(NULL, NULL, NULL) +---- +NULL + +query I +SELECT greatest(2, '4') +---- +4 + +query T +SELECT greatest('foo', 'bar', 'foobar') +---- +foobar + +query R +SELECT greatest(1, 1.2) +---- +1.2 + +statement ok +CREATE TABLE foo (a int) + +statement ok +INSERT INTO foo (a) VALUES (1) + +# Test homogenous functions that can't be constant folded. +query I +SELECT greatest(NULL, a, 5, NULL) FROM foo +---- +5 + +query I +SELECT greatest(NULL, NULL, NULL, a, -1) FROM foo +---- +1 + +statement ok +drop table foo + +query R +select greatest(arrow_cast('NAN','Float64'), arrow_cast('NAN','Float64')) +---- +NaN + +query R +select greatest(arrow_cast('NAN','Float64'), arrow_cast('NAN','Float32')) +---- +NaN + +query R +select greatest(arrow_cast('NAN','Float64'), '+Inf'::Double) +---- +NaN + +query R +select greatest(arrow_cast('NAN','Float64'), NULL) +---- +NaN + +query R +select greatest(NULL, '+Inf'::Double) +---- +Infinity + +query R +select greatest(NULL, '-Inf'::Double) +---- +-Infinity + +statement ok +CREATE TABLE t1 (a double, b double, c double) as VALUES +(1, arrow_cast('NAN', 'Float64'), '+Inf'::Double), +(NULL, arrow_cast('NAN','Float64'), '+Inf'::Double), +(1, '+Inf'::Double, NULL); + +query R +SELECT greatest(a, b, c) FROM t1 +---- +NaN +NaN +Infinity + +statement ok +drop table t1 diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index a379dfc9ec29..a4859dad96c0 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -675,6 +675,26 @@ nvl2(expression1, expression2, expression3) +----------------------------------------+ ``` +## Comparison Functions + +- [greatest](#greatest) + +### `greatest` + +Returns the greatest value in a list of expressions. +Returns _null_ if all expressions are _null_. + +``` +greatest(expression1[, ..., expression_n]) +``` + +#### Arguments + +- **expression1, expression_n**: + Expressions to compare and return the greatest value. + Can be a constant, column, or function, and any combination of arithmetic operators. + Pass as many expression arguments as necessary. + ## String Functions - [ascii](#ascii)