Skip to content

Commit

Permalink
Add IpcSchemaEncoder, deprecate ipc schema functions, Fix IPC not r…
Browse files Browse the repository at this point in the history
…especting not preserving dict ID (#6444)

* arrow-ipc: Add test for non preserving dict ID behavior with same ID

* arrow-ipc: Always set dict ID in IPC from dictionary tracker

This decouples dictionary IDs that end up in IPC from the schema further
because the dictionary tracker always first gathers the dict ID for each
field whether it is pre-defined and preserved or not.

Then when actually writing the IPC bytes the dictionary ID is always
taken from the dictionary tracker as opposed to falling back to the
`Field` of the `Schema`.

* arrow-ipc: Read dictionary IDs from dictionary tracker in correct order

When dictionary IDs are not preserved, then they are assigned depth
first, however, when reading them from the dictionary tracker to write
the IPC bytes, they were previously read from the dictionary tracker in
the order that the schema is traversed (first come first serve), which
caused an incorrect order of dictionaries serialized in IPC.

* Refine IpcSchemaEncoder API and docs

* reduce repeated code

* Fix lints

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
brancz and alamb authored Sep 25, 2024
1 parent 922a1ff commit 62825b2
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 67 deletions.
4 changes: 3 additions & 1 deletion arrow-flight/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ pub struct IpcMessage(pub Bytes);

fn flight_schema_as_encoded_data(arrow_schema: &Schema, options: &IpcWriteOptions) -> EncodedData {
let data_gen = writer::IpcDataGenerator::default();
data_gen.schema_to_bytes(arrow_schema, options)
let mut dict_tracker =
writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id());
data_gen.schema_to_bytes_with_dictionary_tracker(arrow_schema, &mut dict_tracker, options)
}

fn flight_schema_as_flatbuffer(schema: &Schema, options: &IpcWriteOptions) -> IpcMessage {
Expand Down
188 changes: 143 additions & 45 deletions arrow-ipc/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,110 @@ use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;

use crate::writer::DictionaryTracker;
use crate::{size_prefixed_root_as_message, KeyValue, Message, CONTINUATION_MARKER};
use DataType::*;

/// Serialize a schema in IPC format
pub fn schema_to_fb(schema: &Schema) -> FlatBufferBuilder {
let mut fbb = FlatBufferBuilder::new();
/// Low level Arrow [Schema] to IPC bytes converter
///
/// See also [`fb_to_schema`] for the reverse operation
///
/// # Example
/// ```
/// # use arrow_ipc::convert::{fb_to_schema, IpcSchemaEncoder};
/// # use arrow_ipc::root_as_schema;
/// # use arrow_ipc::writer::DictionaryTracker;
/// # use arrow_schema::{DataType, Field, Schema};
/// // given an arrow schema to serialize
/// let schema = Schema::new(vec![
/// Field::new("a", DataType::Int32, false),
/// ]);
///
/// // Use a dictionary tracker to track dictionary id if needed
/// let mut dictionary_tracker = DictionaryTracker::new(true);
/// // create a FlatBuffersBuilder that contains the encoded bytes
/// let fb = IpcSchemaEncoder::new()
/// .with_dictionary_tracker(&mut dictionary_tracker)
/// .schema_to_fb(&schema);
///
/// // the bytes are in `fb.finished_data()`
/// let ipc_bytes = fb.finished_data();
///
/// // convert the IPC bytes back to an Arrow schema
/// let ipc_schema = root_as_schema(ipc_bytes).unwrap();
/// let schema2 = fb_to_schema(ipc_schema);
/// assert_eq!(schema, schema2);
/// ```
#[derive(Debug)]
pub struct IpcSchemaEncoder<'a> {
dictionary_tracker: Option<&'a mut DictionaryTracker>,
}

impl<'a> Default for IpcSchemaEncoder<'a> {
fn default() -> Self {
Self::new()
}
}

let root = schema_to_fb_offset(&mut fbb, schema);
impl<'a> IpcSchemaEncoder<'a> {
/// Create a new schema encoder
pub fn new() -> IpcSchemaEncoder<'a> {
IpcSchemaEncoder {
dictionary_tracker: None,
}
}

/// Specify a dictionary tracker to use
pub fn with_dictionary_tracker(
mut self,
dictionary_tracker: &'a mut DictionaryTracker,
) -> Self {
self.dictionary_tracker = Some(dictionary_tracker);
self
}

fbb.finish(root, None);
/// Serialize a schema in IPC format, returning a completed [`FlatBufferBuilder`]
///
/// Note: Call [`FlatBufferBuilder::finished_data`] to get the serialized bytes
pub fn schema_to_fb<'b>(&mut self, schema: &Schema) -> FlatBufferBuilder<'b> {
let mut fbb = FlatBufferBuilder::new();

let root = self.schema_to_fb_offset(&mut fbb, schema);

fbb.finish(root, None);

fbb
}

fbb
/// Serialize a schema to an in progress [`FlatBufferBuilder`], returning the in progress offset.
pub fn schema_to_fb_offset<'b>(
&mut self,
fbb: &mut FlatBufferBuilder<'b>,
schema: &Schema,
) -> WIPOffset<crate::Schema<'b>> {
let fields = schema
.fields()
.iter()
.map(|field| build_field(fbb, &mut self.dictionary_tracker, field))
.collect::<Vec<_>>();
let fb_field_list = fbb.create_vector(&fields);

