Skip to content

Commit

Permalink
fix: Correctly load Parquet statistics for f16
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite committed Oct 18, 2024
1 parent 01a4e06 commit daf3789
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 6 deletions.
42 changes: 39 additions & 3 deletions crates/polars-parquet/src/arrow/read/statistics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::collections::VecDeque;

use arrow::array::*;
use arrow::datatypes::{ArrowDataType, Field, IntervalUnit, PhysicalType};
use arrow::types::i256;
use arrow::types::{f16, i256, NativeType};
use arrow::with_match_primitive_type_full;
use ethnum::I256;
use polars_error::{polars_bail, PolarsResult};
Expand All @@ -28,6 +28,7 @@ mod struct_;
mod utf8;

use self::list::DynMutableListArray;
use super::PrimitiveLogicalType;

/// Arrow-deserialized parquet Statistics of a file
#[derive(Debug, PartialEq)]
Expand Down Expand Up @@ -549,6 +550,37 @@ fn push(
}
}

pub fn cast_statistics(
statistics: ParquetStatistics,
primitive_type: &ParquetPrimitiveType,
output_type: &ArrowDataType,
) -> ParquetStatistics {
use {ArrowDataType as DT, PrimitiveLogicalType as PT};

match (primitive_type.logical_type, output_type) {
(Some(PT::Float16), DT::Float32) => {
let statistics = statistics.expect_fixedlen();

let primitive_type = primitive_type.clone();

ParquetStatistics::Float(PrimitiveStatistics::<f32> {
primitive_type,
null_count: statistics.null_count,
distinct_count: statistics.distinct_count,
min_value: statistics
.min_value
.as_ref()
.map(|v| f16::from_le_bytes([v[0], v[1]]).to_f32()),
max_value: statistics
.max_value
.as_ref()
.map(|v| f16::from_le_bytes([v[0], v[1]]).to_f32()),
})
},
_ => statistics,
}
}

/// Deserializes the statistics in the column chunks from a single `row_group`
/// into [`Statistics`] associated from `field`'s name.
///
Expand All @@ -562,9 +594,13 @@ pub fn deserialize<'a>(

let mut stats = field_md
.map(|column| {
let primitive_type = &column.descriptor().descriptor.primitive_type;
Ok((
column.statistics().transpose()?,
column.descriptor().descriptor.primitive_type.clone(),
column
.statistics()
.transpose()?
.map(|stats| cast_statistics(stats, primitive_type, &field.dtype)),
primitive_type.clone(),
))
})
.collect::<PolarsResult<VecDeque<(Option<ParquetStatistics>, ParquetPrimitiveType)>>>()?;
Expand Down
18 changes: 15 additions & 3 deletions py-polars/tests/unit/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2041,7 +2041,7 @@ def test_conserve_sortedness(
)


def test_decode_f16() -> None:
def test_f16() -> None:
values = [float("nan"), 0.0, 0.5, 1.0, 1.5]

table = pa.Table.from_pydict(
Expand All @@ -2050,10 +2050,22 @@ def test_decode_f16() -> None:
}
)

df = pl.Series("x", values, pl.Float32).to_frame()

f = io.BytesIO()
pq.write_table(table, f)

f.seek(0)
df = pl.read_parquet(f)
assert_frame_equal(pl.read_parquet(f), df)

f.seek(0)
assert_frame_equal(
pl.scan_parquet(f).filter(pl.col.x > 0.5).collect(),
df.filter(pl.col.x > 0.5),
)

assert_series_equal(df.get_column("x"), pl.Series("x", values, pl.Float32))
f.seek(0)
assert_frame_equal(
pl.scan_parquet(f).slice(1, 3).collect(),
df.slice(1, 3),
)

0 comments on commit daf3789

Please sign in to comment.