diff --git a/datafusion/functions-nested/src/distance.rs b/datafusion/functions-nested/src/distance.rs index fa9394c73bcb..aedcbe148f5b 100644 --- a/datafusion/functions-nested/src/distance.rs +++ b/datafusion/functions-nested/src/distance.rs @@ -17,17 +17,14 @@ //! [ScalarUDFImpl] definitions for array_distance function. -use crate::utils::{downcast_arg, make_scalar_function}; +use crate::utils::{convert_to_f64_array, downcast_arg, make_scalar_function}; use arrow_array::{ Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait, }; use arrow_schema::DataType; use arrow_schema::DataType::{FixedSizeList, Float64, LargeList, List}; use core::any::type_name; -use datafusion_common::cast::{ - as_float32_array, as_float64_array, as_generic_list_array, as_int32_array, - as_int64_array, -}; +use datafusion_common::cast::as_generic_list_array; use datafusion_common::utils::coerced_fixed_size_list_to_list; use datafusion_common::DataFusionError; use datafusion_common::{exec_err, Result}; @@ -203,29 +200,3 @@ fn compute_array_distance( Ok(Some(sum_squares.sqrt())) } - -/// Converts an array of any numeric type to a Float64Array. -fn convert_to_f64_array(array: &ArrayRef) -> Result { - match array.data_type() { - DataType::Float64 => Ok(as_float64_array(array)?.clone()), - DataType::Float32 => { - let array = as_float32_array(array)?; - let converted: Float64Array = - array.iter().map(|v| v.map(|v| v as f64)).collect(); - Ok(converted) - } - DataType::Int64 => { - let array = as_int64_array(array)?; - let converted: Float64Array = - array.iter().map(|v| v.map(|v| v as f64)).collect(); - Ok(converted) - } - DataType::Int32 => { - let array = as_int32_array(array)?; - let converted: Float64Array = - array.iter().map(|v| v.map(|v| v as f64)).collect(); - Ok(converted) - } - _ => exec_err!("Unsupported array type for conversion to Float64Array"), - } -} diff --git a/datafusion/functions-nested/src/dot_product.rs b/datafusion/functions-nested/src/dot_product.rs new file mode 100644 index 000000000000..af2fab900c35 --- /dev/null +++ b/datafusion/functions-nested/src/dot_product.rs @@ -0,0 +1,201 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [ScalarUDFImpl] definitions for dot_product function. + +use crate::utils::{convert_to_f64_array, downcast_arg, make_scalar_function}; +use arrow_array::{ + Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait, +}; +use arrow_schema::DataType; +use arrow_schema::DataType::{FixedSizeList, Float64, LargeList, List}; +use core::any::type_name; +use datafusion_common::cast::as_generic_list_array; +use datafusion_common::utils::coerced_fixed_size_list_to_list; +use datafusion_common::DataFusionError; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +make_udf_expr_and_func!( + ArrayDotProduct, + array_dot_product, + array, + "returns the dot product between two numeric arrays.", + array_dot_product_udf +); + +#[derive(Debug)] +pub(super) struct ArrayDotProduct { + signature: Signature, + aliases: Vec, +} + +impl ArrayDotProduct { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_dot_product".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayDotProduct { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_dot_product" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match arg_types[0] { + List(_) | LargeList(_) | FixedSizeList(_, _) => Ok(Float64), + _ => exec_err!("The array_dot_product function can only accept List/LargeList/FixedSizeList."), + } + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 2 { + return exec_err!("array_dot_product expects exactly two arguments"); + } + let mut result = Vec::new(); + for arg_type in arg_types { + match arg_type { + List(_) | LargeList(_) | FixedSizeList(_, _) => result.push(coerced_fixed_size_list_to_list(arg_type)), + _ => return exec_err!("The array_dot_product function can only accept List/LargeList/FixedSizeList."), + } + } + + Ok(result) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_dot_product_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +pub fn array_dot_product_inner(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_dot_product expects exactly two arguments"); + } + + match (&args[0].data_type(), &args[1].data_type()) { + (List(_), List(_)) => general_array_dot_product::(args), + (LargeList(_), LargeList(_)) => general_array_dot_product::(args), + (array_type1, array_type2) => { + exec_err!("array_dot_product does not support types '{array_type1:?}' and '{array_type2:?}'") + } + } +} + +fn general_array_dot_product( + arrays: &[ArrayRef], +) -> Result { + let list_array1 = as_generic_list_array::(&arrays[0])?; + let list_array2 = as_generic_list_array::(&arrays[1])?; + + let result = list_array1 + .iter() + .zip(list_array2.iter()) + .map(|(arr1, arr2)| compute_array_dot_product(arr1, arr2)) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +/// Computes the dot product between two arrays +fn compute_array_dot_product( + arr1: Option, + arr2: Option, +) -> Result> { + let value1 = match arr1 { + Some(arr) => arr, + None => return Ok(None), + }; + let value2 = match arr2 { + Some(arr) => arr, + None => return Ok(None), + }; + + let mut value1 = value1; + let mut value2 = value2; + + loop { + match value1.data_type() { + List(_) => { + if downcast_arg!(value1, ListArray).null_count() > 0 { + return Ok(None); + } + value1 = downcast_arg!(value1, ListArray).value(0); + } + LargeList(_) => { + if downcast_arg!(value1, LargeListArray).null_count() > 0 { + return Ok(None); + } + value1 = downcast_arg!(value1, LargeListArray).value(0); + } + _ => break, + } + + match value2.data_type() { + List(_) => { + if downcast_arg!(value2, ListArray).null_count() > 0 { + return Ok(None); + } + value2 = downcast_arg!(value2, ListArray).value(0); + } + LargeList(_) => { + if downcast_arg!(value2, LargeListArray).null_count() > 0 { + return Ok(None); + } + value2 = downcast_arg!(value2, LargeListArray).value(0); + } + _ => break, + } + } + + // Check for NULL values inside the arrays + if value1.null_count() != 0 || value2.null_count() != 0 { + return Ok(None); + } + + let values1 = convert_to_f64_array(&value1)?; + let values2 = convert_to_f64_array(&value2)?; + + if values1.len() != values2.len() { + return exec_err!("Both arrays must have the same length"); + } + + let sum_products: f64 = values1 + .iter() + .zip(values2.iter()) + .map(|(v1, v2)| v1.unwrap_or(0.0) * v2.unwrap_or(0.0)) + .sum(); + + Ok(Some(sum_products)) +} diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 301ddb36fc56..0cea583ed0e7 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -35,6 +35,7 @@ pub mod cardinality; pub mod concat; pub mod dimension; pub mod distance; +pub mod dot_product; pub mod empty; pub mod except; pub mod expr_ext; @@ -77,6 +78,7 @@ pub mod expr_fn { pub use super::dimension::array_dims; pub use super::dimension::array_ndims; pub use super::distance::array_distance; + pub use super::dot_product::array_dot_product; pub use super::empty::array_empty; pub use super::except::array_except; pub use super::extract::array_any_value; @@ -137,6 +139,7 @@ pub fn all_default_nested_functions() -> Vec> { empty::array_empty_udf(), length::array_length_udf(), distance::array_distance_udf(), + dot_product::array_dot_product_udf(), flatten::flatten_udf(), sort::array_sort_udf(), repeat::array_repeat_udf(), diff --git a/datafusion/functions-nested/src/utils.rs b/datafusion/functions-nested/src/utils.rs index 0765f6cd237d..27d3f33d48ef 100644 --- a/datafusion/functions-nested/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -22,16 +22,18 @@ use std::sync::Arc; use arrow::{array::ArrayRef, datatypes::DataType}; use arrow_array::{ - Array, BooleanArray, GenericListArray, ListArray, OffsetSizeTrait, Scalar, - UInt32Array, + Array, BooleanArray, Float64Array, GenericListArray, ListArray, OffsetSizeTrait, + Scalar, UInt32Array, }; use arrow_buffer::OffsetBuffer; use arrow_schema::{Field, Fields}; -use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue}; - use core::any::type_name; +use datafusion_common::cast::{ + as_float32_array, as_float64_array, as_int32_array, as_int64_array, + as_large_list_array, as_list_array, +}; use datafusion_common::DataFusionError; +use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; macro_rules! downcast_arg { @@ -268,6 +270,32 @@ pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> { } } +/// Converts an array of any numeric type to a Float64Array. +pub(crate) fn convert_to_f64_array(array: &ArrayRef) -> Result { + match array.data_type() { + DataType::Float64 => Ok(as_float64_array(array)?.clone()), + DataType::Float32 => { + let array = as_float32_array(array)?; + let converted: Float64Array = + array.iter().map(|v| v.map(|v| v as f64)).collect(); + Ok(converted) + } + DataType::Int64 => { + let array = as_int64_array(array)?; + let converted: Float64Array = + array.iter().map(|v| v.map(|v| v as f64)).collect(); + Ok(converted) + } + DataType::Int32 => { + let array = as_int32_array(array)?; + let converted: Float64Array = + array.iter().map(|v| v.map(|v| v as f64)).collect(); + Ok(converted) + } + _ => exec_err!("Unsupported array type for conversion to Float64Array"), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index b7d60b50586d..ad0d864188f8 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -661,6 +661,38 @@ FROM arrays_distance_table ; +statement ok +CREATE TABLE arrays_dot_product_table +AS VALUES + (make_array(1, 2, 3), make_array(1, 2, 3), make_array(1.1, 2.2, 3.3) , make_array(1.1, NULL, 3.3)), + (make_array(1, 2, 3), make_array(4, 5, 6), make_array(4.4, 5.5, 6.6), make_array(4.4, NULL, 6.6)), + (make_array(1, 2, 3), make_array(7, 8, 9), make_array(7.7, 8.8, 9.9), make_array(7.7, NULL, 9.9)), + (make_array(1, 2, 3), make_array(10, 11, 12), make_array(10.1, 11.2, 12.3), make_array(10.1, NULL, 12.3)) +; + +statement ok +CREATE TABLE large_arrays_dot_product_table +AS + SELECT + arrow_cast(column1, 'LargeList(Int64)') AS column1, + arrow_cast(column2, 'LargeList(Int64)') AS column2, + arrow_cast(column3, 'LargeList(Float64)') AS column3, + arrow_cast(column4, 'LargeList(Float64)') AS column4 +FROM arrays_dot_product_table +; + +statement ok +CREATE TABLE fixed_size_arrays_dot_product_table +AS + SELECT + arrow_cast(column1, 'FixedSizeList(3, Int64)') AS column1, + arrow_cast(column2, 'FixedSizeList(3, Int64)') AS column2, + arrow_cast(column3, 'FixedSizeList(3, Float64)') AS column3, + arrow_cast(column4, 'FixedSizeList(3, Float64)') AS column4 +FROM arrays_dot_product_table +; + + # Array literal ## boolean coercion is not supported @@ -5208,6 +5240,85 @@ NULL 1 1 2 NULL 1 2 1 NULL + +query RRR +select array_dot_product([2], [3]), list_dot_product([1], [2]), list_dot_product([1], [-2]); +---- +6 2 -2 + +query error +select list_dot_product([1], [1, 2]); + +query R +select array_dot_product([[1, 1]], [1, 2]); +---- +3 + +query R +select array_dot_product([[1, 1]], [[1, 2]]); +---- +3 + +query R +select list_dot_product([[1, 1]], [[1, 2]]); +---- +3 + +query RR +select array_dot_product([1, 1, 0, 0], [2, 2, 1, 1]), list_dot_product([1, 2, 3], [1, 2, 3]); +---- +4 14 + +query RR +select array_dot_product([1.0, 1, 0, 0], [2, 2.0, 1, 1]), list_dot_product([1, 2.0, 3], [1, 2, 3]); +---- +4 14 + +query R +select list_dot_product([1, 1, NULL, 0], [2, 2, NULL, NULL]); +---- +NULL + +query R +select list_dot_product([NULL, NULL], [NULL, NULL]); +---- +NULL + +query R +select list_dot_product([1.0, 2.0, 3.0], [1.0, 2.0, 3.5]) AS dot_product; +---- +15.5 + +query R +select list_dot_product([1, 2, 3], [1, 2, 3]) AS dot_product; +---- +14 + +# array_dot_product with columns +query RRR +select array_dot_product(column1, column2), array_dot_product(column1, column3), array_dot_product(column1, column4) from arrays_dot_product_table; +---- +14 15.4 NULL +32 35.2 NULL +50 55 NULL +68 69.4 NULL + +query RRR +select array_dot_product(column1, column2), array_dot_product(column1, column3), array_dot_product(column1, column4) from large_arrays_dot_product_table; +---- +14 15.4 NULL +32 35.2 NULL +50 55 NULL +68 69.4 NULL + +query RRR +select array_dot_product(column1, column2), array_dot_product(column1, column3), array_dot_product(column1, column4) from fixed_size_arrays_dot_product_table; +---- +14 15.4 NULL +32 35.2 NULL +50 55 NULL +68 69.4 NULL + ## array_has/array_has_all/array_has_any # If lhs is empty, return false diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 636a33100ed0..cafe171c518e 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -2094,6 +2094,7 @@ to_unixtime(expression[, ..., format_n]) - [array_dims](#array_dims) - [array_distance](#array_distance) - [array_distinct](#array_distinct) +- [array_dot_product](#array_dot_product) - [array_has](#array_has) - [array_has_all](#array_has_all) - [array_has_any](#array_has_any) @@ -2138,6 +2139,7 @@ to_unixtime(expression[, ..., format_n]) - [list_dims](#list_dims) - [list_distance](#list_distance) - [list_distinct](#list_distinct) +- [list_dot_product](#list_dot_product) - [list_element](#list_element) - [list_except](#list_except) - [list_extract](#list_extract) @@ -2472,6 +2474,36 @@ array_distinct(array) - list_distinct +### `array_dot_product` + +Returns the dot product between two input arrays of equal length. + +``` +array_dot_product(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +> select array_dot_product([1, 2], [1, 4]); ++---------------------------------------+ +| array_dot_product(List([1,2], [1,4])) | ++---------------------------------------+ +| 9.0 | ++---------------------------------------+ +``` + +#### Aliases + +- list_dot_product + ### `array_element` Extracts the element with the index n from the array. @@ -3292,6 +3324,10 @@ _Alias of [array_distance](#array_distance)._ _Alias of [array_distinct](#array_distinct)._ +### `list_dot_product` + +_Alias of [array_dot_product](#array_dot_product)._ + ### `list_element` _Alias of [array_element](#array_element)._