Skip to content

Commit

Permalink
fix: Return an error on type mismatch rather than panic (#4995) (#5341)
Browse files Browse the repository at this point in the history
* fix: Return an error on type mismatch rather than panic (#4995)

* test: ArrowWriter and batch schema mismatch is an error

* docs: Clarify that ArrowWriter expects the batch's schema to match
  • Loading branch information
carols10cents committed Jan 30, 2024
1 parent 5117b38 commit 8e9d713
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
26 changes: 25 additions & 1 deletion parquet/src/arrow/arrow_writer/levels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ struct LevelContext {
}

/// A helper to construct [`ArrayLevels`] from a potentially nested [`Field`]
#[derive(Debug)]
enum LevelInfoBuilder {
/// A primitive, leaf array
Primitive(ArrayLevels),
Expand Down Expand Up @@ -132,7 +133,15 @@ enum LevelInfoBuilder {
impl LevelInfoBuilder {
/// Create a new [`LevelInfoBuilder`] for the given [`Field`] and parent [`LevelContext`]
fn try_new(field: &Field, parent_ctx: LevelContext, array: &ArrayRef) -> Result<Self> {
assert_eq!(field.data_type(), array.data_type());
if field.data_type() != array.data_type() {
return Err(arrow_err!(format!(
"Incompatible type. Field '{}' has type {}, array has type {}",
field.name(),
field.data_type(),
array.data_type(),
)));
}

let is_nullable = field.is_nullable();

match array.data_type() {
Expand Down Expand Up @@ -1835,6 +1844,21 @@ mod tests {
assert_eq!(levels[0], expected_level);
}

#[test]
fn mismatched_types() {
let array = Arc::new(Int32Array::from_iter(0..10)) as ArrayRef;
let field = Field::new("item", DataType::Float64, false);

let err = LevelInfoBuilder::try_new(&field, Default::default(), &array)
.unwrap_err()
.to_string();

assert_eq!(
err,
"Arrow: Incompatible type. Field 'item' has type Float64, array has type Int32",
);
}

fn levels<T: Array + 'static>(field: &Field, array: T) -> LevelInfoBuilder {
let v = Arc::new(array) as ArrayRef;
LevelInfoBuilder::try_new(field, Default::default(), &v).unwrap()
Expand Down
29 changes: 28 additions & 1 deletion parquet/src/arrow/arrow_writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,9 @@ impl<W: Write + Send> ArrowWriter<W> {
///
/// If this would cause the current row group to exceed [`WriterProperties::max_row_group_size`]
/// rows, the contents of `batch` will be written to one or more row groups such that all but
/// the final row group in the file contain [`WriterProperties::max_row_group_size`] rows
/// the final row group in the file contain [`WriterProperties::max_row_group_size`] rows.
///
/// This will fail if the `batch`'s schema does not match the writer's schema.
pub fn write(&mut self, batch: &RecordBatch) -> Result<()> {
if batch.num_rows() == 0 {
return Ok(());
Expand Down Expand Up @@ -2963,4 +2965,29 @@ mod tests {
.any(|kv| kv.key.as_str() == ARROW_SCHEMA_META_KEY));
}
}

#[test]
fn mismatched_schemas() {
let batch_schema = Schema::new(vec![Field::new("count", DataType::Int32, false)]);
let file_schema = Arc::new(Schema::new(vec![Field::new(
"temperature",
DataType::Float64,
false,
)]));

let batch = RecordBatch::try_new(
Arc::new(batch_schema),
vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _],
)
.unwrap();

let mut buf = Vec::with_capacity(1024);
let mut writer = ArrowWriter::try_new(&mut buf, file_schema.clone(), None).unwrap();

let err = writer.write(&batch).unwrap_err().to_string();
assert_eq!(
err,
"Arrow: Incompatible type. Field 'temperature' has type Float64, array has type Int32"
);
}
}

0 comments on commit 8e9d713

Please sign in to comment.