Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for DuckDB arrays when using Arrow's FixedSizeList #323

Merged
merged 13 commits into from
Jun 4, 2024
2 changes: 1 addition & 1 deletion crates/duckdb/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ mod test {
let db = Connection::open_in_memory()?;
assert_eq!(
300f32,
db.query_row::<f32, _, _>(r#"SELECT SUM(value) FROM read_parquet('https://github.com/wangfenjin/duckdb-rs/raw/main/examples/int32_decimal.parquet');"#, [], |r| r.get(0))?
db.query_row::<f32, _, _>(r#"SELECT SUM(value) FROM read_parquet('https://github.com/duckdb/duckdb-rs/raw/main/crates/duckdb/examples/int32_decimal.parquet');"#, [], |r| r.get(0))?
);
Ok(())
}
Expand Down
81 changes: 58 additions & 23 deletions crates/duckdb/src/vtab/arrow.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{
vector::{FlatVector, ListVector, Vector},
vector::{ArrayVector, FlatVector, ListVector, Vector},
BindInfo, DataChunk, Free, FunctionInfo, InitInfo, LogicalType, LogicalTypeId, StructVector, VTab,
};
use std::ptr::null_mut;
Expand Down Expand Up @@ -196,8 +196,11 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalType, Box<d
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
} else if let DataType::LargeList(child) = data_type {
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
} else if let DataType::FixedSizeList(child, _) = data_type {
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
} else if let DataType::FixedSizeList(child, array_size) = data_type {
Ok(LogicalType::array(
&to_duckdb_logical_type(child.data_type())?,
*array_size as u64,
))
} else {
Err(
format!("Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs")
Expand Down Expand Up @@ -234,7 +237,7 @@ pub fn record_batch_to_duckdb_data_chunk(
list_array_to_vector(as_large_list_array(col.as_ref()), &mut chunk.list_vector(i))?;
}
DataType::FixedSizeList(_, _) => {
fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.list_vector(i))?;
fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.array_vector(i))?;
}
DataType::Struct(_) => {
let struct_array = as_struct_array(col.as_ref());
Expand Down Expand Up @@ -455,33 +458,21 @@ fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(

fn fixed_size_list_array_to_vector(
array: &FixedSizeListArray,
out: &mut ListVector,
out: &mut ArrayVector,
) -> Result<(), Box<dyn std::error::Error>> {
let value_array = array.values();
let mut child = out.child(value_array.len());
match value_array.data_type() {
dt if dt.is_primitive() => {
primitive_array_to_vector(value_array.as_ref(), &mut child)?;
for i in 0..array.len() {
let offset = array.value_offset(i);
let length = array.value_length();
out.set_entry(i, offset as usize, length as usize);
}
out.set_len(value_array.len());
}
DataType::Utf8 => {
string_array_to_vector(as_string_array(value_array.as_ref()), &mut child);
}
_ => {
return Err("Nested list is not supported yet.".into());
return Err("Nested array is not supported yet.".into());
}
}
for i in 0..array.len() {
let offset = array.value_offset(i);
let length = array.value_length();
out.set_entry(i, offset as usize, length as usize);
}
out.set_len(value_array.len());

Ok(())
}
Expand Down Expand Up @@ -511,7 +502,7 @@ fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) -> Result
DataType::FixedSizeList(_, _) => {
fixed_size_list_array_to_vector(
as_fixed_size_list_array(column.as_ref()),
&mut out.list_vector_child(i),
&mut out.array_vector_child(i),
)?;
}
DataType::Struct(_) => {
Expand Down Expand Up @@ -569,10 +560,10 @@ mod test {
use crate::{Connection, Result};
use arrow::{
array::{
Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, Float64Array, GenericListArray,
Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray,
Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray,
Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, FixedSizeListArray, Float64Array,
GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray,
Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray,
},
buffer::{OffsetBuffer, ScalarBuffer},
datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema},
Expand Down Expand Up @@ -760,6 +751,50 @@ mod test {
Ok(())
}

//field: FieldRef, size: i32, values: ArrayRef, nulls: Option<NullBuffer>
#[test]
fn test_fixed_array_roundtrip() -> Result<(), Box<dyn Error>> {
let array = FixedSizeListArray::new(
Arc::new(Field::new("item", DataType::Int32, true)),
2,
Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)])),
None,
);

let expected_output_array = array.clone();

let db = Connection::open_in_memory()?;
db.register_table_function::<ArrowVTab>("arrow")?;

// Roundtrip a record batch from Rust to DuckDB and back to Rust
let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), false)]);

