Skip to content

Commit

Permalink
Add support for DuckDB arrays when using Arrow's FixedSizeList (#323)
Browse files Browse the repository at this point in the history
* support UTF8[]

* add tests

* fix test

* format

* clippy

* bump cause github is broken

* add support for DuckDB arrays when using Arrow's FixedSizeList

* fmt

* add ArrayVector

* update path in remote test

---------

Co-authored-by: Max Gabrielsson <[email protected]>
  • Loading branch information
Jeadie and Maxxen authored Jun 4, 2024
1 parent 74fce0f commit f628e5a
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 25 deletions.
2 changes: 1 addition & 1 deletion crates/duckdb/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ mod test {
let db = Connection::open_in_memory()?;
assert_eq!(
300f32,
db.query_row::<f32, _, _>(r#"SELECT SUM(value) FROM read_parquet('https://github.com/wangfenjin/duckdb-rs/raw/main/examples/int32_decimal.parquet');"#, [], |r| r.get(0))?
db.query_row::<f32, _, _>(r#"SELECT SUM(value) FROM read_parquet('https://github.com/duckdb/duckdb-rs/raw/main/crates/duckdb/examples/int32_decimal.parquet');"#, [], |r| r.get(0))?
);
Ok(())
}
Expand Down
81 changes: 58 additions & 23 deletions crates/duckdb/src/vtab/arrow.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{
vector::{FlatVector, ListVector, Vector},
vector::{ArrayVector, FlatVector, ListVector, Vector},
BindInfo, DataChunk, Free, FunctionInfo, InitInfo, LogicalType, LogicalTypeId, StructVector, VTab,
};
use std::ptr::null_mut;
Expand Down Expand Up @@ -196,8 +196,11 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalType, Box<d
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
} else if let DataType::LargeList(child) = data_type {
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
} else if let DataType::FixedSizeList(child, _) = data_type {
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
} else if let DataType::FixedSizeList(child, array_size) = data_type {
Ok(LogicalType::array(
&to_duckdb_logical_type(child.data_type())?,
*array_size as u64,
))
} else {
Err(
format!("Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs")
Expand Down Expand Up @@ -234,7 +237,7 @@ pub fn record_batch_to_duckdb_data_chunk(
list_array_to_vector(as_large_list_array(col.as_ref()), &mut chunk.list_vector(i))?;
}
DataType::FixedSizeList(_, _) => {
fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.list_vector(i))?;
fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.array_vector(i))?;
}
DataType::Struct(_) => {
let struct_array = as_struct_array(col.as_ref());
Expand Down Expand Up @@ -455,33 +458,21 @@ fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(

fn fixed_size_list_array_to_vector(
array: &FixedSizeListArray,
out: &mut ListVector,
out: &mut ArrayVector,
) -> Result<(), Box<dyn std::error::Error>> {
let value_array = array.values();
let mut child = out.child(value_array.len());
match value_array.data_type() {
dt if dt.is_primitive() => {
primitive_array_to_vector(value_array.as_ref(), &mut child)?;
for i in 0..array.len() {
let offset = array.value_offset(i);
let length = array.value_length();
out.set_entry(i, offset as usize, length as usize);
}
out.set_len(value_array.len());
}
DataType::Utf8 => {
string_array_to_vector(as_string_array(value_array.as_ref()), &mut child);
}
_ => {
return Err("Nested list is not supported yet.".into());
return Err("Nested array is not supported yet.".into());
}
}
for i in 0..array.len() {
let offset = array.value_offset(i);
let length = array.value_length();
out.set_entry(i, offset as usize, length as usize);
}
out.set_len(value_array.len());

Ok(())
}
Expand Down Expand Up @@ -511,7 +502,7 @@ fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) -> Result
DataType::FixedSizeList(_, _) => {
fixed_size_list_array_to_vector(
as_fixed_size_list_array(column.as_ref()),
&mut out.list_vector_child(i),
&mut out.array_vector_child(i),
)?;
}
DataType::Struct(_) => {
Expand Down Expand Up @@ -569,10 +560,10 @@ mod test {
use crate::{Connection, Result};
use arrow::{
array::{
Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, Float64Array, GenericListArray,
Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray,
Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray,
Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, FixedSizeListArray, Float64Array,
GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray,
Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray,
},
buffer::{OffsetBuffer, ScalarBuffer},
datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema},
Expand Down Expand Up @@ -760,6 +751,50 @@ mod test {
Ok(())
}

//field: FieldRef, size: i32, values: ArrayRef, nulls: Option<NullBuffer>
#[test]
fn test_fixed_array_roundtrip() -> Result<(), Box<dyn Error>> {
let array = FixedSizeListArray::new(
Arc::new(Field::new("item", DataType::Int32, true)),
2,
Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)])),
None,
);

let expected_output_array = array.clone();

let db = Connection::open_in_memory()?;
db.register_table_function::<ArrowVTab>("arrow")?;

// Roundtrip a record batch from Rust to DuckDB and back to Rust
let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), false)]);