let fb_metadata_list =
(!schema.metadata().is_empty()).then(|| metadata_to_fb(fbb, schema.metadata()));

let mut builder = crate::SchemaBuilder::new(fbb);
builder.add_fields(fb_field_list);
if let Some(fb_metadata_list) = fb_metadata_list {
builder.add_custom_metadata(fb_metadata_list);
}
builder.finish()
}
}

/// Serialize a schema in IPC format
#[deprecated(since = "54.0.0", note = "Use `IpcSchemaConverter`.")]
pub fn schema_to_fb(schema: &Schema) -> FlatBufferBuilder<'_> {
IpcSchemaEncoder::new().schema_to_fb(schema)
}

pub fn metadata_to_fb<'a>(
Expand All @@ -60,26 +152,12 @@ pub fn metadata_to_fb<'a>(
fbb.create_vector(&custom_metadata)
}

#[deprecated(since = "54.0.0", note = "Use `IpcSchemaConverter`.")]
pub fn schema_to_fb_offset<'a>(
fbb: &mut FlatBufferBuilder<'a>,
schema: &Schema,
) -> WIPOffset<crate::Schema<'a>> {
let fields = schema
.fields()
.iter()
.map(|field| build_field(fbb, field))
.collect::<Vec<_>>();
let fb_field_list = fbb.create_vector(&fields);

let fb_metadata_list =
(!schema.metadata().is_empty()).then(|| metadata_to_fb(fbb, schema.metadata()));

let mut builder = crate::SchemaBuilder::new(fbb);
builder.add_fields(fb_field_list);
if let Some(fb_metadata_list) = fb_metadata_list {
builder.add_custom_metadata(fb_metadata_list);
}
builder.finish()
IpcSchemaEncoder::new().schema_to_fb_offset(fbb, schema)
}

/// Convert an IPC Field to Arrow Field
Expand Down Expand Up @@ -114,7 +192,7 @@ impl<'a> From<crate::Field<'a>> for Field {
}
}

/// Deserialize a Schema table from flat buffer format to Schema data type
/// Deserialize an ipc [crate::Schema`] from flat buffers to an arrow [Schema].
pub fn fb_to_schema(fb: crate::Schema) -> Schema {
let mut fields: Vec<Field> = vec![];
let c_fields = fb.fields().unwrap();
Expand Down Expand Up @@ -424,6 +502,7 @@ pub(crate) struct FBFieldType<'b> {
/// Create an IPC Field from an Arrow Field
pub(crate) fn build_field<'a>(
fbb: &mut FlatBufferBuilder<'a>,
dictionary_tracker: &mut Option<&mut DictionaryTracker>,
field: &Field,
) -> WIPOffset<crate::Field<'a>> {
// Optional custom metadata.
Expand All @@ -433,19 +512,29 @@ pub(crate) fn build_field<'a>(
};

