From 4bc822b449d62fcc4216793dac934d5302f0ea68 Mon Sep 17 00:00:00 2001 From: Spears Randall Date: Fri, 15 Dec 2023 16:38:24 -0700 Subject: [PATCH 01/14] Change ScalarValue::List type signature Also ScalarValue::LargeList and ScalarValue::FixedSizeList --- datafusion/common/src/scalar.rs | 270 +++++++++++------- .../simplify_expressions/expr_simplifier.rs | 4 +- .../src/aggregate/array_agg_distinct.rs | 4 +- .../physical-expr/src/aggregate/tdigest.rs | 5 +- datafusion/physical-plan/src/values.rs | 3 +- .../proto/src/logical_plan/from_proto.rs | 8 +- datafusion/proto/src/logical_plan/to_proto.rs | 102 ++++++- 7 files changed, 280 insertions(+), 116 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index d730fbf89b72..8f2c6c159a7c 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -141,13 +141,13 @@ pub enum ScalarValue { /// Fixed size list scalar. /// /// The array must be a FixedSizeListArray with length 1. - FixedSizeList(ArrayRef), + FixedSizeList(Arc), /// Represents a single element of a [`ListArray`] as an [`ArrayRef`] /// /// The array must be a ListArray with length 1. - List(ArrayRef), + List(Arc), /// The array must be a LargeListArray with length 1. - LargeList(ArrayRef), + LargeList(Arc), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), /// Date stored as a signed 64bit int milliseconds since UNIX epoch 1970-01-01 @@ -359,10 +359,8 @@ impl PartialOrd for ScalarValue { (FixedSizeBinary(_, _), _) => None, (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), (LargeBinary(_), _) => None, - (List(arr1), List(arr2)) - | (FixedSizeList(arr1), FixedSizeList(arr2)) - | (LargeList(arr1), LargeList(arr2)) => { - // ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1 + // ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1 + (List(arr1), List(arr2)) => { assert_eq!(arr1.len(), 1); assert_eq!(arr2.len(), 1); @@ -370,18 +368,39 @@ impl PartialOrd for ScalarValue { return None; } - fn first_array_for_list(arr: &ArrayRef) -> ArrayRef { - if let Some(arr) = arr.as_list_opt::() { - arr.value(0) - } else if let Some(arr) = arr.as_list_opt::() { - arr.value(0) - } else if let Some(arr) = arr.as_fixed_size_list_opt() { - arr.value(0) - } else { - unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen") + fn first_array_for_list(arr: &Arc) -> ArrayRef { + arr.value(0) + } + + let arr1 = first_array_for_list(arr1); + let arr2 = first_array_for_list(arr2); + + let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; + + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); + } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); } } + Some(Ordering::Equal) + } + (FixedSizeList(arr1), FixedSizeList(arr2)) => { + assert_eq!(arr1.len(), 1); + assert_eq!(arr2.len(), 1); + + if arr1.data_type() != arr2.data_type() { + return None; + } + + fn first_array_for_list(arr: &Arc) -> ArrayRef { + arr.value(0) + } + let arr1 = first_array_for_list(arr1); let arr2 = first_array_for_list(arr2); @@ -399,6 +418,36 @@ impl PartialOrd for ScalarValue { Some(Ordering::Equal) } + (LargeList(arr1), LargeList(arr2)) => { + assert_eq!(arr1.len(), 1); + assert_eq!(arr2.len(), 1); + + if arr1.data_type() != arr2.data_type() { + return None; + } + + fn first_array_for_list(arr: &Arc) -> ArrayRef { + arr.value(0) + } + + let arr1 = first_array_for_list(arr1); + let arr2 = first_array_for_list(arr2); + + let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; + + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); + } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); + } + } + + Some(Ordering::Equal) + + } (List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), (Date32(_), _) => None, @@ -516,8 +565,26 @@ impl std::hash::Hash for ScalarValue { Binary(v) => v.hash(state), FixedSizeBinary(_, v) => v.hash(state), LargeBinary(v) => v.hash(state), - List(arr) | LargeList(arr) | FixedSizeList(arr) => { - let arrays = vec![arr.to_owned()]; + List(arr) => { + let arrays = vec![arr.to_owned() as ArrayRef]; + let hashes_buffer = &mut vec![0; arr.len()]; + let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); + let hashes = + create_hashes(&arrays, &random_state, hashes_buffer).unwrap(); + // Hash back to std::hash::Hasher + hashes.hash(state); + } + LargeList(arr) => { + let arrays = vec![arr.to_owned() as ArrayRef]; + let hashes_buffer = &mut vec![0; arr.len()]; + let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); + let hashes = + create_hashes(&arrays, &random_state, hashes_buffer).unwrap(); + // Hash back to std::hash::Hasher + hashes.hash(state); + } + FixedSizeList(arr) => { + let arrays = vec![arr.to_owned() as ArrayRef]; let hashes_buffer = &mut vec![0; arr.len()]; let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); let hashes = @@ -941,9 +1008,9 @@ impl ScalarValue { ScalarValue::Binary(_) => DataType::Binary, ScalarValue::FixedSizeBinary(sz, _) => DataType::FixedSizeBinary(*sz), ScalarValue::LargeBinary(_) => DataType::LargeBinary, - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), + ScalarValue::List(arr) => arr.data_type().to_owned(), + ScalarValue::LargeList(arr) => arr.data_type().to_owned(), + ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, ScalarValue::Time32Second(_) => DataType::Time32(TimeUnit::Second), @@ -1146,9 +1213,9 @@ impl ScalarValue { ScalarValue::LargeBinary(v) => v.is_none(), // arr.len() should be 1 for a list scalar, but we don't seem to // enforce that anywhere, so we still check against array length. - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => arr.len() == arr.null_count(), + ScalarValue::List(arr) => arr.len() == arr.null_count(), + ScalarValue::LargeList(arr) => arr.len() == arr.null_count(), + ScalarValue::FixedSizeList(arr) => arr.len() == arr.null_count(), ScalarValue::Date32(v) => v.is_none(), ScalarValue::Date64(v) => v.is_none(), ScalarValue::Time32Second(v) => v.is_none(), @@ -1694,17 +1761,16 @@ impl ScalarValue { /// ScalarValue::Int32(Some(2)) /// ]; /// - /// let array = ScalarValue::new_list(&scalars, &DataType::Int32); - /// let result = as_list_array(&array).unwrap(); + /// let result = ScalarValue::new_list(&scalars, &DataType::Int32); /// /// let expected = ListArray::from_iter_primitive::( /// vec![ /// Some(vec![Some(1), None, Some(2)]) /// ]); /// - /// assert_eq!(result, &expected); + /// assert_eq!(*result, expected); /// ``` - pub fn new_list(values: &[ScalarValue], data_type: &DataType) -> ArrayRef { + pub fn new_list(values: &[ScalarValue], data_type: &DataType) -> Arc { let values = if values.is_empty() { new_empty_array(data_type) } else { @@ -1729,17 +1795,16 @@ impl ScalarValue { /// ScalarValue::Int32(Some(2)) /// ]; /// - /// let array = ScalarValue::new_large_list(&scalars, &DataType::Int32); - /// let result = as_large_list_array(&array).unwrap(); + /// let result = ScalarValue::new_large_list(&scalars, &DataType::Int32); /// /// let expected = LargeListArray::from_iter_primitive::( /// vec![ /// Some(vec![Some(1), None, Some(2)]) /// ]); /// - /// assert_eq!(result, &expected); + /// assert_eq!(*result, expected); /// ``` - pub fn new_large_list(values: &[ScalarValue], data_type: &DataType) -> ArrayRef { + pub fn new_large_list(values: &[ScalarValue], data_type: &DataType) -> Arc { let values = if values.is_empty() { new_empty_array(data_type) } else { @@ -1875,10 +1940,22 @@ impl ScalarValue { .collect::(), ), }, - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => { - let arrays = std::iter::repeat(arr.as_ref()) + ScalarValue::List(arr) => { + let arrays = std::iter::repeat(arr.as_ref() as &dyn Array) + .take(size) + .collect::>(); + arrow::compute::concat(arrays.as_slice()) + .map_err(DataFusionError::ArrowError)? + } + ScalarValue::LargeList(arr) => { + let arrays = std::iter::repeat(arr.as_ref() as &dyn Array) + .take(size) + .collect::>(); + arrow::compute::concat(arrays.as_slice()) + .map_err(DataFusionError::ArrowError)? + } + ScalarValue::FixedSizeList(arr) => { + let arrays = std::iter::repeat(arr.as_ref() as &dyn Array) .take(size) .collect::>(); arrow::compute::concat(arrays.as_slice()) @@ -2432,11 +2509,17 @@ impl ScalarValue { ScalarValue::LargeBinary(val) => { eq_array_primitive!(array, index, LargeBinaryArray, val)? } - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => { + ScalarValue::List(arr) => { + let right = array.slice(index, 1); + arr.as_ref() as &dyn Array == &right + } + ScalarValue::LargeList(arr) => { + let right = array.slice(index, 1); + arr.as_ref() as &dyn Array == &right + } + ScalarValue::FixedSizeList(arr) => { let right = array.slice(index, 1); - arr == &right + arr.as_ref() as &dyn Array == &right } ScalarValue::Date32(val) => { eq_array_primitive!(array, index, Date32Array, val)? @@ -2560,9 +2643,9 @@ impl ScalarValue { | ScalarValue::LargeBinary(b) => { b.as_ref().map(|b| b.capacity()).unwrap_or_default() } - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), + ScalarValue::List(arr) => arr.get_array_memory_size(), + ScalarValue::LargeList(arr) => arr.get_array_memory_size(), + ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), ScalarValue::Struct(vals, fields) => { vals.as_ref() .map(|vals| { @@ -2864,14 +2947,22 @@ impl TryFrom<&DataType> for ScalarValue { Box::new(value_type.as_ref().try_into()?), ), // `ScalaValue::List` contains single element `ListArray`. - DataType::List(field) => ScalarValue::List(new_null_array( - &DataType::List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - ))), - 1, - )), + DataType::List(field) => { + ScalarValue::List( + new_null_array( + &DataType::List( + Arc::new( + Field::new( + "item", + field.data_type().clone(), + true, + ) + ) + ), + 1, + ).as_list::().to_owned().into() + ) + }, DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()), DataType::Null => ScalarValue::Null, _ => { @@ -2936,13 +3027,27 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => { + ScalarValue::List(arr) => { + // ScalarValue List should always have a single element + assert_eq!(arr.len(), 1); + let options = FormatOptions::default().with_display_error(true); + let formatter = ArrayFormatter::try_new(arr.as_ref() as &dyn Array, &options).unwrap(); + let value_formatter = formatter.value(0); + write!(f, "{value_formatter}")? + } + ScalarValue::LargeList(arr) => { + // ScalarValue List should always have a single element + assert_eq!(arr.len(), 1); + let options = FormatOptions::default().with_display_error(true); + let formatter = ArrayFormatter::try_new(arr.as_ref() as &dyn Array, &options).unwrap(); + let value_formatter = formatter.value(0); + write!(f, "{value_formatter}")? + } + ScalarValue::FixedSizeList(arr) => { // ScalarValue List should always have a single element assert_eq!(arr.len(), 1); let options = FormatOptions::default().with_display_error(true); - let formatter = ArrayFormatter::try_new(arr, &options).unwrap(); + let formatter = ArrayFormatter::try_new(arr.as_ref() as &dyn Array, &options).unwrap(); let value_formatter = formatter.value(0); write!(f, "{value_formatter}")? } @@ -3181,15 +3286,14 @@ mod tests { ScalarValue::from("data-fusion"), ]; - let array = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8); + let result = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8); let expected = array_into_list_array(Arc::new(StringArray::from(vec![ "rust", "arrow", "data-fusion", ]))); - let result = as_list_array(&array); - assert_eq!(result, &expected); + assert_eq!(*result, expected); } fn build_list( @@ -3225,9 +3329,9 @@ mod tests { }; if O::IS_LARGE { - ScalarValue::LargeList(arr) + ScalarValue::LargeList(arr.as_list::().to_owned().into()) } else { - ScalarValue::List(arr) + ScalarValue::List(arr.as_list::().to_owned().into()) } }) .collect() @@ -3300,33 +3404,6 @@ mod tests { assert_eq!(result, &expected); } - #[test] - fn test_list_scalar_eq_to_array() { - let list_array: ArrayRef = - Arc::new(ListArray::from_iter_primitive::(vec![ - Some(vec![Some(0), Some(1), Some(2)]), - None, - Some(vec![None, Some(5)]), - ])); - - let fsl_array: ArrayRef = - Arc::new(FixedSizeListArray::from_iter_primitive::( - vec![ - Some(vec![Some(0), Some(1), Some(2)]), - None, - Some(vec![Some(3), None, Some(5)]), - ], - 3, - )); - - for arr in [list_array, fsl_array] { - for i in 0..arr.len() { - let scalar = ScalarValue::List(arr.slice(i, 1)); - assert!(scalar.eq_array(&arr, i).unwrap()); - } - } - } - #[test] fn scalar_add_trait_test() -> Result<()> { let float_value = ScalarValue::Float64(Some(123.)); @@ -3675,8 +3752,7 @@ mod tests { #[test] fn scalar_list_null_to_array() { - let list_array_ref = ScalarValue::new_list(&[], &DataType::UInt64); - let list_array = as_list_array(&list_array_ref); + let list_array = ScalarValue::new_list(&[], &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 0); @@ -3684,8 +3760,7 @@ mod tests { #[test] fn scalar_large_list_null_to_array() { - let list_array_ref = ScalarValue::new_large_list(&[], &DataType::UInt64); - let list_array = as_large_list_array(&list_array_ref); + let list_array = ScalarValue::new_large_list(&[], &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 0); @@ -3698,8 +3773,7 @@ mod tests { ScalarValue::UInt64(None), ScalarValue::UInt64(Some(101)), ]; - let list_array_ref = ScalarValue::new_list(&values, &DataType::UInt64); - let list_array = as_list_array(&list_array_ref); + let list_array = ScalarValue::new_list(&values, &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 3); @@ -3719,8 +3793,7 @@ mod tests { ScalarValue::UInt64(None), ScalarValue::UInt64(Some(101)), ]; - let list_array_ref = ScalarValue::new_large_list(&values, &DataType::UInt64); - let list_array = as_large_list_array(&list_array_ref); + let list_array = ScalarValue::new_large_list(&values, &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 3); @@ -3958,10 +4031,11 @@ mod tests { let data_type = &data_type; let scalar: ScalarValue = data_type.try_into().unwrap(); - let expected = ScalarValue::List(new_null_array( + let expected = ScalarValue::List( + new_null_array( &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), 1, - )); + ).as_list::().to_owned().into()); assert_eq!(expected, scalar) } @@ -3983,7 +4057,7 @@ mod tests { true, ))), 1, - )); + ).as_list::().to_owned().into()); assert_eq!(expected, scalar) } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index e2fbd5e927a1..5459d405f8c1 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -27,7 +27,7 @@ use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::SimplifyInfo; use arrow::{ - array::new_null_array, + array::{new_null_array, AsArray}, datatypes::{DataType, Field, Schema}, error::ArrowError, record_batch::RecordBatch, @@ -396,7 +396,7 @@ impl<'a> ConstEvaluator<'a> { a.len() ) } else if as_list_array(&a).is_ok() || as_large_list_array(&a).is_ok() { - Ok(ScalarValue::List(a)) + Ok(ScalarValue::List(a.as_list().to_owned().into())) } else { // Non-ListArray ScalarValue::try_from_array(&a, 0) diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 1efae424cc69..8fb633a3cd43 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -186,7 +186,6 @@ mod tests { use arrow::array::{ArrayRef, Int32Array}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; - use arrow_array::cast::as_list_array; use arrow_array::types::Int32Type; use arrow_array::{Array, ListArray}; use arrow_buffer::OffsetBuffer; @@ -197,8 +196,7 @@ mod tests { fn sort_list_inner(arr: ScalarValue) -> ScalarValue { let arr = match arr { ScalarValue::List(arr) => { - let list_arr = as_list_array(&arr); - list_arr.value(0) + arr.value(0) } _ => { panic!("Expected ScalarValue::List, got {:?}", arr) diff --git a/datafusion/physical-expr/src/aggregate/tdigest.rs b/datafusion/physical-expr/src/aggregate/tdigest.rs index 90f5244f477d..34bf511fa52e 100644 --- a/datafusion/physical-expr/src/aggregate/tdigest.rs +++ b/datafusion/physical-expr/src/aggregate/tdigest.rs @@ -606,11 +606,10 @@ impl TDigest { let centroids: Vec<_> = match &state[5] { ScalarValue::List(arr) => { - let list_array = as_list_array(arr); - let arr = list_array.values(); + let array = arr.values(); let f64arr = - as_primitive_array::(arr).expect("expected f64 array"); + as_primitive_array::(array).expect("expected f64 array"); f64arr .values() .chunks(2) diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs index b624fb362e65..c910e0cabe6f 100644 --- a/datafusion/physical-plan/src/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -30,6 +30,7 @@ use crate::{ use arrow::array::new_null_array; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use arrow_array::cast::AsArray; use datafusion_common::{internal_err, plan_err, DataFusionError, Result, ScalarValue}; use datafusion_execution::TaskContext; @@ -71,7 +72,7 @@ impl ValuesExec { match r { Ok(ColumnarValue::Scalar(scalar)) => Ok(scalar), Ok(ColumnarValue::Array(a)) if a.len() == 1 => { - Ok(ScalarValue::List(a)) + Ok(ScalarValue::List(a.as_list().to_owned().into())) } Ok(ColumnarValue::Array(a)) => { plan_err!( diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 193e0947d6d9..f809609722de 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -32,7 +32,7 @@ use arrow::{ i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, UnionFields, UnionMode, }, - ipc::{reader::read_record_batch, root_as_message}, + ipc::{reader::read_record_batch, root_as_message}, array::AsArray, }; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ @@ -721,9 +721,9 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { .map_err(|e| e.context("Decoding ScalarValue::List Value"))?; let arr = record_batch.column(0); match value { - Value::ListValue(_) => Self::List(arr.to_owned()), - Value::LargeListValue(_) => Self::LargeList(arr.to_owned()), - Value::FixedSizeListValue(_) => Self::FixedSizeList(arr.to_owned()), + Value::ListValue(_) => Self::List(arr.as_list::().to_owned().into()), + Value::LargeListValue(_) => Self::LargeList(arr.as_list::().to_owned().into()), + Value::FixedSizeListValue(_) => Self::FixedSizeList(arr.as_fixed_size_list().to_owned().into()), _ => unreachable!(), } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 2997d147424d..59883080df1e 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -37,7 +37,7 @@ use arrow::{ TimeUnit, UnionMode, }, ipc::writer::{DictionaryTracker, IpcDataGenerator}, - record_batch::RecordBatch, + record_batch::RecordBatch, array::ArrayRef, }; use datafusion_common::{ Column, Constraint, Constraints, DFField, DFSchema, DFSchemaRef, OwnedTableReference, @@ -1159,13 +1159,105 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { } // ScalarValue::List and ScalarValue::FixedSizeList are serialized using // Arrow IPC messages as a single column RecordBatch - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => { + ScalarValue::List(arr) => { // Wrap in a "field_name" column let batch = RecordBatch::try_from_iter(vec![( "field_name", - arr.to_owned(), + arr.to_owned() as ArrayRef, + )]) + .map_err(|e| { + Error::General( format!("Error creating temporary batch while encoding ScalarValue::List: {e}")) + })?; + + let gen = IpcDataGenerator {}; + let mut dict_tracker = DictionaryTracker::new(false); + let (_, encoded_message) = gen + .encoded_batch(&batch, &mut dict_tracker, &Default::default()) + .map_err(|e| { + Error::General(format!( + "Error encoding ScalarValue::List as IPC: {e}" + )) + })?; + + let schema: protobuf::Schema = batch.schema().try_into()?; + + let scalar_list_value = protobuf::ScalarListValue { + ipc_message: encoded_message.ipc_message, + arrow_data: encoded_message.arrow_data, + schema: Some(schema), + }; + + match val { + ScalarValue::List(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::ListValue( + scalar_list_value, + )), + }), + ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::LargeListValue( + scalar_list_value, + )), + }), + ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::FixedSizeListValue( + scalar_list_value, + )), + }), + _ => unreachable!(), + } + } + ScalarValue::LargeList(arr) => { + // Wrap in a "field_name" column + let batch = RecordBatch::try_from_iter(vec![( + "field_name", + arr.to_owned() as ArrayRef, + )]) + .map_err(|e| { + Error::General( format!("Error creating temporary batch while encoding ScalarValue::List: {e}")) + })?; + + let gen = IpcDataGenerator {}; + let mut dict_tracker = DictionaryTracker::new(false); + let (_, encoded_message) = gen + .encoded_batch(&batch, &mut dict_tracker, &Default::default()) + .map_err(|e| { + Error::General(format!( + "Error encoding ScalarValue::List as IPC: {e}" + )) + })?; + + let schema: protobuf::Schema = batch.schema().try_into()?; + + let scalar_list_value = protobuf::ScalarListValue { + ipc_message: encoded_message.ipc_message, + arrow_data: encoded_message.arrow_data, + schema: Some(schema), + }; + + match val { + ScalarValue::List(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::ListValue( + scalar_list_value, + )), + }), + ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::LargeListValue( + scalar_list_value, + )), + }), + ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::FixedSizeListValue( + scalar_list_value, + )), + }), + _ => unreachable!(), + } + } + ScalarValue::FixedSizeList(arr) => { + // Wrap in a "field_name" column + let batch = RecordBatch::try_from_iter(vec![( + "field_name", + arr.to_owned() as ArrayRef, )]) .map_err(|e| { Error::General( format!("Error creating temporary batch while encoding ScalarValue::List: {e}")) From 6938bd6c9e22932e6cdaf67879884e5880eac8cd Mon Sep 17 00:00:00 2001 From: Spears Randall Date: Fri, 15 Dec 2023 16:52:24 -0700 Subject: [PATCH 02/14] Formatting/cleanup --- datafusion/common/src/scalar.rs | 78 +++++++++++-------- .../src/aggregate/array_agg_distinct.rs | 4 +- .../physical-expr/src/aggregate/tdigest.rs | 1 - .../proto/src/logical_plan/from_proto.rs | 15 +++- datafusion/proto/src/logical_plan/to_proto.rs | 3 +- 5 files changed, 60 insertions(+), 41 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 8f2c6c159a7c..f33a818d272f 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -445,8 +445,7 @@ impl PartialOrd for ScalarValue { } } - Some(Ordering::Equal) - + Some(Ordering::Equal) } (List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), @@ -1804,7 +1803,10 @@ impl ScalarValue { /// /// assert_eq!(*result, expected); /// ``` - pub fn new_large_list(values: &[ScalarValue], data_type: &DataType) -> Arc { + pub fn new_large_list( + values: &[ScalarValue], + data_type: &DataType, + ) -> Arc { let values = if values.is_empty() { new_empty_array(data_type) } else { @@ -2947,22 +2949,19 @@ impl TryFrom<&DataType> for ScalarValue { Box::new(value_type.as_ref().try_into()?), ), // `ScalaValue::List` contains single element `ListArray`. - DataType::List(field) => { - ScalarValue::List( - new_null_array( - &DataType::List( - Arc::new( - Field::new( - "item", - field.data_type().clone(), - true, - ) - ) - ), - 1, - ).as_list::().to_owned().into() + DataType::List(field) => ScalarValue::List( + new_null_array( + &DataType::List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + ))), + 1, ) - }, + .as_list::() + .to_owned() + .into(), + ), DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()), DataType::Null => ScalarValue::Null, _ => { @@ -3031,7 +3030,9 @@ impl fmt::Display for ScalarValue { // ScalarValue List should always have a single element assert_eq!(arr.len(), 1); let options = FormatOptions::default().with_display_error(true); - let formatter = ArrayFormatter::try_new(arr.as_ref() as &dyn Array, &options).unwrap(); + let formatter = + ArrayFormatter::try_new(arr.as_ref() as &dyn Array, &options) + .unwrap(); let value_formatter = formatter.value(0); write!(f, "{value_formatter}")? } @@ -3039,7 +3040,9 @@ impl fmt::Display for ScalarValue { // ScalarValue List should always have a single element assert_eq!(arr.len(), 1); let options = FormatOptions::default().with_display_error(true); - let formatter = ArrayFormatter::try_new(arr.as_ref() as &dyn Array, &options).unwrap(); + let formatter = + ArrayFormatter::try_new(arr.as_ref() as &dyn Array, &options) + .unwrap(); let value_formatter = formatter.value(0); write!(f, "{value_formatter}")? } @@ -3047,7 +3050,9 @@ impl fmt::Display for ScalarValue { // ScalarValue List should always have a single element assert_eq!(arr.len(), 1); let options = FormatOptions::default().with_display_error(true); - let formatter = ArrayFormatter::try_new(arr.as_ref() as &dyn Array, &options).unwrap(); + let formatter = + ArrayFormatter::try_new(arr.as_ref() as &dyn Array, &options) + .unwrap(); let value_formatter = formatter.value(0); write!(f, "{value_formatter}")? } @@ -4033,9 +4038,13 @@ mod tests { let expected = ScalarValue::List( new_null_array( - &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - 1, - ).as_list::().to_owned().into()); + &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + 1, + ) + .as_list::() + .to_owned() + .into(), + ); assert_eq!(expected, scalar) } @@ -4050,14 +4059,19 @@ mod tests { let data_type = &data_type; let scalar: ScalarValue = data_type.try_into().unwrap(); - let expected = ScalarValue::List(new_null_array( - &DataType::List(Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - true, - ))), - 1, - ).as_list::().to_owned().into()); + let expected = ScalarValue::List( + new_null_array( + &DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ))), + 1, + ) + .as_list::() + .to_owned() + .into(), + ); assert_eq!(expected, scalar) } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 8fb633a3cd43..2d263a42e0ff 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -195,9 +195,7 @@ mod tests { // arrow::compute::sort cann't sort ListArray directly, so we need to sort the inner primitive array and wrap it back into ListArray. fn sort_list_inner(arr: ScalarValue) -> ScalarValue { let arr = match arr { - ScalarValue::List(arr) => { - arr.value(0) - } + ScalarValue::List(arr) => arr.value(0), _ => { panic!("Expected ScalarValue::List, got {:?}", arr) } diff --git a/datafusion/physical-expr/src/aggregate/tdigest.rs b/datafusion/physical-expr/src/aggregate/tdigest.rs index 34bf511fa52e..78708df94c25 100644 --- a/datafusion/physical-expr/src/aggregate/tdigest.rs +++ b/datafusion/physical-expr/src/aggregate/tdigest.rs @@ -28,7 +28,6 @@ //! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h use arrow::datatypes::DataType; -use arrow_array::cast::as_list_array; use arrow_array::types::Float64Type; use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index f809609722de..70227b7757cb 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -27,12 +27,13 @@ use crate::protobuf::{ OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; use arrow::{ + array::AsArray, buffer::Buffer, datatypes::{ i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, UnionFields, UnionMode, }, - ipc::{reader::read_record_batch, root_as_message}, array::AsArray, + ipc::{reader::read_record_batch, root_as_message}, }; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ @@ -721,9 +722,15 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { .map_err(|e| e.context("Decoding ScalarValue::List Value"))?; let arr = record_batch.column(0); match value { - Value::ListValue(_) => Self::List(arr.as_list::().to_owned().into()), - Value::LargeListValue(_) => Self::LargeList(arr.as_list::().to_owned().into()), - Value::FixedSizeListValue(_) => Self::FixedSizeList(arr.as_fixed_size_list().to_owned().into()), + Value::ListValue(_) => { + Self::List(arr.as_list::().to_owned().into()) + } + Value::LargeListValue(_) => { + Self::LargeList(arr.as_list::().to_owned().into()) + } + Value::FixedSizeListValue(_) => { + Self::FixedSizeList(arr.as_fixed_size_list().to_owned().into()) + } _ => unreachable!(), } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 59883080df1e..9a37cfba1286 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -32,12 +32,13 @@ use crate::protobuf::{ OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; use arrow::{ + array::ArrayRef, datatypes::{ DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionMode, }, ipc::writer::{DictionaryTracker, IpcDataGenerator}, - record_batch::RecordBatch, array::ArrayRef, + record_batch::RecordBatch, }; use datafusion_common::{ Column, Constraint, Constraints, DFField, DFSchema, DFSchemaRef, OwnedTableReference, From a8585b6a33e0d2b3c5046093dbf68e65e1c4ff1d Mon Sep 17 00:00:00 2001 From: Spears Randall Date: Mon, 18 Dec 2023 09:19:25 -0700 Subject: [PATCH 03/14] Remove duplicate match statements --- datafusion/proto/src/logical_plan/to_proto.rs | 70 +++++-------------- 1 file changed, 16 insertions(+), 54 deletions(-) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 9a37cfba1286..9765bc729143 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1188,24 +1188,11 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { schema: Some(schema), }; - match val { - ScalarValue::List(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::ListValue( - scalar_list_value, - )), - }), - ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::LargeListValue( - scalar_list_value, - )), - }), - ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::FixedSizeListValue( - scalar_list_value, - )), - }), - _ => unreachable!(), - } + Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::ListValue( + scalar_list_value, + )), + }) } ScalarValue::LargeList(arr) => { // Wrap in a "field_name" column @@ -1235,24 +1222,11 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { schema: Some(schema), }; - match val { - ScalarValue::List(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::ListValue( - scalar_list_value, - )), - }), - ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::LargeListValue( - scalar_list_value, - )), - }), - ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::FixedSizeListValue( - scalar_list_value, - )), - }), - _ => unreachable!(), - } + Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::LargeListValue( + scalar_list_value, + )), + }) } ScalarValue::FixedSizeList(arr) => { // Wrap in a "field_name" column @@ -1282,24 +1256,12 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { schema: Some(schema), }; - match val { - ScalarValue::List(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::ListValue( - scalar_list_value, - )), - }), - ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::LargeListValue( - scalar_list_value, - )), - }), - ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::FixedSizeListValue( - scalar_list_value, - )), - }), - _ => unreachable!(), - } + + Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::FixedSizeListValue( + scalar_list_value, + )), + }) } ScalarValue::Date32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s)) From da1bf38280f1ec6d6f8a41e0895ead6c80bb5da4 Mon Sep 17 00:00:00 2001 From: Spears Randall Date: Mon, 18 Dec 2023 09:33:57 -0700 Subject: [PATCH 04/14] Add back scalar eq_array test for List --- datafusion/common/src/scalar.rs | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index f33a818d272f..b4403ec58a4c 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -3409,6 +3409,32 @@ mod tests { assert_eq!(result, &expected); } + #[test] + fn test_list_scalar_eq_to_array() { + let list_array: ArrayRef = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![None, Some(5)]), + ])); + + let fsl_array: ArrayRef = + Arc::new(ListArray::from_iter_primitive::( + vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + ] + )); + + for arr in [list_array, fsl_array] { + for i in 0..arr.len() { + let scalar = ScalarValue::List(arr.slice(i, 1).as_list::().to_owned().into()); + assert!(scalar.eq_array(&arr, i).unwrap()); + } + } + } + #[test] fn scalar_add_trait_test() -> Result<()> { let float_value = ScalarValue::Float64(Some(123.)); From 48bfc92efa7bb4f3239d048ebcb7364e8d353a86 Mon Sep 17 00:00:00 2001 From: Spears Randall Date: Mon, 18 Dec 2023 11:19:54 -0700 Subject: [PATCH 05/14] Formatting --- datafusion/common/src/scalar.rs | 15 +++++++-------- datafusion/proto/src/logical_plan/to_proto.rs | 1 - 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index b4403ec58a4c..a9cb1305a414 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -3419,17 +3419,16 @@ mod tests { ])); let fsl_array: ArrayRef = - Arc::new(ListArray::from_iter_primitive::( - vec![ - Some(vec![Some(0), Some(1), Some(2)]), - None, - Some(vec![Some(3), None, Some(5)]), - ] - )); + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + ])); for arr in [list_array, fsl_array] { for i in 0..arr.len() { - let scalar = ScalarValue::List(arr.slice(i, 1).as_list::().to_owned().into()); + let scalar = + ScalarValue::List(arr.slice(i, 1).as_list::().to_owned().into()); assert!(scalar.eq_array(&arr, i).unwrap()); } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 9765bc729143..1fb2edd717b4 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1256,7 +1256,6 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { schema: Some(schema), }; - Ok(protobuf::ScalarValue { value: Some(protobuf::scalar_value::Value::FixedSizeListValue( scalar_list_value, From bf1adafc9b979cd708d1b90168d313a7eb39a492 Mon Sep 17 00:00:00 2001 From: Spears Randall Date: Fri, 5 Jan 2024 17:31:29 -0700 Subject: [PATCH 06/14] Reduce code duplication --- datafusion/common/src/scalar.rs | 61 +++----- datafusion/proto/src/logical_plan/to_proto.rs | 142 ++++++------------ 2 files changed, 70 insertions(+), 133 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index a9cb1305a414..4f281f0ce9ad 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -21,6 +21,7 @@ use std::borrow::Borrow; use std::cmp::Ordering; use std::collections::HashSet; use std::convert::{Infallible, TryInto}; +use std::hash::Hash; use std::str::FromStr; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; @@ -565,31 +566,13 @@ impl std::hash::Hash for ScalarValue { FixedSizeBinary(_, v) => v.hash(state), LargeBinary(v) => v.hash(state), List(arr) => { - let arrays = vec![arr.to_owned() as ArrayRef]; - let hashes_buffer = &mut vec![0; arr.len()]; - let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); - let hashes = - create_hashes(&arrays, &random_state, hashes_buffer).unwrap(); - // Hash back to std::hash::Hasher - hashes.hash(state); + hash_list(arr.to_owned() as ArrayRef, state); } LargeList(arr) => { - let arrays = vec![arr.to_owned() as ArrayRef]; - let hashes_buffer = &mut vec![0; arr.len()]; - let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); - let hashes = - create_hashes(&arrays, &random_state, hashes_buffer).unwrap(); - // Hash back to std::hash::Hasher - hashes.hash(state); + hash_list(arr.to_owned() as ArrayRef, state); } FixedSizeList(arr) => { - let arrays = vec![arr.to_owned() as ArrayRef]; - let hashes_buffer = &mut vec![0; arr.len()]; - let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); - let hashes = - create_hashes(&arrays, &random_state, hashes_buffer).unwrap(); - // Hash back to std::hash::Hasher - hashes.hash(state); + hash_list(arr.to_owned() as ArrayRef, state); } Date32(v) => v.hash(state), Date64(v) => v.hash(state), @@ -622,6 +605,15 @@ impl std::hash::Hash for ScalarValue { } } +fn hash_list(arr: ArrayRef, state: &mut H) { + let arrays = vec![arr.to_owned()]; + let hashes_buffer = &mut vec![0; arr.len()]; + let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); + let hashes = create_hashes(&arrays, &random_state, hashes_buffer).unwrap(); + // Hash back to std::hash::Hasher + hashes.hash(state); +} + /// Return a reference to the values array and the index into it for a /// dictionary array /// @@ -1943,25 +1935,13 @@ impl ScalarValue { ), }, ScalarValue::List(arr) => { - let arrays = std::iter::repeat(arr.as_ref() as &dyn Array) - .take(size) - .collect::>(); - arrow::compute::concat(arrays.as_slice()) - .map_err(DataFusionError::ArrowError)? + Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } ScalarValue::LargeList(arr) => { - let arrays = std::iter::repeat(arr.as_ref() as &dyn Array) - .take(size) - .collect::>(); - arrow::compute::concat(arrays.as_slice()) - .map_err(DataFusionError::ArrowError)? + Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } ScalarValue::FixedSizeList(arr) => { - let arrays = std::iter::repeat(arr.as_ref() as &dyn Array) - .take(size) - .collect::>(); - arrow::compute::concat(arrays.as_slice()) - .map_err(DataFusionError::ArrowError)? + Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } ScalarValue::Date32(e) => { build_array_from_option!(Date32, Date32Array, e, size) @@ -2118,6 +2098,11 @@ impl ScalarValue { } } + fn list_to_array_of_size(arr: &dyn Array, size: usize) -> Result { + let arrays = std::iter::repeat(arr).take(size).collect::>(); + arrow::compute::concat(arrays.as_slice()).map_err(DataFusionError::ArrowError) + } + /// Retrieve ScalarValue for each row in `array` /// /// Example @@ -3037,7 +3022,7 @@ impl fmt::Display for ScalarValue { write!(f, "{value_formatter}")? } ScalarValue::LargeList(arr) => { - // ScalarValue List should always have a single element + // ScalarValue LargeList should always have a single element assert_eq!(arr.len(), 1); let options = FormatOptions::default().with_display_error(true); let formatter = @@ -3047,7 +3032,7 @@ impl fmt::Display for ScalarValue { write!(f, "{value_formatter}")? } ScalarValue::FixedSizeList(arr) => { - // ScalarValue List should always have a single element + // ScalarValue FixedSizeList should always have a single element assert_eq!(arr.len(), 1); let options = FormatOptions::default().with_display_error(true); let formatter = diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 1fb2edd717b4..ac8555010c6d 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1161,106 +1161,14 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { // ScalarValue::List and ScalarValue::FixedSizeList are serialized using // Arrow IPC messages as a single column RecordBatch ScalarValue::List(arr) => { - // Wrap in a "field_name" column - let batch = RecordBatch::try_from_iter(vec![( - "field_name", - arr.to_owned() as ArrayRef, - )]) - .map_err(|e| { - Error::General( format!("Error creating temporary batch while encoding ScalarValue::List: {e}")) - })?; - - let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false); - let (_, encoded_message) = gen - .encoded_batch(&batch, &mut dict_tracker, &Default::default()) - .map_err(|e| { - Error::General(format!( - "Error encoding ScalarValue::List as IPC: {e}" - )) - })?; - - let schema: protobuf::Schema = batch.schema().try_into()?; - - let scalar_list_value = protobuf::ScalarListValue { - ipc_message: encoded_message.ipc_message, - arrow_data: encoded_message.arrow_data, - schema: Some(schema), - }; - - Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::ListValue( - scalar_list_value, - )), - }) + encode_scalar_list_value(arr.to_owned() as ArrayRef, val) } ScalarValue::LargeList(arr) => { // Wrap in a "field_name" column - let batch = RecordBatch::try_from_iter(vec![( - "field_name", - arr.to_owned() as ArrayRef, - )]) - .map_err(|e| { - Error::General( format!("Error creating temporary batch while encoding ScalarValue::List: {e}")) - })?; - - let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false); - let (_, encoded_message) = gen - .encoded_batch(&batch, &mut dict_tracker, &Default::default()) - .map_err(|e| { - Error::General(format!( - "Error encoding ScalarValue::List as IPC: {e}" - )) - })?; - - let schema: protobuf::Schema = batch.schema().try_into()?; - - let scalar_list_value = protobuf::ScalarListValue { - ipc_message: encoded_message.ipc_message, - arrow_data: encoded_message.arrow_data, - schema: Some(schema), - }; - - Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::LargeListValue( - scalar_list_value, - )), - }) + encode_scalar_list_value(arr.to_owned() as ArrayRef, val) } ScalarValue::FixedSizeList(arr) => { - // Wrap in a "field_name" column - let batch = RecordBatch::try_from_iter(vec![( - "field_name", - arr.to_owned() as ArrayRef, - )]) - .map_err(|e| { - Error::General( format!("Error creating temporary batch while encoding ScalarValue::List: {e}")) - })?; - - let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false); - let (_, encoded_message) = gen - .encoded_batch(&batch, &mut dict_tracker, &Default::default()) - .map_err(|e| { - Error::General(format!( - "Error encoding ScalarValue::List as IPC: {e}" - )) - })?; - - let schema: protobuf::Schema = batch.schema().try_into()?; - - let scalar_list_value = protobuf::ScalarListValue { - ipc_message: encoded_message.ipc_message, - arrow_data: encoded_message.arrow_data, - schema: Some(schema), - }; - - Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::FixedSizeListValue( - scalar_list_value, - )), - }) + encode_scalar_list_value(arr.to_owned() as ArrayRef, val) } ScalarValue::Date32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s)) @@ -1777,3 +1685,47 @@ fn create_proto_scalar protobuf::scalar_value::Value>( Ok(protobuf::ScalarValue { value: Some(value) }) } + +fn encode_scalar_list_value( + arr: ArrayRef, + val: &ScalarValue, +) -> Result { + let batch = RecordBatch::try_from_iter(vec![("field_name", arr)]).map_err(|e| { + Error::General(format!( + "Error creating temporary batch while encoding ScalarValue::List: {e}" + )) + })?; + + let gen = IpcDataGenerator {}; + let mut dict_tracker = DictionaryTracker::new(false); + let (_, encoded_message) = gen + .encoded_batch(&batch, &mut dict_tracker, &Default::default()) + .map_err(|e| { + Error::General(format!("Error encoding ScalarValue::List as IPC: {e}")) + })?; + + let schema: protobuf::Schema = batch.schema().try_into()?; + + let scalar_list_value = protobuf::ScalarListValue { + ipc_message: encoded_message.ipc_message, + arrow_data: encoded_message.arrow_data, + schema: Some(schema), + }; + + match val { + ScalarValue::List(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::ListValue(scalar_list_value)), + }), + ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::LargeListValue( + scalar_list_value, + )), + }), + ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::FixedSizeListValue( + scalar_list_value, + )), + }), + _ => unreachable!(), + } +} From 73d6776fe955b903a245705440e3e5f5616867e5 Mon Sep 17 00:00:00 2001 From: Spears Randall Date: Fri, 5 Jan 2024 17:40:45 -0700 Subject: [PATCH 07/14] Fix merge conflict --- datafusion/common/src/scalar.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 22b45eb69eb4..a5e40312d300 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -2100,8 +2100,11 @@ impl ScalarValue { } fn list_to_array_of_size(arr: &dyn Array, size: usize) -> Result { - let arrays = std::iter::repeat(arr).take(size).collect::>(); - arrow::compute::concat(arrays.as_slice()).map_err(arrow_datafusion_error!(e)) + let arrays = std::iter::repeat(arr) + .take(size) + .collect::>(); + arrow::compute::concat(arrays.as_slice()) + .map_err(|e| arrow_datafusion_err!(e)) } /// Retrieve ScalarValue for each row in `array` From 41f0d3648621e4b7cc98d59664ffae4c2ed5ccb4 Mon Sep 17 00:00:00 2001 From: Spears Randall Date: Fri, 5 Jan 2024 23:43:20 -0700 Subject: [PATCH 08/14] Fix post-merge compile errors --- datafusion/physical-expr/src/aggregate/count_distinct.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index f7c13948b2dc..021c33fb94a7 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -292,7 +292,7 @@ where let arr = Arc::new(PrimitiveArray::::from_iter_values( self.values.iter().cloned(), )) as ArrayRef; - let list = Arc::new(array_into_list_array(arr)) as ArrayRef; + let list = Arc::new(array_into_list_array(arr)); Ok(vec![ScalarValue::List(list)]) } @@ -378,7 +378,7 @@ where let arr = Arc::new(PrimitiveArray::::from_iter_values( self.values.iter().map(|v| v.0), )) as ArrayRef; - let list = Arc::new(array_into_list_array(arr)) as ArrayRef; + let list = Arc::new(array_into_list_array(arr)); Ok(vec![ScalarValue::List(list)]) } From c1ec85e219ec0366ec5d8af90421b44609387a3e Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 6 Jan 2024 08:46:30 -0500 Subject: [PATCH 09/14] Remove redundant partial_cmp implementation --- datafusion/common/src/scalar.rs | 132 +++++++++++--------------------- 1 file changed, 44 insertions(+), 88 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index a5e40312d300..808d353a6b93 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -362,92 +362,12 @@ impl PartialOrd for ScalarValue { (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), (LargeBinary(_), _) => None, // ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1 - (List(arr1), List(arr2)) => { - assert_eq!(arr1.len(), 1); - assert_eq!(arr2.len(), 1); - - if arr1.data_type() != arr2.data_type() { - return None; - } - - fn first_array_for_list(arr: &Arc) -> ArrayRef { - arr.value(0) - } - - let arr1 = first_array_for_list(arr1); - let arr2 = first_array_for_list(arr2); - - let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; - let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; - - for j in 0..lt_res.len() { - if lt_res.is_valid(j) && lt_res.value(j) { - return Some(Ordering::Less); - } - if eq_res.is_valid(j) && !eq_res.value(j) { - return Some(Ordering::Greater); - } - } - - Some(Ordering::Equal) - } + (List(arr1), List(arr2)) => partial_cmp_list(arr1.as_ref(), arr2.as_ref()), (FixedSizeList(arr1), FixedSizeList(arr2)) => { - assert_eq!(arr1.len(), 1); - assert_eq!(arr2.len(), 1); - - if arr1.data_type() != arr2.data_type() { - return None; - } - - fn first_array_for_list(arr: &Arc) -> ArrayRef { - arr.value(0) - } - - let arr1 = first_array_for_list(arr1); - let arr2 = first_array_for_list(arr2); - - let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; - let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; - - for j in 0..lt_res.len() { - if lt_res.is_valid(j) && lt_res.value(j) { - return Some(Ordering::Less); - } - if eq_res.is_valid(j) && !eq_res.value(j) { - return Some(Ordering::Greater); - } - } - - Some(Ordering::Equal) + partial_cmp_list(arr1.as_ref(), arr2.as_ref()) } (LargeList(arr1), LargeList(arr2)) => { - assert_eq!(arr1.len(), 1); - assert_eq!(arr2.len(), 1); - - if arr1.data_type() != arr2.data_type() { - return None; - } - - fn first_array_for_list(arr: &Arc) -> ArrayRef { - arr.value(0) - } - - let arr1 = first_array_for_list(arr1); - let arr2 = first_array_for_list(arr2); - - let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; - let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; - - for j in 0..lt_res.len() { - if lt_res.is_valid(j) && lt_res.value(j) { - return Some(Ordering::Less); - } - if eq_res.is_valid(j) && !eq_res.value(j) { - return Some(Ordering::Greater); - } - } - - Some(Ordering::Equal) + partial_cmp_list(arr1.as_ref(), arr2.as_ref()) } (List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), @@ -513,6 +433,45 @@ impl PartialOrd for ScalarValue { } } +/// Compares two List/LargeList/FixedSizeList scalars for equality +fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option { + assert_eq!(arr1.len(), 1); + assert_eq!(arr2.len(), 1); + + if arr1.data_type() != arr2.data_type() { + return None; + } + + fn first_array_for_list(arr: &dyn Array) -> ArrayRef { + if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_fixed_size_list_opt() { + arr.value(0) + } else { + unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen") + } + } + + let arr1 = first_array_for_list(arr1); + let arr2 = first_array_for_list(arr2); + + let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; + + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); + } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); + } + } + + Some(Ordering::Equal) +} + impl Eq for ScalarValue {} //Float wrapper over f32/f64. Just because we cannot build std::hash::Hash for floats directly we have to do it through type wrapper @@ -2100,11 +2059,8 @@ impl ScalarValue { } fn list_to_array_of_size(arr: &dyn Array, size: usize) -> Result { - let arrays = std::iter::repeat(arr) - .take(size) - .collect::>(); - arrow::compute::concat(arrays.as_slice()) - .map_err(|e| arrow_datafusion_err!(e)) + let arrays = std::iter::repeat(arr).take(size).collect::>(); + arrow::compute::concat(arrays.as_slice()).map_err(|e| arrow_datafusion_err!(e)) } /// Retrieve ScalarValue for each row in `array` From 2fd67d936ef235ad6363cc929593c9d4cd6359a0 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 6 Jan 2024 08:50:38 -0500 Subject: [PATCH 10/14] improve --- datafusion/common/src/scalar.rs | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 808d353a6b93..2324fa79d2c1 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -433,27 +433,26 @@ impl PartialOrd for ScalarValue { } } -/// Compares two List/LargeList/FixedSizeList scalars for equality -fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option { - assert_eq!(arr1.len(), 1); - assert_eq!(arr2.len(), 1); +/// List/LargeList/FixedSizeList scalars always have a single element +/// array. This function returns that array +fn first_array_for_list(arr: &dyn Array) -> ArrayRef { + assert_eq!(arr.len(), 1); + if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_fixed_size_list_opt() { + arr.value(0) + } else { + unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen") + } +} +/// Compares two List/LargeList/FixedSizeList scalars +fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option { if arr1.data_type() != arr2.data_type() { return None; } - - fn first_array_for_list(arr: &dyn Array) -> ArrayRef { - if let Some(arr) = arr.as_list_opt::() { - arr.value(0) - } else if let Some(arr) = arr.as_list_opt::() { - arr.value(0) - } else if let Some(arr) = arr.as_fixed_size_list_opt() { - arr.value(0) - } else { - unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen") - } - } - let arr1 = first_array_for_list(arr1); let arr2 = first_array_for_list(arr2); From 84f3fed29ef6a19569db50807aa32210940945ea Mon Sep 17 00:00:00 2001 From: Spears Randall Date: Sat, 6 Jan 2024 09:41:55 -0700 Subject: [PATCH 11/14] Cargo fmt fix --- datafusion/common/src/scalar.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index a5e40312d300..76d9cd3a4b7c 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -2100,11 +2100,8 @@ impl ScalarValue { } fn list_to_array_of_size(arr: &dyn Array, size: usize) -> Result { - let arrays = std::iter::repeat(arr) - .take(size) - .collect::>(); - arrow::compute::concat(arrays.as_slice()) - .map_err(|e| arrow_datafusion_err!(e)) + let arrays = std::iter::repeat(arr).take(size).collect::>(); + arrow::compute::concat(arrays.as_slice()).map_err(|e| arrow_datafusion_err!(e)) } /// Retrieve ScalarValue for each row in `array` From 38488af1be35b2fa0e39299b3e0dd0222e8fc273 Mon Sep 17 00:00:00 2001 From: Spears Randall Date: Sat, 6 Jan 2024 09:55:08 -0700 Subject: [PATCH 12/14] Reduce duplication in formatter --- datafusion/common/src/scalar.rs | 43 ++++++++++----------------------- 1 file changed, 13 insertions(+), 30 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 2324fa79d2c1..670190dd1b04 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -2970,36 +2970,9 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, - ScalarValue::List(arr) => { - // ScalarValue List should always have a single element - assert_eq!(arr.len(), 1); - let options = FormatOptions::default().with_display_error(true); - let formatter = - ArrayFormatter::try_new(arr.as_ref() as &dyn Array, &options) - .unwrap(); - let value_formatter = formatter.value(0); - write!(f, "{value_formatter}")? - } - ScalarValue::LargeList(arr) => { - // ScalarValue LargeList should always have a single element - assert_eq!(arr.len(), 1); - let options = FormatOptions::default().with_display_error(true); - let formatter = - ArrayFormatter::try_new(arr.as_ref() as &dyn Array, &options) - .unwrap(); - let value_formatter = formatter.value(0); - write!(f, "{value_formatter}")? - } - ScalarValue::FixedSizeList(arr) => { - // ScalarValue FixedSizeList should always have a single element - assert_eq!(arr.len(), 1); - let options = FormatOptions::default().with_display_error(true); - let formatter = - ArrayFormatter::try_new(arr.as_ref() as &dyn Array, &options) - .unwrap(); - let value_formatter = formatter.value(0); - write!(f, "{value_formatter}")? - } + ScalarValue::List(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, + ScalarValue::LargeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, + ScalarValue::FixedSizeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, ScalarValue::Date32(e) => format_option!(f, e)?, ScalarValue::Date64(e) => format_option!(f, e)?, ScalarValue::Time32Second(e) => format_option!(f, e)?, @@ -3032,6 +3005,16 @@ impl fmt::Display for ScalarValue { } } +fn fmt_list(arr: ArrayRef, f: &mut fmt::Formatter) -> fmt::Result { + // ScalarValue List, LargeList, FixedSizeList should always have a single element + assert_eq!(arr.len(), 1); + let options = FormatOptions::default().with_display_error(true); + let formatter = + ArrayFormatter::try_new(arr.as_ref() as &dyn Array, &options).unwrap(); + let value_formatter = formatter.value(0); + write!(f, "{value_formatter}") +} + impl fmt::Debug for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { From c212877ab92640bb1b13fb3032cac8efc61525e8 Mon Sep 17 00:00:00 2001 From: Spears Randall Date: Sat, 6 Jan 2024 10:09:12 -0700 Subject: [PATCH 13/14] Reduce more duplication --- datafusion/common/src/scalar.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 670190dd1b04..3910da8fd941 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -2456,16 +2456,13 @@ impl ScalarValue { eq_array_primitive!(array, index, LargeBinaryArray, val)? } ScalarValue::List(arr) => { - let right = array.slice(index, 1); - arr.as_ref() as &dyn Array == &right + Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) } ScalarValue::LargeList(arr) => { - let right = array.slice(index, 1); - arr.as_ref() as &dyn Array == &right + Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) } ScalarValue::FixedSizeList(arr) => { - let right = array.slice(index, 1); - arr.as_ref() as &dyn Array == &right + Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) } ScalarValue::Date32(val) => { eq_array_primitive!(array, index, Date32Array, val)? @@ -2543,6 +2540,11 @@ impl ScalarValue { }) } + fn eq_array_list(arr1: &ArrayRef, arr2: &ArrayRef, index: usize) -> bool { + let right = arr1.slice(index, 1); + arr2.as_ref() as &dyn Array == &right + } + /// Estimate size if bytes including `Self`. For values with internal containers such as `String` /// includes the allocated size (`capacity`) rather than the current length (`len`) pub fn size(&self) -> usize { From 588671c76d1686b1be97dc0c4614d9491acd228f Mon Sep 17 00:00:00 2001 From: Spears Randall Date: Sat, 6 Jan 2024 15:15:07 -0700 Subject: [PATCH 14/14] Fix test error --- datafusion/common/src/scalar.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 3910da8fd941..8820ca9942fc 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -2541,8 +2541,8 @@ impl ScalarValue { } fn eq_array_list(arr1: &ArrayRef, arr2: &ArrayRef, index: usize) -> bool { - let right = arr1.slice(index, 1); - arr2.as_ref() as &dyn Array == &right + let right = arr2.slice(index, 1); + arr1 == &right } /// Estimate size if bytes including `Self`. For values with internal containers such as `String`