let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())])?;
let param = arrow_recordbatch_to_query_params(rb);
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
let rb = stmt.query_arrow(param)?.next().expect("no record batch");

let output_any_array = rb.column(0);
assert!(output_any_array
.data_type()
.equals_datatype(expected_output_array.data_type()));

match output_any_array.as_fixed_size_list_opt() {
Some(output_array) => {
assert_eq!(output_array.len(), expected_output_array.len());
for i in 0..output_array.len() {
assert_eq!(output_array.is_valid(i), expected_output_array.is_valid(i));
if output_array.is_valid(i) {
assert!(expected_output_array.value(i).eq(&output_array.value(i)));
}
}
}
None => panic!("Expected FixedSizeListArray"),
}

Ok(())
}

#[test]
fn test_primitive_roundtrip_contains_nulls() -> Result<(), Box<dyn Error>> {
let mut builder = arrow::array::PrimitiveBuilder::<arrow::datatypes::Int32Type>::new();
Expand Down
7 changes: 6 additions & 1 deletion crates/duckdb/src/vtab/data_chunk.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
logical_type::LogicalType,
vector::{FlatVector, ListVector, StructVector},
vector::{ArrayVector, FlatVector, ListVector, StructVector},
};
use crate::ffi::{
duckdb_create_data_chunk, duckdb_data_chunk, duckdb_data_chunk_get_column_count, duckdb_data_chunk_get_size,
Expand Down Expand Up @@ -35,6 +35,11 @@ impl DataChunk {
ListVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) })
}

/// Get a array vector from the column index.
pub fn array_vector(&self, idx: usize) -> ArrayVector {
ArrayVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) })
}

/// Get struct vector at the column index: `idx`.
pub fn struct_vector(&self, idx: usize) -> StructVector {
StructVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) })
Expand Down
9 changes: 9 additions & 0 deletions crates/duckdb/src/vtab/logical_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,15 @@ impl LogicalType {
}
}

/// Creates an array type from its child type.
pub fn array(child_type: &LogicalType, array_size: u64) -> Self {
unsafe {
Self {
ptr: duckdb_create_array_type(child_type.ptr, array_size),
}
}
}

/// Creates a decimal type from its `width` and `scale`.
pub fn decimal(width: u8, scale: u8) -> Self {
unsafe {
Expand Down
43 changes: 43 additions & 0 deletions crates/duckdb/src/vtab/vector.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::{any::Any, ffi::CString, slice};

use libduckdb_sys::{duckdb_array_type_array_size, duckdb_array_vector_get_child};

use super::LogicalType;
use crate::ffi::{
duckdb_list_entry, duckdb_list_vector_get_child, duckdb_list_vector_get_size, duckdb_list_vector_reserve,
Expand Down Expand Up @@ -170,6 +172,42 @@ impl ListVector {
}
}

/// A array vector. (fixed-size list)
pub struct ArrayVector {
/// ArrayVector does not own the vector pointer.
ptr: duckdb_vector,
}

impl From<duckdb_vector> for ArrayVector {
fn from(ptr: duckdb_vector) -> Self {
Self { ptr }
}
}

impl ArrayVector {
/// Get the logical type of this ArrayVector.
pub fn logical_type(&self) -> LogicalType {
LogicalType::from(unsafe { duckdb_vector_get_column_type(self.ptr) })
}

pub fn get_array_size(&self) -> u64 {
let ty = self.logical_type();
unsafe { duckdb_array_type_array_size(ty.ptr) as u64 }
}

/// Returns the child vector.
/// capacity should be a multiple of the array size.
// TODO: not ideal interface. Where should we keep count.
pub fn child(&self, capacity: usize) -> FlatVector {
FlatVector::with_capacity(unsafe { duckdb_array_vector_get_child(self.ptr) }, capacity)
}

/// Set primitive data to the child node.
pub fn set_child<T: Copy>(&self, data: &[T]) {
self.child(data.len()).copy(data);
}
}

/// A struct vector.
pub struct StructVector {
/// ListVector does not own the vector pointer.
Expand Down Expand Up @@ -198,6 +236,11 @@ impl StructVector {
ListVector::from(unsafe { duckdb_struct_vector_get_child(self.ptr, idx as u64) })
}

/// Take the child as [ArrayVector].
pub fn array_vector_child(&self, idx: usize) -> ArrayVector {
ArrayVector::from(unsafe { duckdb_struct_vector_get_child(self.ptr, idx as u64) })
}

/// Get the logical type of this struct vector.
pub fn logical_type(&self) -> LogicalType {
LogicalType::from(unsafe { duckdb_vector_get_column_type(self.ptr) })
Expand Down

0 comments on commit f628e5a

Please sign in to comment.