let fb_field_name = fbb.create_string(field.name().as_str());
let field_type = get_fb_field_type(field.data_type(), fbb);
let field_type = get_fb_field_type(field.data_type(), dictionary_tracker, fbb);

let fb_dictionary = if let Dictionary(index_type, _) = field.data_type() {
Some(get_fb_dictionary(
index_type,
field
.dict_id()
.expect("All Dictionary types have `dict_id`"),
field
.dict_is_ordered()
.expect("All Dictionary types have `dict_is_ordered`"),
fbb,
))
match dictionary_tracker {
Some(tracker) => Some(get_fb_dictionary(
index_type,
tracker.set_dict_id(field),
field
.dict_is_ordered()
.expect("All Dictionary types have `dict_is_ordered`"),
fbb,
)),
None => Some(get_fb_dictionary(
index_type,
field
.dict_id()
.expect("Dictionary type must have a dictionary id"),
field
.dict_is_ordered()
.expect("All Dictionary types have `dict_is_ordered`"),
fbb,
)),
}
} else {
None
};
Expand Down Expand Up @@ -473,6 +562,7 @@ pub(crate) fn build_field<'a>(
/// Get the IPC type of a data type
pub(crate) fn get_fb_field_type<'a>(
data_type: &DataType,
dictionary_tracker: &mut Option<&mut DictionaryTracker>,
fbb: &mut FlatBufferBuilder<'a>,
) -> FBFieldType<'a> {
// some IPC implementations expect an empty list for child data, instead of a null value.
Expand Down Expand Up @@ -673,7 +763,7 @@ pub(crate) fn get_fb_field_type<'a>(
}
}
List(ref list_type) => {
let child = build_field(fbb, list_type);
let child = build_field(fbb, dictionary_tracker, list_type);
FBFieldType {
type_type: crate::Type::List,
type_: crate::ListBuilder::new(fbb).finish().as_union_value(),
Expand All @@ -682,15 +772,15 @@ pub(crate) fn get_fb_field_type<'a>(
}
ListView(_) | LargeListView(_) => unimplemented!("ListView/LargeListView not implemented"),
LargeList(ref list_type) => {
let child = build_field(fbb, list_type);
let child = build_field(fbb, dictionary_tracker, list_type);
FBFieldType {
type_type: crate::Type::LargeList,
type_: crate::LargeListBuilder::new(fbb).finish().as_union_value(),
children: Some(fbb.create_vector(&[child])),
}
}
FixedSizeList(ref list_type, len) => {
let child = build_field(fbb, list_type);
let child = build_field(fbb, dictionary_tracker, list_type);
let mut builder = crate::FixedSizeListBuilder::new(fbb);
builder.add_listSize(*len);
FBFieldType {
Expand All @@ -703,7 +793,7 @@ pub(crate) fn get_fb_field_type<'a>(
// struct's fields are children
let mut children = vec![];
for field in fields {
children.push(build_field(fbb, field));
children.push(build_field(fbb, dictionary_tracker, field));
}
FBFieldType {
type_type: crate::Type::Struct_,
Expand All @@ -712,8 +802,8 @@ pub(crate) fn get_fb_field_type<'a>(
}
}
RunEndEncoded(run_ends, values) => {
let run_ends_field = build_field(fbb, run_ends);
let values_field = build_field(fbb, values);
let run_ends_field = build_field(fbb, dictionary_tracker, run_ends);
let values_field = build_field(fbb, dictionary_tracker, values);
let children = [run_ends_field, values_field];
FBFieldType {
type_type: crate::Type::RunEndEncoded,
Expand All @@ -724,7 +814,7 @@ pub(crate) fn get_fb_field_type<'a>(
}
}
Map(map_field, keys_sorted) => {
let child = build_field(fbb, map_field);
let child = build_field(fbb, dictionary_tracker, map_field);
let mut field_type = crate::MapBuilder::new(fbb);
field_type.add_keysSorted(*keys_sorted);
FBFieldType {
Expand All @@ -737,7 +827,7 @@ pub(crate) fn get_fb_field_type<'a>(
// In this library, the dictionary "type" is a logical construct. Here we
// pass through to the value type, as we've already captured the index
// type in the DictionaryEncoding metadata in the parent field
get_fb_field_type(value_type, fbb)
get_fb_field_type(value_type, dictionary_tracker, fbb)
}
Decimal128(precision, scale) => {
let mut builder = crate::DecimalBuilder::new(fbb);
Expand All @@ -764,7 +854,7 @@ pub(crate) fn get_fb_field_type<'a>(
Union(fields, mode) => {
let mut children = vec![];
for (_, field) in fields.iter() {
children.push(build_field(fbb, field));
children.push(build_field(fbb, dictionary_tracker, field));
}

let union_mode = match mode {
Expand Down Expand Up @@ -1067,7 +1157,10 @@ mod tests {
md,
);

let fb = schema_to_fb(&schema);
let mut dictionary_tracker = DictionaryTracker::new(true);
let fb = IpcSchemaEncoder::new()
.with_dictionary_tracker(&mut dictionary_tracker)
.schema_to_fb(&schema);

// read back fields
let ipc = crate::root_as_schema(fb.finished_data()).unwrap();
Expand Down Expand Up @@ -1098,9 +1191,14 @@ mod tests {

// generate same message with Rust
let data_gen = crate::writer::IpcDataGenerator::default();
let mut dictionary_tracker = DictionaryTracker::new(true);
let arrow_schema = Schema::new(vec![Field::new("field1", DataType::UInt32, false)]);
let bytes = data_gen
.schema_to_bytes(&arrow_schema, &crate::writer::IpcWriteOptions::default())
.schema_to_bytes_with_dictionary_tracker(
&arrow_schema,
&mut dictionary_tracker,
&crate::writer::IpcWriteOptions::default(),
)
.ipc_message;

let ipc2 = crate::root_as_message(&bytes).unwrap();
Expand Down
55 changes: 55 additions & 0 deletions arrow-ipc/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2283,4 +2283,59 @@ mod tests {
let err = reader.next().unwrap().unwrap_err();
assert!(matches!(err, ArrowError::InvalidArgumentError(_)));
}

#[test]
fn test_same_dict_id_without_preserve() {
let batch = RecordBatch::try_new(
Arc::new(Schema::new(
["a", "b"]
.iter()
.map(|name| {
Field::new_dict(
name.to_string(),
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
),
true,
0,
false,
)
})
.collect::<Vec<Field>>(),
)),
vec![
Arc::new(
vec![Some("c"), Some("d")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>(),
) as ArrayRef,
Arc::new(
vec![Some("e"), Some("f")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>(),
) as ArrayRef,
],
)
.expect("Failed to create RecordBatch");

// serialize the record batch as an IPC stream
let mut buf = vec![];
{
let mut writer = crate::writer::StreamWriter::try_new_with_options(
&mut buf,
batch.schema().as_ref(),
crate::writer::IpcWriteOptions::default().with_preserve_dict_id(false),
)
.expect("Failed to create StreamWriter");
writer.write(&batch).expect("Failed to write RecordBatch");
writer.finish().expect("Failed to finish StreamWriter");
}

StreamReader::try_new(std::io::Cursor::new(buf), None)
.expect("Failed to create StreamReader")
.for_each(|decoded_batch| {
assert_eq!(decoded_batch.expect("Failed to read RecordBatch"), batch);
});
}
}
Loading

0 comments on commit 62825b2

Please sign in to comment.