Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set nulls correctly for all type of arrays/vectors #344

Merged
merged 8 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/duckdb/src/core/data_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ impl Drop for DataChunkHandle {
}

impl DataChunkHandle {
#[allow(dead_code)]
pub(crate) unsafe fn new_unowned(ptr: duckdb_data_chunk) -> Self {
Self { ptr, owned: false }
}
Expand Down
29 changes: 27 additions & 2 deletions crates/duckdb/src/core/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ impl ListVector {
self.entries.as_mut_slice::<duckdb_list_entry>()[idx].length = length as u64;
}

/// Set row as null
pub fn set_null(&mut self, row: usize) {
unsafe {
duckdb_vector_ensure_validity_writable(self.entries.ptr);
let idx = duckdb_vector_get_validity(self.entries.ptr);
duckdb_validity_set_row_invalid(idx, row as u64);
}
}

/// Reserve the capacity for its child node.
fn reserve(&self, capacity: usize) {
unsafe {
Expand All @@ -190,7 +199,6 @@ impl ListVector {

/// A array vector. (fixed-size list)
pub struct ArrayVector {
/// ArrayVector does not own the vector pointer.
ptr: duckdb_vector,
}

Expand Down Expand Up @@ -223,11 +231,19 @@ impl ArrayVector {
pub fn set_child<T: Copy>(&self, data: &[T]) {
self.child(data.len()).copy(data);
}

/// Set row as null
pub fn set_null(&mut self, row: usize) {
unsafe {
duckdb_vector_ensure_validity_writable(self.ptr);
let idx = duckdb_vector_get_validity(self.ptr);
duckdb_validity_set_row_invalid(idx, row as u64);
}
}
}

/// A struct vector.
pub struct StructVector {
/// ListVector does not own the vector pointer.
ptr: duckdb_vector,
}

Expand Down Expand Up @@ -277,4 +293,13 @@ impl StructVector {
let logical_type = self.logical_type();
unsafe { duckdb_struct_type_child_count(logical_type.ptr) as usize }
}

/// Set row as null
pub fn set_null(&mut self, row: usize) {
unsafe {
duckdb_vector_ensure_validity_writable(self.ptr);
let idx = duckdb_vector_get_validity(self.ptr);
duckdb_validity_set_row_invalid(idx, row as u64);
}
}
}
130 changes: 108 additions & 22 deletions crates/duckdb/src/vtab/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,7 @@ pub fn record_batch_to_duckdb_data_chunk(
fn primitive_array_to_flat_vector<T: ArrowPrimitiveType>(array: &PrimitiveArray<T>, out_vector: &mut FlatVector) {
// assert!(array.len() <= out_vector.capacity());
out_vector.copy::<T::Native>(array.values());
if let Some(nulls) = array.nulls() {
for (i, null) in nulls.into_iter().enumerate() {
if !null {
out_vector.set_null(i);
}
}
}
set_nulls_in_flat_vector(array, out_vector);
}

fn primitive_array_to_flat_vector_cast<T: ArrowPrimitiveType>(
Expand All @@ -285,13 +279,7 @@ fn primitive_array_to_flat_vector_cast<T: ArrowPrimitiveType>(
let array = arrow::compute::kernels::cast::cast(array, &data_type).unwrap();
let out_vector: &mut FlatVector = out_vector.as_mut_any().downcast_mut().unwrap();
out_vector.copy::<T::Native>(array.as_primitive::<T>().values());
if let Some(nulls) = array.nulls() {
for (i, null) in nulls.iter().enumerate() {
if !null {
out_vector.set_null(i);
}
}
}
set_nulls_in_flat_vector(&array, out_vector);
}

fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result<(), Box<dyn std::error::Error>> {
Expand Down Expand Up @@ -441,13 +429,7 @@ fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector, width:
}

// Set nulls
if let Some(nulls) = array.nulls() {
for (i, null) in nulls.into_iter().enumerate() {
if !null {
out.set_null(i);
}
}
}
set_nulls_in_flat_vector(array, out);
}

/// Convert Arrow [BooleanArray] to a duckdb vector.
Expand All @@ -457,6 +439,7 @@ fn boolean_array_to_vector(array: &BooleanArray, out: &mut FlatVector) {
for i in 0..array.len() {
out.as_mut_slice()[i] = array.value(i);
}
set_nulls_in_flat_vector(array, out);
}

fn string_array_to_vector<O: OffsetSizeTrait>(array: &GenericStringArray<O>, out: &mut FlatVector) {
Expand All @@ -467,6 +450,7 @@ fn string_array_to_vector<O: OffsetSizeTrait>(array: &GenericStringArray<O>, out
let s = array.value(i);
out.insert(i, s);
}
set_nulls_in_flat_vector(array, out);
}

fn binary_array_to_vector(array: &BinaryArray, out: &mut FlatVector) {
Expand All @@ -476,6 +460,7 @@ fn binary_array_to_vector(array: &BinaryArray, out: &mut FlatVector) {
let s = array.value(i);
out.insert(i, s);
}
set_nulls_in_flat_vector(array, out);
}

fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(
Expand Down Expand Up @@ -504,6 +489,8 @@ fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(
let length = array.value_length(i);
out.set_entry(i, offset.as_(), length.as_());
}
set_nulls_in_list_vector(array, out);

Ok(())
}

Expand All @@ -528,6 +515,8 @@ fn fixed_size_list_array_to_vector(
}
}

set_nulls_in_array_vector(array, out);

Ok(())
}

Expand Down Expand Up @@ -575,6 +564,7 @@ fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) -> Result
}
}
}
set_nulls_in_struct_vector(array, out);
Ok(())
}

