Skip to content

Commit

Permalink
fix: binary_mut should work if only one input array has null buffer (#…
Browse files Browse the repository at this point in the history
…6396)

* fix: binary_mut should work if only one input array has null buffer

* Avoid copying null buffer in binary_mut

* Update arrow-arith/src/arity.rs

Co-authored-by: Andrew Lamb <[email protected]>

* Update arrow-arith/src/arity.rs

Co-authored-by: Andrew Lamb <[email protected]>

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
viirya and alamb authored Sep 18, 2024
1 parent d7e8702 commit f5a6382
Showing 1 changed file with 62 additions and 4 deletions.
66 changes: 62 additions & 4 deletions arrow-arith/src/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,6 @@ where
))));
}

let nulls = NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref());

let mut builder = a.into_builder()?;

builder
Expand All @@ -323,7 +321,12 @@ where
.zip(b.values())
.for_each(|(l, r)| *l = op(*l, *r));

let array_builder = builder.finish().into_data().into_builder().nulls(nulls);
let array = builder.finish();

// The builder has the null buffer from `a`, it is not changed.
let nulls = NullBuffer::union(array.logical_nulls().as_ref(), b.logical_nulls().as_ref());

let array_builder = array.into_data().into_builder().nulls(nulls);

let array_data = unsafe { array_builder.build_unchecked() };
Ok(Ok(PrimitiveArray::<T>::from(array_data)))
Expand Down Expand Up @@ -413,7 +416,8 @@ where
try_binary_no_nulls_mut(len, a, b, op)
} else {
let nulls =
NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()).unwrap();
create_union_null_buffer(a.logical_nulls().as_ref(), b.logical_nulls().as_ref())
.unwrap();

let mut builder = a.into_builder()?;

Expand All @@ -435,6 +439,22 @@ where
}
}

/// Computes the union of the nulls in two optional [`NullBuffer`] which
/// is not shared with the input buffers.
///
/// The union of the nulls is the same as `NullBuffer::union(lhs, rhs)` but
/// it does not increase the reference count of the null buffer.
fn create_union_null_buffer(
lhs: Option<&NullBuffer>,
rhs: Option<&NullBuffer>,
) -> Option<NullBuffer> {
match (lhs, rhs) {
(Some(lhs), Some(rhs)) => Some(NullBuffer::new(lhs.inner() & rhs.inner())),
(Some(n), None) | (None, Some(n)) => Some(NullBuffer::new(n.inner() & n.inner())),
(None, None) => None,
}
}

/// This intentional inline(never) attribute helps LLVM optimize the loop.
#[inline(never)]
fn try_binary_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
Expand Down Expand Up @@ -557,6 +577,25 @@ mod tests {
assert_eq!(c, expected);
}

#[test]
fn test_binary_mut_null_buffer() {
let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);

let b = Int32Array::from(vec![Some(10), Some(11), Some(12), Some(13), Some(14)]);

let r1 = binary_mut(a, &b, |a, b| a + b).unwrap();

let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
let b = Int32Array::new(
vec![10, 11, 12, 13, 14].into(),
Some(vec![true, true, true, true, true].into()),
);

// unwrap here means that no copying occured
let r2 = binary_mut(a, &b, |a, b| a + b).unwrap();
assert_eq!(r1.unwrap(), r2.unwrap());
}

#[test]
fn test_try_binary_mut() {
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
Expand Down Expand Up @@ -587,6 +626,25 @@ mod tests {
.expect_err("should got error");
}

#[test]
fn test_try_binary_mut_null_buffer() {
let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);

let b = Int32Array::from(vec![Some(10), Some(11), Some(12), Some(13), Some(14)]);

let r1 = try_binary_mut(a, &b, |a, b| Ok(a + b)).unwrap();

let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
let b = Int32Array::new(
vec![10, 11, 12, 13, 14].into(),
Some(vec![true, true, true, true, true].into()),
);

// unwrap here means that no copying occured
let r2 = try_binary_mut(a, &b, |a, b| Ok(a + b)).unwrap();
assert_eq!(r1.unwrap(), r2.unwrap());
}

#[test]
fn test_unary_dict_mut() {
let values = Int32Array::from(vec![Some(10), Some(20), None]);
Expand Down

0 comments on commit f5a6382

Please sign in to comment.