diff --git a/arrow-array/src/builder/generic_list_builder.rs b/arrow-array/src/builder/generic_list_builder.rs index 116e2553cfb7..b857224c5da6 100644 --- a/arrow-array/src/builder/generic_list_builder.rs +++ b/arrow-array/src/builder/generic_list_builder.rs @@ -17,10 +17,9 @@ use crate::builder::{ArrayBuilder, BufferBuilder}; use crate::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; -use arrow_buffer::Buffer; use arrow_buffer::NullBufferBuilder; -use arrow_data::ArrayData; -use arrow_schema::Field; +use arrow_buffer::{Buffer, OffsetBuffer}; +use arrow_schema::{Field, FieldRef}; use std::any::Any; use std::sync::Arc; @@ -92,6 +91,7 @@ pub struct GenericListBuilder { offsets_builder: BufferBuilder, null_buffer_builder: NullBufferBuilder, values_builder: T, + field: Option, } impl Default for GenericListBuilder { @@ -116,6 +116,20 @@ impl GenericListBuilder) -> Self { + Self { + field: Some(field.into()), + ..self } } } @@ -275,53 +289,37 @@ where /// Builds the [`GenericListArray`] and reset this builder. pub fn finish(&mut self) -> GenericListArray { - let len = self.len(); - let values_arr = self.values_builder.finish(); - let values_data = values_arr.to_data(); + let values = self.values_builder.finish(); + let nulls = self.null_buffer_builder.finish(); - let offset_buffer = self.offsets_builder.finish(); - let null_bit_buffer = self.null_buffer_builder.finish(); + let offsets = self.offsets_builder.finish(); + // Safety: Safe by construction + let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; self.offsets_builder.append(OffsetSize::zero()); - let field = Arc::new(Field::new( - "item", - values_data.data_type().clone(), - true, // TODO: find a consistent way of getting this - )); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(field); - let array_data_builder = ArrayData::builder(data_type) - .len(len) - .add_buffer(offset_buffer) - .add_child_data(values_data) - .nulls(null_bit_buffer); - let array_data = unsafe { array_data_builder.build_unchecked() }; + let field = match &self.field { + Some(f) => f.clone(), + None => Arc::new(Field::new("item", values.data_type().clone(), true)), + }; - GenericListArray::::from(array_data) + GenericListArray::new(field, offsets, values, nulls) } /// Builds the [`GenericListArray`] without resetting the builder. pub fn finish_cloned(&self) -> GenericListArray { - let len = self.len(); - let values_arr = self.values_builder.finish_cloned(); - let values_data = values_arr.to_data(); - - let offset_buffer = Buffer::from_slice_ref(self.offsets_builder.as_slice()); + let values = self.values_builder.finish_cloned(); let nulls = self.null_buffer_builder.finish_cloned(); - let field = Arc::new(Field::new( - "item", - values_data.data_type().clone(), - true, // TODO: find a consistent way of getting this - )); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(field); - let array_data_builder = ArrayData::builder(data_type) - .len(len) - .add_buffer(offset_buffer) - .add_child_data(values_data) - .nulls(nulls); - let array_data = unsafe { array_data_builder.build_unchecked() }; + let offsets = Buffer::from_slice_ref(self.offsets_builder.as_slice()); + // Safety: safe by construction + let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; + + let field = match &self.field { + Some(f) => f.clone(), + None => Arc::new(Field::new("item", values.data_type().clone(), true)), + }; - GenericListArray::::from(array_data) + GenericListArray::new(field, offsets, values, nulls) } /// Returns the current offsets buffer as a slice @@ -765,4 +763,39 @@ mod tests { assert_eq!(0, i1.null_count()); assert_eq!(i1.values(), &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); } + + #[test] + fn test_with_field() { + let field = Arc::new(Field::new("bar", DataType::Int32, false)); + let mut builder = ListBuilder::new(Int32Builder::new()).with_field(field.clone()); + builder.append_value([Some(1), Some(2), Some(3)]); + builder.append_null(); // This is fine as nullability refers to nullability of values + builder.append_value([Some(4)]); + let array = builder.finish(); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::List(field.clone())); + + builder.append_value([Some(4), Some(5)]); + let array = builder.finish(); + assert_eq!(array.data_type(), &DataType::List(field)); + assert_eq!(array.len(), 1); + } + + #[test] + #[should_panic(expected = "Non-nullable field of ListArray \\\"item\\\" cannot contain nulls")] + fn test_checks_nullability() { + let field = Arc::new(Field::new("item", DataType::Int32, false)); + let mut builder = ListBuilder::new(Int32Builder::new()).with_field(field.clone()); + builder.append_value([Some(1), None]); + builder.finish(); + } + + #[test] + #[should_panic(expected = "ListArray expected data type Int64 got Int32")] + fn test_checks_data_type() { + let field = Arc::new(Field::new("item", DataType::Int64, false)); + let mut builder = ListBuilder::new(Int32Builder::new()).with_field(field.clone()); + builder.append_value([Some(1)]); + builder.finish(); + } }