diff --git a/crates/polars-parquet/src/arrow/read/statistics/mod.rs b/crates/polars-parquet/src/arrow/read/statistics/mod.rs index d9210abcb5aa..22ba71caba17 100644 --- a/crates/polars-parquet/src/arrow/read/statistics/mod.rs +++ b/crates/polars-parquet/src/arrow/read/statistics/mod.rs @@ -3,7 +3,7 @@ use std::collections::VecDeque; use arrow::array::*; use arrow::datatypes::{ArrowDataType, Field, IntervalUnit, PhysicalType}; -use arrow::types::i256; +use arrow::types::{f16, i256, NativeType}; use arrow::with_match_primitive_type_full; use ethnum::I256; use polars_error::{polars_bail, PolarsResult}; @@ -28,6 +28,7 @@ mod struct_; mod utf8; use self::list::DynMutableListArray; +use super::PrimitiveLogicalType; /// Arrow-deserialized parquet Statistics of a file #[derive(Debug, PartialEq)] @@ -549,6 +550,37 @@ fn push( } } +pub fn cast_statistics( + statistics: ParquetStatistics, + primitive_type: &ParquetPrimitiveType, + output_type: &ArrowDataType, +) -> ParquetStatistics { + use {ArrowDataType as DT, PrimitiveLogicalType as PT}; + + match (primitive_type.logical_type, output_type) { + (Some(PT::Float16), DT::Float32) => { + let statistics = statistics.expect_fixedlen(); + + let primitive_type = primitive_type.clone(); + + ParquetStatistics::Float(PrimitiveStatistics:: { + primitive_type, + null_count: statistics.null_count, + distinct_count: statistics.distinct_count, + min_value: statistics + .min_value + .as_ref() + .map(|v| f16::from_le_bytes([v[0], v[1]]).to_f32()), + max_value: statistics + .max_value + .as_ref() + .map(|v| f16::from_le_bytes([v[0], v[1]]).to_f32()), + }) + }, + _ => statistics, + } +} + /// Deserializes the statistics in the column chunks from a single `row_group` /// into [`Statistics`] associated from `field`'s name. /// @@ -562,9 +594,13 @@ pub fn deserialize<'a>( let mut stats = field_md .map(|column| { + let primitive_type = &column.descriptor().descriptor.primitive_type; Ok(( - column.statistics().transpose()?, - column.descriptor().descriptor.primitive_type.clone(), + column + .statistics() + .transpose()? + .map(|stats| cast_statistics(stats, primitive_type, &field.dtype)), + primitive_type.clone(), )) }) .collect::, ParquetPrimitiveType)>>>()?; diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 5f73284d9707..d44ed0412e53 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -2041,7 +2041,7 @@ def test_conserve_sortedness( ) -def test_decode_f16() -> None: +def test_f16() -> None: values = [float("nan"), 0.0, 0.5, 1.0, 1.5] table = pa.Table.from_pydict( @@ -2050,10 +2050,22 @@ def test_decode_f16() -> None: } ) + df = pl.Series("x", values, pl.Float32).to_frame() + f = io.BytesIO() pq.write_table(table, f) f.seek(0) - df = pl.read_parquet(f) + assert_frame_equal(pl.read_parquet(f), df) + + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f).filter(pl.col.x > 0.5).collect(), + df.filter(pl.col.x > 0.5), + ) - assert_series_equal(df.get_column("x"), pl.Series("x", values, pl.Float32)) + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f).slice(1, 3).collect(), + df.slice(1, 3), + )