Skip to content

Commit

Permalink
fix: when taking struct fields they should be merged into the output …
Browse files Browse the repository at this point in the history
…in the correct order (#3277)

In various situations we need to fetch some fields from a struct and
then later add more fields for the struct. For example, maybe we have a
`struct<big_string: string, filter: i32>`. We might query with a filter
on `filter` and then use late materialization to take `big_string`. When
we do this we were previously creating `struct<filter: i32, big_string:
string>` which would cause issues since that isn't the correct data
type.

In creating this fix I added a new `Projection` concept and I would like
to slowly replace a lot of the places where we use schemas as
projections to use `Projection` instead. Not necessarily for performance
but more for convenience.
  • Loading branch information
westonpace authored Dec 20, 2024
1 parent 022135b commit 805438f
Show file tree
Hide file tree
Showing 14 changed files with 1,045 additions and 219 deletions.
16 changes: 16 additions & 0 deletions python/python/tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,19 @@ def test_duckdb(tmp_path):
expected = duckdb.query("SELECT id, meta, price FROM ds").to_df()
expected = expected[expected.meta == "aa"].reset_index(drop=True)
tm.assert_frame_equal(actual, expected)


def test_struct_field_order(tmp_path):
"""
This test regresses some old behavior where the order of struct fields would get
messed up due to late materialization and we would get {y,x} instead of {x,y}
"""
data = pa.table({"struct": [{"x": i, "y": i} for i in range(10)]})
dataset = lance.write_dataset(data, tmp_path)

for late_materialization in [True, False]:
result = dataset.to_table(
filter="struct.y > 5", late_materialization=late_materialization
)
expected = pa.table({"struct": [{"x": i, "y": i} for i in range(6, 10)]})
assert result == expected
159 changes: 157 additions & 2 deletions rust/lance-arrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,17 @@ pub trait RecordBatchExt {

/// Merge with another [`RecordBatch`] and returns a new one.
///
/// Fields are merged based on name. First we iterate the left columns. If a matching
/// name is found in the right then we merge the two columns. If there is no match then
/// we add the left column to the output.
///
/// To merge two columns we consider the type. If both arrays are struct arrays we recurse.
/// Otherwise we use the left array.
///
/// Afterwards we add all non-matching right columns to the output.
///
/// Note: This method likely does not handle nested fields correctly and you may want to consider
/// using [`merge_with_schema`] instead.
/// ```
/// use std::sync::Arc;
/// use arrow_array::*;
Expand Down Expand Up @@ -382,6 +393,17 @@ pub trait RecordBatchExt {
/// TODO: add merge nested fields support.
fn merge(&self, other: &RecordBatch) -> Result<RecordBatch>;

/// Create a batch by merging columns between two batches with a given schema.
///
/// A reference schema is used to determine the proper ordering of nested fields.
///
/// For each field in the reference schema we look for corresponding fields in
/// the left and right batches. If a field is found in both batches we recursively merge
/// it.
///
/// If a field is only in the left or right batch we take it as it is.
fn merge_with_schema(&self, other: &RecordBatch, schema: &Schema) -> Result<RecordBatch>;

/// Drop one column specified with the name and return the new [`RecordBatch`].
///
/// If the named column does not exist, it returns a copy of this [`RecordBatch`].
Expand Down Expand Up @@ -450,6 +472,23 @@ impl RecordBatchExt for RecordBatch {
self.try_new_from_struct_array(merge(&left_struct_array, &right_struct_array))
}

fn merge_with_schema(&self, other: &RecordBatch, schema: &Schema) -> Result<RecordBatch> {
if self.num_rows() != other.num_rows() {
return Err(ArrowError::InvalidArgumentError(format!(
"Attempt to merge two RecordBatch with different sizes: {} != {}",
self.num_rows(),
other.num_rows()
)));
}
let left_struct_array: StructArray = self.clone().into();
let right_struct_array: StructArray = other.clone().into();
self.try_new_from_struct_array(merge_with_schema(
&left_struct_array,
&right_struct_array,
schema.fields(),
))
}

fn drop_column(&self, name: &str) -> Result<Self> {
let mut fields = vec![];
let mut columns = vec![];
Expand Down Expand Up @@ -542,7 +581,6 @@ fn project(struct_array: &StructArray, fields: &Fields) -> Result<StructArray> {
StructArray::try_new(fields.clone(), columns, None)
}

/// Merge the fields and columns of two RecordBatch's recursively
fn merge(left_struct_array: &StructArray, right_struct_array: &StructArray) -> StructArray {
let mut fields: Vec<Field> = vec![];
let mut columns: Vec<ArrayRef> = vec![];
Expand Down Expand Up @@ -616,6 +654,77 @@ fn merge(left_struct_array: &StructArray, right_struct_array: &StructArray) -> S
StructArray::from(zipped)
}

fn merge_with_schema(
left_struct_array: &StructArray,
right_struct_array: &StructArray,
fields: &Fields,
) -> StructArray {
// Helper function that returns true if both types are struct or both are non-struct
fn same_type_kind(left: &DataType, right: &DataType) -> bool {
match (left, right) {
(DataType::Struct(_), DataType::Struct(_)) => true,
(DataType::Struct(_), _) => false,
(_, DataType::Struct(_)) => false,
_ => true,
}
}

let mut output_fields: Vec<Field> = Vec::with_capacity(fields.len());
let mut columns: Vec<ArrayRef> = Vec::with_capacity(fields.len());

let left_fields = left_struct_array.fields();
let left_columns = left_struct_array.columns();
let right_fields = right_struct_array.fields();
let right_columns = right_struct_array.columns();

for field in fields {
let left_match_idx = left_fields.iter().position(|f| {
f.name() == field.name() && same_type_kind(f.data_type(), field.data_type())
});
let right_match_idx = right_fields.iter().position(|f| {
f.name() == field.name() && same_type_kind(f.data_type(), field.data_type())
});

match (left_match_idx, right_match_idx) {
(None, Some(right_idx)) => {
output_fields.push(right_fields[right_idx].as_ref().clone());
columns.push(right_columns[right_idx].clone());
}
(Some(left_idx), None) => {
output_fields.push(left_fields[left_idx].as_ref().clone());
columns.push(left_columns[left_idx].clone());
}
(Some(left_idx), Some(right_idx)) => {
if let DataType::Struct(child_fields) = field.data_type() {
let left_sub_array = left_columns[left_idx].as_struct();
let right_sub_array = right_columns[right_idx].as_struct();
let merged_sub_array =
merge_with_schema(left_sub_array, right_sub_array, child_fields);
output_fields.push(Field::new(
field.name(),
merged_sub_array.data_type().clone(),
field.is_nullable(),
));
columns.push(Arc::new(merged_sub_array) as ArrayRef);
} else {
output_fields.push(left_fields[left_idx].as_ref().clone());
columns.push(left_columns[left_idx].clone());
}
}
(None, None) => {
// The field will not be included in the output
}
}
}

let zipped: Vec<(FieldRef, ArrayRef)> = output_fields
.into_iter()
.map(Arc::new)
.zip(columns)
.collect::<Vec<_>>();
StructArray::from(zipped)
}

fn get_sub_array<'a>(array: &'a ArrayRef, components: &[&str]) -> Option<&'a ArrayRef> {
if components.is_empty() {
return Some(array);
Expand Down Expand Up @@ -721,7 +830,7 @@ impl BufferExt for arrow_buffer::Buffer {
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int32Array, StringArray};
use arrow_array::{new_empty_array, Int32Array, StringArray};

#[test]
fn test_merge_recursive() {
Expand Down Expand Up @@ -808,6 +917,52 @@ mod tests {
assert_eq!(result, merged_batch);
}

#[test]
fn test_merge_with_schema() {
fn test_batch(names: &[&str], types: &[DataType]) -> (Schema, RecordBatch) {
let fields: Fields = names
.iter()
.zip(types)
.map(|(name, ty)| Field::new(name.to_string(), ty.clone(), false))
.collect();
let schema = Schema::new(vec![Field::new(
"struct",
DataType::Struct(fields.clone()),
false,
)]);
let children = types
.iter()
.map(|ty| new_empty_array(ty))
.collect::<Vec<_>>();
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(StructArray::new(fields, children, None)) as ArrayRef],
);
(schema, batch.unwrap())
}

let (_, left_batch) = test_batch(&["a", "b"], &[DataType::Int32, DataType::Int64]);
let (_, right_batch) = test_batch(&["c", "b"], &[DataType::Int32, DataType::Int64]);
let (output_schema, _) = test_batch(
&["b", "a", "c"],
&[DataType::Int64, DataType::Int32, DataType::Int32],
);

// If we use merge_with_schema the schema is respected
let merged = left_batch
.merge_with_schema(&right_batch, &output_schema)
.unwrap();
assert_eq!(merged.schema().as_ref(), &output_schema);

// If we use merge we get first-come first-serve based on the left batch
let (naive_schema, _) = test_batch(
&["a", "b", "c"],
&[DataType::Int32, DataType::Int64, DataType::Int32],
);
let merged = left_batch.merge(&right_batch).unwrap();
assert_eq!(merged.schema().as_ref(), &naive_schema);
}

#[test]
fn test_take_record_batch() {
let schema = Arc::new(Schema::new(vec![
Expand Down
4 changes: 2 additions & 2 deletions rust/lance-core/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ mod schema;

use crate::{Error, Result};
pub use field::{
Encoding, Field, NullabilityComparison, SchemaCompareOptions, StorageClass,
Encoding, Field, NullabilityComparison, OnTypeMismatch, SchemaCompareOptions, StorageClass,
LANCE_STORAGE_CLASS_SCHEMA_META_KEY,
};
pub use schema::Schema;
pub use schema::{OnMissing, Projectable, Projection, Schema};

pub const COMPRESSION_META_KEY: &str = "lance-encoding:compression";
pub const COMPRESSION_LEVEL_META_KEY: &str = "lance-encoding:compression-level";
Expand Down
Loading

0 comments on commit 805438f

Please sign in to comment.