let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())])?;
let param = arrow_recordbatch_to_query_params(rb);
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
let rb = stmt.query_arrow(param)?.next().expect("no record batch");

let output_any_array = rb.column(0);
assert!(output_any_array
.data_type()
.equals_datatype(expected_output_array.data_type()));

match output_any_array.as_fixed_size_list_opt() {
Some(output_array) => {
assert_eq!(output_array.len(), expected_output_array.len());
for i in 0..output_array.len() {
assert_eq!(output_array.is_valid(i), expected_output_array.is_valid(i));
if output_array.is_valid(i) {
assert!(expected_output_array.value(i).eq(&output_array.value(i)));
}
}
}
None => panic!("Expected FixedSizeListArray"),
}

Ok(())
}

#[test]
fn test_primitive_roundtrip_contains_nulls() -> Result<(), Box<dyn Error>> {
let mut builder = arrow::array::PrimitiveBuilder::<arrow::datatypes::Int32Type>::new();
Expand Down
7 changes: 6 additions & 1 deletion crates/duckdb/src/vtab/data_chunk.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
logical_type::LogicalType,
vector::{FlatVector, ListVector, StructVector},
vector::{ArrayVector, FlatVector, ListVector, StructVector},
};
use crate::ffi::{
duckdb_create_data_chunk, duckdb_data_chunk, duckdb_data_chunk_get_column_count, duckdb_data_chunk_get_size,
Expand Down Expand Up @@ -35,6 +35,11 @@ impl DataChunk {
ListVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) })
}

/// Get a array vector from the column index.
pub fn array_vector(&self, idx: usize) -> ArrayVector {
ArrayVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) })
}

/// Get struct vector at the column index: `idx`.
pub fn struct_vector(&self, idx: usize) -> StructVector {
StructVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) })
Expand Down
9 changes: 9 additions & 0 deletions crates/duckdb/src/vtab/logical_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,15 @@ impl LogicalType {
}
}

/// Creates an array type from its child type.
pub fn array(child_type: &LogicalType, array_size: u64) -> Self {
unsafe {
Self {
ptr: duckdb_create_array_type(child_type.ptr, array_size),
}
}
}

/// Creates a decimal type from its `width` and `scale`.
pub fn decimal(width: u8, scale: u8) -> Self {
unsafe {
Expand Down
43 changes: 43 additions & 0 deletions crates/duckdb/src/vtab/vector.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::{any::Any, ffi::CString, slice};

use libduckdb_sys::{duckdb_array_type_array_size, duckdb_array_vector_get_child};

use super::LogicalType;
use crate::ffi::{
duckdb_list_entry, duckdb_list_vector_get_child, duckdb_list_vector_get_size, duckdb_list_vector_reserve,
Expand Down Expand Up @@ -170,6 +172,42 @@ impl ListVector {
}
}

/// A array vector. (fixed-size list)
pub struct ArrayVector {
/// ArrayVector does not own the vector pointer.
ptr: duckdb_vector,
}

impl From<duckdb_vector> for ArrayVector {
fn from(ptr: duckdb_vector) -> Self {
Self { ptr }
}
}

impl ArrayVector {
/// Get the logical type of this ArrayVector.
pub fn logical_type(&self) -> LogicalType {
LogicalType::from(unsafe { duckdb_vector_get_column_type(self.ptr) })
}

pub fn get_array_size(&self) -> u64 {
let ty = self.logical_type();
unsafe { duckdb_array_type_array_size(ty.ptr) as u64 }
}

/// Returns the child vector.
/// capacity should be a multiple of the array size.
// TODO: not ideal interface. Where should we keep count.
pub fn child(&self, capacity: usize) -> FlatVector {
FlatVector::with_capacity(unsafe { duckdb_array_vector_get_child(self.ptr) }, capacity)
}

/// Set primitive data to the child node.
pub fn set_child<T: Copy>(&self, data: &[T]) {
self.child(data.len()).copy(data);
}
}

/// A struct vector.
pub struct StructVector {
/// ListVector does not own the vector pointer.
Expand Down Expand Up @@ -198,6 +236,11 @@ impl StructVector {
ListVector::from(unsafe { duckdb_struct_vector_get_child(self.ptr, idx as u64) })
}

/// Take the child as [ArrayVector].
pub fn array_vector_child(&self, idx: usize) -> ArrayVector {
ArrayVector::from(unsafe { duckdb_struct_vector_get_child(self.ptr, idx as u64) })
}

/// Get the logical type of this struct vector.
pub fn logical_type(&self) -> LogicalType {
LogicalType::from(unsafe { duckdb_vector_get_column_type(self.ptr) })
Expand Down
Loading