Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redundant partial_cmp implementation #1

Merged
merged 2 commits into from
Jan 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 43 additions & 88 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,92 +362,12 @@ impl PartialOrd for ScalarValue {
(LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2),
(LargeBinary(_), _) => None,
// ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1
(List(arr1), List(arr2)) => {
assert_eq!(arr1.len(), 1);
assert_eq!(arr2.len(), 1);

if arr1.data_type() != arr2.data_type() {
return None;
}

fn first_array_for_list(arr: &Arc<ListArray>) -> ArrayRef {
arr.value(0)
}

let arr1 = first_array_for_list(arr1);
let arr2 = first_array_for_list(arr2);

let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?;
let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?;

for j in 0..lt_res.len() {
if lt_res.is_valid(j) && lt_res.value(j) {
return Some(Ordering::Less);
}
if eq_res.is_valid(j) && !eq_res.value(j) {
return Some(Ordering::Greater);
}
}

Some(Ordering::Equal)
}
(List(arr1), List(arr2)) => partial_cmp_list(arr1.as_ref(), arr2.as_ref()),
(FixedSizeList(arr1), FixedSizeList(arr2)) => {
assert_eq!(arr1.len(), 1);
assert_eq!(arr2.len(), 1);

if arr1.data_type() != arr2.data_type() {
return None;
}

fn first_array_for_list(arr: &Arc<FixedSizeListArray>) -> ArrayRef {
arr.value(0)
}

let arr1 = first_array_for_list(arr1);
let arr2 = first_array_for_list(arr2);

let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?;
let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?;

for j in 0..lt_res.len() {
if lt_res.is_valid(j) && lt_res.value(j) {
return Some(Ordering::Less);
}
if eq_res.is_valid(j) && !eq_res.value(j) {
return Some(Ordering::Greater);
}
}

Some(Ordering::Equal)
partial_cmp_list(arr1.as_ref(), arr2.as_ref())
}
(LargeList(arr1), LargeList(arr2)) => {
assert_eq!(arr1.len(), 1);
assert_eq!(arr2.len(), 1);

if arr1.data_type() != arr2.data_type() {
return None;
}

fn first_array_for_list(arr: &Arc<LargeListArray>) -> ArrayRef {
arr.value(0)
}

let arr1 = first_array_for_list(arr1);
let arr2 = first_array_for_list(arr2);

let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?;
let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?;

for j in 0..lt_res.len() {
if lt_res.is_valid(j) && lt_res.value(j) {
return Some(Ordering::Less);
}
if eq_res.is_valid(j) && !eq_res.value(j) {
return Some(Ordering::Greater);
}
}

Some(Ordering::Equal)
partial_cmp_list(arr1.as_ref(), arr2.as_ref())
}
(List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None,
(Date32(v1), Date32(v2)) => v1.partial_cmp(v2),
Expand Down Expand Up @@ -513,6 +433,44 @@ impl PartialOrd for ScalarValue {
}
}

/// List/LargeList/FixedSizeList scalars always have a single element
/// array. This function returns that array
fn first_array_for_list(arr: &dyn Array) -> ArrayRef {
assert_eq!(arr.len(), 1);
if let Some(arr) = arr.as_list_opt::<i32>() {
arr.value(0)
} else if let Some(arr) = arr.as_list_opt::<i64>() {
arr.value(0)
} else if let Some(arr) = arr.as_fixed_size_list_opt() {
arr.value(0)
} else {
unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen")
}
}

/// Compares two List/LargeList/FixedSizeList scalars
fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option<Ordering> {
if arr1.data_type() != arr2.data_type() {
return None;
}
let arr1 = first_array_for_list(arr1);
let arr2 = first_array_for_list(arr2);

let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?;
let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?;

for j in 0..lt_res.len() {
if lt_res.is_valid(j) && lt_res.value(j) {
return Some(Ordering::Less);
}
if eq_res.is_valid(j) && !eq_res.value(j) {
return Some(Ordering::Greater);
}
}

Some(Ordering::Equal)
}

impl Eq for ScalarValue {}

//Float wrapper over f32/f64. Just because we cannot build std::hash::Hash for floats directly we have to do it through type wrapper
Expand Down Expand Up @@ -2100,11 +2058,8 @@ impl ScalarValue {
}

fn list_to_array_of_size(arr: &dyn Array, size: usize) -> Result<ArrayRef> {
let arrays = std::iter::repeat(arr)
.take(size)
.collect::<Vec<_>>();
arrow::compute::concat(arrays.as_slice())
.map_err(|e| arrow_datafusion_err!(e))
let arrays = std::iter::repeat(arr).take(size).collect::<Vec<_>>();
arrow::compute::concat(arrays.as_slice()).map_err(|e| arrow_datafusion_err!(e))
}

/// Retrieve ScalarValue for each row in `array`
Expand Down