Skip to content

Commit

Permalink
Merge branch 'union_logical_nulls' of https://github.com/gstvg/arrow-rs
Browse files Browse the repository at this point in the history
… into union_logical_nulls
  • Loading branch information
gstvg committed Sep 30, 2024
2 parents b51bc86 + 0b9a443 commit 74a3b20
Show file tree
Hide file tree
Showing 80 changed files with 1,943 additions and 406 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ on:
- arrow/**

jobs:

integration:
name: Archery test With other arrows
runs-on: ubuntu-latest
Expand All @@ -65,6 +64,7 @@ jobs:
ARROW_INTEGRATION_GO: ON
ARROW_INTEGRATION_JAVA: ON
ARROW_INTEGRATION_JS: ON
ARCHERY_INTEGRATION_TARGET_IMPLEMENTATIONS: "rust"
ARCHERY_INTEGRATION_WITH_NANOARROW: "1"
# https://github.com/apache/arrow/pull/38403/files#r1371281630
ARCHERY_INTEGRATION_WITH_RUST: "1"
Expand Down Expand Up @@ -118,9 +118,9 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
rust: [ stable ]
# PyArrow 13 was the last version prior to introduction to Arrow PyCapsules
pyarrow: [ "13", "14" ]
rust: [stable]
# PyArrow 15 was the first version to introduce StringView/BinaryView support
pyarrow: ["15", "16", "17"]
steps:
- uses: actions/checkout@v4
with:
Expand Down
83 changes: 73 additions & 10 deletions arrow-arith/src/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,6 @@ where
))));
}

let nulls = NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref());

let mut builder = a.into_builder()?;

builder
Expand All @@ -323,14 +321,21 @@ where
.zip(b.values())
.for_each(|(l, r)| *l = op(*l, *r));

let array_builder = builder.finish().into_data().into_builder().nulls(nulls);
let array = builder.finish();

// The builder has the null buffer from `a`, it is not changed.
let nulls = NullBuffer::union(array.logical_nulls().as_ref(), b.logical_nulls().as_ref());

let array_builder = array.into_data().into_builder().nulls(nulls);

let array_data = unsafe { array_builder.build_unchecked() };
Ok(Ok(PrimitiveArray::<T>::from(array_data)))
}

/// Applies the provided fallible binary operation across `a` and `b`, returning any error,
/// and collecting the results into a [`PrimitiveArray`]. If any index is null in either `a`
/// Applies the provided fallible binary operation across `a` and `b`.
///
/// This will return any error encountered, or collect the results into
/// a [`PrimitiveArray`]. If any index is null in either `a`
/// or `b`, the corresponding index in the result will also be null
///
/// Like [`try_unary`] the function is only evaluated for non-null indices
Expand Down Expand Up @@ -381,12 +386,15 @@ where
}

/// Applies the provided fallible binary operation across `a` and `b` by mutating the mutable
/// [`PrimitiveArray`] `a` with the results, returning any error. If any index is null in
/// either `a` or `b`, the corresponding index in the result will also be null
/// [`PrimitiveArray`] `a` with the results.
///
/// Like [`try_unary`] the function is only evaluated for non-null indices
/// Returns any error encountered, or collects the results into a [`PrimitiveArray`] as return
/// value. If any index is null in either `a` or `b`, the corresponding index in the result will
/// also be null.
///
/// Like [`try_unary`] the function is only evaluated for non-null indices.
///
/// See [`binary_mut`] for errors and buffer reuse information
/// See [`binary_mut`] for errors and buffer reuse information.
pub fn try_binary_mut<T, F>(
a: PrimitiveArray<T>,
b: &PrimitiveArray<T>,
Expand All @@ -413,7 +421,8 @@ where
try_binary_no_nulls_mut(len, a, b, op)
} else {
let nulls =
NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()).unwrap();
create_union_null_buffer(a.logical_nulls().as_ref(), b.logical_nulls().as_ref())
.unwrap();

let mut builder = a.into_builder()?;

Expand All @@ -435,6 +444,22 @@ where
}
}

/// Computes the union of the nulls in two optional [`NullBuffer`] which
/// is not shared with the input buffers.
///
/// The union of the nulls is the same as `NullBuffer::union(lhs, rhs)` but
/// it does not increase the reference count of the null buffer.
fn create_union_null_buffer(
lhs: Option<&NullBuffer>,
rhs: Option<&NullBuffer>,
) -> Option<NullBuffer> {
match (lhs, rhs) {
(Some(lhs), Some(rhs)) => Some(NullBuffer::new(lhs.inner() & rhs.inner())),
(Some(n), None) | (None, Some(n)) => Some(NullBuffer::new(n.inner() & n.inner())),
(None, None) => None,
}
}

/// This intentional inline(never) attribute helps LLVM optimize the loop.
#[inline(never)]
fn try_binary_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
Expand Down Expand Up @@ -557,6 +582,25 @@ mod tests {
assert_eq!(c, expected);
}

#[test]
fn test_binary_mut_null_buffer() {
let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);

let b = Int32Array::from(vec![Some(10), Some(11), Some(12), Some(13), Some(14)]);

let r1 = binary_mut(a, &b, |a, b| a + b).unwrap();

let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
let b = Int32Array::new(
vec![10, 11, 12, 13, 14].into(),
Some(vec![true, true, true, true, true].into()),
);

// unwrap here means that no copying occured
let r2 = binary_mut(a, &b, |a, b| a + b).unwrap();
assert_eq!(r1.unwrap(), r2.unwrap());
}

#[test]
fn test_try_binary_mut() {
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
Expand Down Expand Up @@ -587,6 +631,25 @@ mod tests {
.expect_err("should got error");
}

#[test]
fn test_try_binary_mut_null_buffer() {
let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);

let b = Int32Array::from(vec![Some(10), Some(11), Some(12), Some(13), Some(14)]);

let r1 = try_binary_mut(a, &b, |a, b| Ok(a + b)).unwrap();

let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
let b = Int32Array::new(
vec![10, 11, 12, 13, 14].into(),
Some(vec![true, true, true, true, true].into()),
);

// unwrap here means that no copying occured
let r2 = try_binary_mut(a, &b, |a, b| Ok(a + b)).unwrap();
assert_eq!(r1.unwrap(), r2.unwrap());
}

#[test]
fn test_unary_dict_mut() {
let values = Int32Array::from(vec![Some(10), Some(20), None]);
Expand Down
2 changes: 2 additions & 0 deletions arrow-arith/src/bitwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.

//! Module contains bitwise operations on arrays

use crate::arity::{binary, unary};
use arrow_array::*;
use arrow_buffer::ArrowNativeType;
Expand Down
1 change: 1 addition & 0 deletions arrow-arith/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Arrow arithmetic and aggregation kernels

#![warn(missing_docs)]
pub mod aggregate;
#[doc(hidden)] // Kernels to be removed in a future release
pub mod arithmetic;
Expand Down
8 changes: 6 additions & 2 deletions arrow-arith/src/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ impl<T: Datelike> ChronoDateExt for T {

/// Parse the given string into a string representing fixed-offset that is correct as of the given
/// UTC NaiveDateTime.
///
/// Note that the offset is function of time and can vary depending on whether daylight savings is
/// in effect or not. e.g. Australia/Sydney is +10:00 or +11:00 depending on DST.
#[deprecated(note = "Use arrow_array::timezone::Tz instead")]
Expand Down Expand Up @@ -811,6 +812,7 @@ where
}

/// Extracts the day of a given temporal array as an array of integers.
///
/// If the given array isn't temporal primitive or dictionary array,
/// an `Err` will be returned.
#[deprecated(since = "51.0.0", note = "Use `date_part` instead")]
Expand All @@ -828,7 +830,8 @@ where
date_part_primitive(array, DatePart::Day)
}

/// Extracts the day of year of a given temporal array as an array of integers
/// Extracts the day of year of a given temporal array as an array of integers.
///
/// The day of year that ranges from 1 to 366.
/// If the given array isn't temporal primitive or dictionary array,
/// an `Err` will be returned.
Expand All @@ -837,7 +840,8 @@ pub fn doy_dyn(array: &dyn Array) -> Result<ArrayRef, ArrowError> {
date_part(array, DatePart::DayOfYear)
}

/// Extracts the day of year of a given temporal primitive array as an array of integers
/// Extracts the day of year of a given temporal primitive array as an array of integers.
///
/// The day of year that ranges from 1 to 366
#[deprecated(since = "51.0.0", note = "Use `date_part` instead")]
pub fn doy<T>(array: &PrimitiveArray<T>) -> Result<Int32Array, ArrowError>
Expand Down
4 changes: 4 additions & 0 deletions arrow-array/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ harness = false
name = "fixed_size_list_array"
harness = false

[[bench]]
name = "decimal_overflow"
harness = false

[[bench]]
name = "union_array"
harness = false
53 changes: 53 additions & 0 deletions arrow-array/benches/decimal_overflow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow_array::builder::{Decimal128Builder, Decimal256Builder};
use arrow_buffer::i256;
use criterion::*;

fn criterion_benchmark(c: &mut Criterion) {
let len = 8192;
let mut builder_128 = Decimal128Builder::with_capacity(len);
let mut builder_256 = Decimal256Builder::with_capacity(len);
for i in 0..len {
if i % 10 == 0 {
builder_128.append_value(i128::MAX);
builder_256.append_value(i256::from_i128(i128::MAX));
} else {
builder_128.append_value(i as i128);
builder_256.append_value(i256::from_i128(i as i128));
}
}
let array_128 = builder_128.finish();
let array_256 = builder_256.finish();

c.bench_function("validate_decimal_precision_128", |b| {
b.iter(|| black_box(array_128.validate_decimal_precision(8)));
});
c.bench_function("null_if_overflow_precision_128", |b| {
b.iter(|| black_box(array_128.null_if_overflow_precision(8)));
});
c.bench_function("validate_decimal_precision_256", |b| {
b.iter(|| black_box(array_256.validate_decimal_precision(8)));
});
c.bench_function("null_if_overflow_precision_256", |b| {
b.iter(|| black_box(array_256.null_if_overflow_precision(8)));
});
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
4 changes: 2 additions & 2 deletions arrow-array/src/array/binary_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl<OffsetSize: OffsetSizeTrait> GenericBinaryArray<OffsetSize> {
pub fn take_iter<'a>(
&'a self,
indexes: impl Iterator<Item = Option<usize>> + 'a,
) -> impl Iterator<Item = Option<&[u8]>> + 'a {
) -> impl Iterator<Item = Option<&'a [u8]>> {
indexes.map(|opt_index| opt_index.map(|index| self.value(index)))
}

Expand All @@ -95,7 +95,7 @@ impl<OffsetSize: OffsetSizeTrait> GenericBinaryArray<OffsetSize> {
pub unsafe fn take_iter_unchecked<'a>(
&'a self,
indexes: impl Iterator<Item = Option<usize>> + 'a,
) -> impl Iterator<Item = Option<&[u8]>> + 'a {
) -> impl Iterator<Item = Option<&'a [u8]>> {
indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index)))
}
}
Expand Down
4 changes: 1 addition & 3 deletions arrow-array/src/array/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1570,9 +1570,7 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
/// Validates the Decimal Array, if the value of slot is overflow for the specified precision, and
/// will be casted to Null
pub fn null_if_overflow_precision(&self, precision: u8) -> Self {
self.unary_opt::<_, T>(|v| {
(T::validate_decimal_precision(v, precision).is_ok()).then_some(v)
})
self.unary_opt::<_, T>(|v| T::is_valid_decimal_precision(v, precision).then_some(v))
}

/// Returns [`Self::value`] formatted as a string
Expand Down
4 changes: 2 additions & 2 deletions arrow-array/src/array/string_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl<OffsetSize: OffsetSizeTrait> GenericStringArray<OffsetSize> {
pub fn take_iter<'a>(
&'a self,
indexes: impl Iterator<Item = Option<usize>> + 'a,
) -> impl Iterator<Item = Option<&str>> + 'a {
) -> impl Iterator<Item = Option<&'a str>> {
indexes.map(|opt_index| opt_index.map(|index| self.value(index)))
}

Expand All @@ -53,7 +53,7 @@ impl<OffsetSize: OffsetSizeTrait> GenericStringArray<OffsetSize> {
pub unsafe fn take_iter_unchecked<'a>(
&'a self,
indexes: impl Iterator<Item = Option<usize>> + 'a,
) -> impl Iterator<Item = Option<&str>> + 'a {
) -> impl Iterator<Item = Option<&'a str>> {
indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index)))
}

Expand Down
Loading

0 comments on commit 74a3b20

Please sign in to comment.