Skip to content

Commit

Permalink
Add ListBuilder::with_field to support non nullable list fields (#5330)…
Browse files Browse the repository at this point in the history
… (#5331)

* Add ListBuilder::with_field (#5330)

* Tweak docs

* Review feedback
  • Loading branch information
tustvold committed Jan 25, 2024
1 parent 5146419 commit 8fff5e4
Showing 1 changed file with 72 additions and 39 deletions.
111 changes: 72 additions & 39 deletions arrow-array/src/builder/generic_list_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@

use crate::builder::{ArrayBuilder, BufferBuilder};
use crate::{Array, ArrayRef, GenericListArray, OffsetSizeTrait};
use arrow_buffer::Buffer;
use arrow_buffer::NullBufferBuilder;
use arrow_data::ArrayData;
use arrow_schema::Field;
use arrow_buffer::{Buffer, OffsetBuffer};
use arrow_schema::{Field, FieldRef};
use std::any::Any;
use std::sync::Arc;

Expand Down Expand Up @@ -92,6 +91,7 @@ pub struct GenericListBuilder<OffsetSize: OffsetSizeTrait, T: ArrayBuilder> {
offsets_builder: BufferBuilder<OffsetSize>,
null_buffer_builder: NullBufferBuilder,
values_builder: T,
field: Option<FieldRef>,
}

impl<O: OffsetSizeTrait, T: ArrayBuilder + Default> Default for GenericListBuilder<O, T> {
Expand All @@ -116,6 +116,20 @@ impl<OffsetSize: OffsetSizeTrait, T: ArrayBuilder> GenericListBuilder<OffsetSize
offsets_builder,
null_buffer_builder: NullBufferBuilder::new(capacity),
values_builder,
field: None,
}
}

/// Override the field passed to [`GenericListArray::new`]
///
/// By default a nullable field is created with the name `item`
///
/// Note: [`Self::finish`] and [`Self::finish_cloned`] will panic if the
/// field's data type does not match that of `T`
pub fn with_field(self, field: impl Into<FieldRef>) -> Self {
Self {
field: Some(field.into()),
..self
}
}
}
Expand Down Expand Up @@ -275,53 +289,37 @@ where

/// Builds the [`GenericListArray`] and reset this builder.
pub fn finish(&mut self) -> GenericListArray<OffsetSize> {
let len = self.len();
let values_arr = self.values_builder.finish();
let values_data = values_arr.to_data();
let values = self.values_builder.finish();
let nulls = self.null_buffer_builder.finish();

let offset_buffer = self.offsets_builder.finish();
let null_bit_buffer = self.null_buffer_builder.finish();
let offsets = self.offsets_builder.finish();
// Safety: Safe by construction
let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) };
self.offsets_builder.append(OffsetSize::zero());
let field = Arc::new(Field::new(
"item",
values_data.data_type().clone(),
true, // TODO: find a consistent way of getting this
));
let data_type = GenericListArray::<OffsetSize>::DATA_TYPE_CONSTRUCTOR(field);
let array_data_builder = ArrayData::builder(data_type)
.len(len)
.add_buffer(offset_buffer)
.add_child_data(values_data)
.nulls(null_bit_buffer);

let array_data = unsafe { array_data_builder.build_unchecked() };
let field = match &self.field {
Some(f) => f.clone(),
None => Arc::new(Field::new("item", values.data_type().clone(), true)),
};

GenericListArray::<OffsetSize>::from(array_data)
GenericListArray::new(field, offsets, values, nulls)
}

/// Builds the [`GenericListArray`] without resetting the builder.
pub fn finish_cloned(&self) -> GenericListArray<OffsetSize> {
let len = self.len();
let values_arr = self.values_builder.finish_cloned();
let values_data = values_arr.to_data();

let offset_buffer = Buffer::from_slice_ref(self.offsets_builder.as_slice());
let values = self.values_builder.finish_cloned();
let nulls = self.null_buffer_builder.finish_cloned();
let field = Arc::new(Field::new(
"item",
values_data.data_type().clone(),
true, // TODO: find a consistent way of getting this
));
let data_type = GenericListArray::<OffsetSize>::DATA_TYPE_CONSTRUCTOR(field);
let array_data_builder = ArrayData::builder(data_type)
.len(len)
.add_buffer(offset_buffer)
.add_child_data(values_data)
.nulls(nulls);

let array_data = unsafe { array_data_builder.build_unchecked() };
let offsets = Buffer::from_slice_ref(self.offsets_builder.as_slice());
// Safety: safe by construction
let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) };

let field = match &self.field {
Some(f) => f.clone(),
None => Arc::new(Field::new("item", values.data_type().clone(), true)),
};

GenericListArray::<OffsetSize>::from(array_data)
GenericListArray::new(field, offsets, values, nulls)
}

/// Returns the current offsets buffer as a slice
Expand Down Expand Up @@ -765,4 +763,39 @@ mod tests {
assert_eq!(0, i1.null_count());
assert_eq!(i1.values(), &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
}

#[test]
fn test_with_field() {
let field = Arc::new(Field::new("bar", DataType::Int32, false));
let mut builder = ListBuilder::new(Int32Builder::new()).with_field(field.clone());
builder.append_value([Some(1), Some(2), Some(3)]);
builder.append_null(); // This is fine as nullability refers to nullability of values
builder.append_value([Some(4)]);
let array = builder.finish();
assert_eq!(array.len(), 3);
assert_eq!(array.data_type(), &DataType::List(field.clone()));

builder.append_value([Some(4), Some(5)]);
let array = builder.finish();
assert_eq!(array.data_type(), &DataType::List(field));
assert_eq!(array.len(), 1);
}

#[test]
#[should_panic(expected = "Non-nullable field of ListArray \\\"item\\\" cannot contain nulls")]
fn test_checks_nullability() {
let field = Arc::new(Field::new("item", DataType::Int32, false));
let mut builder = ListBuilder::new(Int32Builder::new()).with_field(field.clone());
builder.append_value([Some(1), None]);
builder.finish();
}

#[test]
#[should_panic(expected = "ListArray expected data type Int64 got Int32")]
fn test_checks_data_type() {
let field = Arc::new(Field::new("item", DataType::Int64, false));
let mut builder = ListBuilder::new(Int32Builder::new()).with_field(field.clone());
builder.append_value([Some(1)]);
builder.finish();
}
}

0 comments on commit 8fff5e4

Please sign in to comment.