diff --git a/crates/duckdb/src/core/data_chunk.rs b/crates/duckdb/src/core/data_chunk.rs index 7b6d2e2c..3ef35992 100644 --- a/crates/duckdb/src/core/data_chunk.rs +++ b/crates/duckdb/src/core/data_chunk.rs @@ -26,6 +26,7 @@ impl Drop for DataChunkHandle { } impl DataChunkHandle { + #[allow(dead_code)] pub(crate) unsafe fn new_unowned(ptr: duckdb_data_chunk) -> Self { Self { ptr, owned: false } } diff --git a/crates/duckdb/src/core/vector.rs b/crates/duckdb/src/core/vector.rs index befda697..92e5622a 100644 --- a/crates/duckdb/src/core/vector.rs +++ b/crates/duckdb/src/core/vector.rs @@ -173,6 +173,15 @@ impl ListVector { self.entries.as_mut_slice::()[idx].length = length as u64; } + /// Set row as null + pub fn set_null(&mut self, row: usize) { + unsafe { + duckdb_vector_ensure_validity_writable(self.entries.ptr); + let idx = duckdb_vector_get_validity(self.entries.ptr); + duckdb_validity_set_row_invalid(idx, row as u64); + } + } + /// Reserve the capacity for its child node. fn reserve(&self, capacity: usize) { unsafe { @@ -190,7 +199,6 @@ impl ListVector { /// A array vector. (fixed-size list) pub struct ArrayVector { - /// ArrayVector does not own the vector pointer. ptr: duckdb_vector, } @@ -223,11 +231,19 @@ impl ArrayVector { pub fn set_child(&self, data: &[T]) { self.child(data.len()).copy(data); } + + /// Set row as null + pub fn set_null(&mut self, row: usize) { + unsafe { + duckdb_vector_ensure_validity_writable(self.ptr); + let idx = duckdb_vector_get_validity(self.ptr); + duckdb_validity_set_row_invalid(idx, row as u64); + } + } } /// A struct vector. pub struct StructVector { - /// ListVector does not own the vector pointer. ptr: duckdb_vector, } @@ -277,4 +293,13 @@ impl StructVector { let logical_type = self.logical_type(); unsafe { duckdb_struct_type_child_count(logical_type.ptr) as usize } } + + /// Set row as null + pub fn set_null(&mut self, row: usize) { + unsafe { + duckdb_vector_ensure_validity_writable(self.ptr); + let idx = duckdb_vector_get_validity(self.ptr); + duckdb_validity_set_row_invalid(idx, row as u64); + } + } } diff --git a/crates/duckdb/src/vtab/arrow.rs b/crates/duckdb/src/vtab/arrow.rs index 0dbbd7f5..1e985146 100644 --- a/crates/duckdb/src/vtab/arrow.rs +++ b/crates/duckdb/src/vtab/arrow.rs @@ -268,13 +268,7 @@ pub fn record_batch_to_duckdb_data_chunk( fn primitive_array_to_flat_vector(array: &PrimitiveArray, out_vector: &mut FlatVector) { // assert!(array.len() <= out_vector.capacity()); out_vector.copy::(array.values()); - if let Some(nulls) = array.nulls() { - for (i, null) in nulls.into_iter().enumerate() { - if !null { - out_vector.set_null(i); - } - } - } + set_nulls_in_flat_vector(array, out_vector); } fn primitive_array_to_flat_vector_cast( @@ -285,13 +279,7 @@ fn primitive_array_to_flat_vector_cast( let array = arrow::compute::kernels::cast::cast(array, &data_type).unwrap(); let out_vector: &mut FlatVector = out_vector.as_mut_any().downcast_mut().unwrap(); out_vector.copy::(array.as_primitive::().values()); - if let Some(nulls) = array.nulls() { - for (i, null) in nulls.iter().enumerate() { - if !null { - out_vector.set_null(i); - } - } - } + set_nulls_in_flat_vector(&array, out_vector); } fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result<(), Box> { @@ -441,13 +429,7 @@ fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector, width: } // Set nulls - if let Some(nulls) = array.nulls() { - for (i, null) in nulls.into_iter().enumerate() { - if !null { - out.set_null(i); - } - } - } + set_nulls_in_flat_vector(array, out); } /// Convert Arrow [BooleanArray] to a duckdb vector. @@ -457,6 +439,7 @@ fn boolean_array_to_vector(array: &BooleanArray, out: &mut FlatVector) { for i in 0..array.len() { out.as_mut_slice()[i] = array.value(i); } + set_nulls_in_flat_vector(array, out); } fn string_array_to_vector(array: &GenericStringArray, out: &mut FlatVector) { @@ -467,6 +450,7 @@ fn string_array_to_vector(array: &GenericStringArray, out let s = array.value(i); out.insert(i, s); } + set_nulls_in_flat_vector(array, out); } fn binary_array_to_vector(array: &BinaryArray, out: &mut FlatVector) { @@ -476,6 +460,7 @@ fn binary_array_to_vector(array: &BinaryArray, out: &mut FlatVector) { let s = array.value(i); out.insert(i, s); } + set_nulls_in_flat_vector(array, out); } fn list_array_to_vector>( @@ -504,6 +489,8 @@ fn list_array_to_vector>( let length = array.value_length(i); out.set_entry(i, offset.as_(), length.as_()); } + set_nulls_in_list_vector(array, out); + Ok(()) } @@ -528,6 +515,8 @@ fn fixed_size_list_array_to_vector( } } + set_nulls_in_array_vector(array, out); + Ok(()) } @@ -575,6 +564,7 @@ fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) -> Result } } } + set_nulls_in_struct_vector(array, out); Ok(()) } @@ -611,6 +601,46 @@ pub fn arrow_ffi_to_query_params(array: FFI_ArrowArray, schema: FFI_ArrowSchema) [arr as *mut _ as usize, sch as *mut _ as usize] } +fn set_nulls_in_flat_vector(array: &dyn Array, out_vector: &mut FlatVector) { + if let Some(nulls) = array.nulls() { + for (i, null) in nulls.into_iter().enumerate() { + if !null { + out_vector.set_null(i); + } + } + } +} + +fn set_nulls_in_struct_vector(array: &dyn Array, out_vector: &mut StructVector) { + if let Some(nulls) = array.nulls() { + for (i, null) in nulls.into_iter().enumerate() { + if !null { + out_vector.set_null(i); + } + } + } +} + +fn set_nulls_in_array_vector(array: &dyn Array, out_vector: &mut ArrayVector) { + if let Some(nulls) = array.nulls() { + for (i, null) in nulls.into_iter().enumerate() { + if !null { + out_vector.set_null(i); + } + } + } +} + +fn set_nulls_in_list_vector(array: &dyn Array, out_vector: &mut ListVector) { + if let Some(nulls) = array.nulls() { + for (i, null) in nulls.into_iter().enumerate() { + if !null { + out_vector.set_null(i); + } + } + } +} + #[cfg(test)] mod test { use super::{arrow_recordbatch_to_query_params, ArrowVTab}; @@ -705,6 +735,44 @@ mod test { Ok(()) } + #[test] + fn test_append_struct_contains_null() -> Result<(), Box> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE t1 (s STRUCT(v VARCHAR, i INTEGER))")?; + { + let struct_array = StructArray::try_new( + vec![ + Arc::new(Field::new("v", DataType::Utf8, true)), + Arc::new(Field::new("i", DataType::Int32, true)), + ] + .into(), + vec![ + Arc::new(StringArray::from(vec![Some("foo"), Some("bar")])) as ArrayRef, + Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef, + ], + Some(vec![true, false].into()), + )?; + + let schema = Schema::new(vec![Field::new( + "s", + DataType::Struct(Fields::from(vec![ + Field::new("v", DataType::Utf8, true), + Field::new("i", DataType::Int32, true), + ])), + true, + )]); + + let record_batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)])?; + let mut app = db.appender("t1")?; + app.append_record_batch(record_batch)?; + } + let mut stmt = db.prepare("SELECT s FROM t1 where s IS NOT NULL")?; + let rbs: Vec = stmt.query_arrow([])?.collect(); + assert_eq!(rbs.iter().map(|op| op.num_rows()).sum::(), 1); + + Ok(()) + } + fn check_rust_primitive_array_roundtrip( input_array: PrimitiveArray, expected_array: PrimitiveArray, @@ -762,7 +830,7 @@ mod test { db.register_table_function::("arrow")?; // Roundtrip a record batch from Rust to DuckDB and back to Rust - let schema = Schema::new(vec![Field::new("a", arry.data_type().clone(), false)]); + let schema = Schema::new(vec![Field::new("a", arry.data_type().clone(), true)]); let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arry.clone())])?; let param = arrow_recordbatch_to_query_params(rb); @@ -910,6 +978,24 @@ mod test { Ok(()) } + #[test] + fn test_check_generic_array_roundtrip_contains_null() -> Result<(), Box> { + check_generic_array_roundtrip(ListArray::new( + Arc::new(Field::new("item", DataType::Utf8, true)), + OffsetBuffer::new(ScalarBuffer::from(vec![0, 2, 4, 5])), + Arc::new(StringArray::from(vec![ + Some("foo"), + Some("baz"), + Some("bar"), + Some("foo"), + Some("baz"), + ])), + Some(vec![true, false, true].into()), + ))?; + + Ok(()) + } + #[test] fn test_utf8_roundtrip() -> Result<(), Box> { check_generic_byte_roundtrip( diff --git a/crates/libduckdb-sys/Cargo.toml b/crates/libduckdb-sys/Cargo.toml index d02c0ea0..8e6804bd 100644 --- a/crates/libduckdb-sys/Cargo.toml +++ b/crates/libduckdb-sys/Cargo.toml @@ -22,6 +22,7 @@ buildtime_bindgen = ["bindgen", "pkg-config", "vcpkg"] json = ["bundled"] parquet = ["bundled"] extensions-full = ["json", "parquet"] +winduckdb = [] [dependencies]