Expand Down Expand Up @@ -611,6 +601,46 @@ pub fn arrow_ffi_to_query_params(array: FFI_ArrowArray, schema: FFI_ArrowSchema)
[arr as *mut _ as usize, sch as *mut _ as usize]
}

fn set_nulls_in_flat_vector(array: &dyn Array, out_vector: &mut FlatVector) {
if let Some(nulls) = array.nulls() {
for (i, null) in nulls.into_iter().enumerate() {
if !null {
out_vector.set_null(i);
}
}
}
}

fn set_nulls_in_struct_vector(array: &dyn Array, out_vector: &mut StructVector) {
if let Some(nulls) = array.nulls() {
for (i, null) in nulls.into_iter().enumerate() {
if !null {
out_vector.set_null(i);
}
}
}
}

fn set_nulls_in_array_vector(array: &dyn Array, out_vector: &mut ArrayVector) {
if let Some(nulls) = array.nulls() {
for (i, null) in nulls.into_iter().enumerate() {
if !null {
out_vector.set_null(i);
}
}
}
}

fn set_nulls_in_list_vector(array: &dyn Array, out_vector: &mut ListVector) {
if let Some(nulls) = array.nulls() {
for (i, null) in nulls.into_iter().enumerate() {
if !null {
out_vector.set_null(i);
}
}
}
}

#[cfg(test)]
mod test {
use super::{arrow_recordbatch_to_query_params, ArrowVTab};
Expand Down Expand Up @@ -705,6 +735,44 @@ mod test {
Ok(())
}

#[test]
fn test_append_struct_contains_null() -> Result<(), Box<dyn Error>> {
let db = Connection::open_in_memory()?;
db.execute_batch("CREATE TABLE t1 (s STRUCT(v VARCHAR, i INTEGER))")?;
{
let struct_array = StructArray::try_new(
vec![
Arc::new(Field::new("v", DataType::Utf8, true)),
Arc::new(Field::new("i", DataType::Int32, true)),
]
.into(),
vec![
Arc::new(StringArray::from(vec![Some("foo"), Some("bar")])) as ArrayRef,
Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef,
],
Some(vec![true, false].into()),
)?;

let schema = Schema::new(vec![Field::new(
"s",
DataType::Struct(Fields::from(vec![
Field::new("v", DataType::Utf8, true),
Field::new("i", DataType::Int32, true),
])),
true,
)]);

let record_batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)])?;
let mut app = db.appender("t1")?;
app.append_record_batch(record_batch)?;
}
let mut stmt = db.prepare("SELECT s FROM t1 where s IS NOT NULL")?;
let rbs: Vec<RecordBatch> = stmt.query_arrow([])?.collect();
assert_eq!(rbs.iter().map(|op| op.num_rows()).sum::<usize>(), 1);

Ok(())
}

fn check_rust_primitive_array_roundtrip<T1, T2>(
input_array: PrimitiveArray<T1>,
expected_array: PrimitiveArray<T2>,
Expand Down Expand Up @@ -762,7 +830,7 @@ mod test {
db.register_table_function::<ArrowVTab>("arrow")?;

// Roundtrip a record batch from Rust to DuckDB and back to Rust
let schema = Schema::new(vec![Field::new("a", arry.data_type().clone(), false)]);
let schema = Schema::new(vec![Field::new("a", arry.data_type().clone(), true)]);

let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arry.clone())])?;
let param = arrow_recordbatch_to_query_params(rb);
Expand Down Expand Up @@ -910,6 +978,24 @@ mod test {
Ok(())
}

#[test]
fn test_check_generic_array_roundtrip_contains_null() -> Result<(), Box<dyn Error>> {
check_generic_array_roundtrip(ListArray::new(
Arc::new(Field::new("item", DataType::Utf8, true)),
OffsetBuffer::new(ScalarBuffer::from(vec![0, 2, 4, 5])),
Arc::new(StringArray::from(vec![
Some("foo"),
Some("baz"),
Some("bar"),
Some("foo"),
Some("baz"),
])),
Some(vec![true, false, true].into()),
))?;

Ok(())
}

#[test]
fn test_utf8_roundtrip() -> Result<(), Box<dyn Error>> {
check_generic_byte_roundtrip(
Expand Down
1 change: 1 addition & 0 deletions crates/libduckdb-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ buildtime_bindgen = ["bindgen", "pkg-config", "vcpkg"]
json = ["bundled"]
parquet = ["bundled"]
extensions-full = ["json", "parquet"]
winduckdb = []

[dependencies]

Expand Down
Loading