From d91b99c8b360b6ac91a29f358a6c684b424e1883 Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Wed, 16 Oct 2024 17:47:24 +1100 Subject: [PATCH] fix: Do not check dtypes of non-projected columns for parquet (#19254) --- crates/polars-io/src/parquet/read/reader.rs | 20 ++++++++-- .../nodes/parquet_source/metadata_fetch.rs | 2 +- py-polars/tests/unit/io/test_lazy_parquet.py | 40 +++++++++++++++++-- 3 files changed, 54 insertions(+), 8 deletions(-) diff --git a/crates/polars-io/src/parquet/read/reader.rs b/crates/polars-io/src/parquet/read/reader.rs index 2a70ef2c5046..25d1f51b098b 100644 --- a/crates/polars-io/src/parquet/read/reader.rs +++ b/crates/polars-io/src/parquet/read/reader.rs @@ -89,9 +89,15 @@ impl ParquetReader { projected_arrow_schema: Option<&ArrowSchema>, allow_missing_columns: bool, ) -> PolarsResult { + // `self.schema` gets overwritten if allow_missing_columns + let this_schema_width = self.schema()?.len(); + if allow_missing_columns { // Must check the dtypes - ensure_matching_dtypes_if_found(first_schema, self.schema()?.as_ref())?; + ensure_matching_dtypes_if_found( + projected_arrow_schema.unwrap_or(first_schema.as_ref()), + self.schema()?.as_ref(), + )?; self.schema.replace(first_schema.clone()); } @@ -104,7 +110,7 @@ impl ParquetReader { projected_arrow_schema, )?; } else { - if schema.len() > first_schema.len() { + if this_schema_width > first_schema.len() { polars_bail!( SchemaMismatch: "parquet file contained extra columns and no selection was given" @@ -328,9 +334,15 @@ impl ParquetAsyncReader { projected_arrow_schema: Option<&ArrowSchema>, allow_missing_columns: bool, ) -> PolarsResult { + // `self.schema` gets overwritten if allow_missing_columns + let this_schema_width = self.schema().await?.len(); + if allow_missing_columns { // Must check the dtypes - ensure_matching_dtypes_if_found(first_schema, self.schema().await?.as_ref())?; + ensure_matching_dtypes_if_found( + projected_arrow_schema.unwrap_or(first_schema.as_ref()), + self.schema().await?.as_ref(), + )?; self.schema.replace(first_schema.clone()); } @@ -343,7 +355,7 @@ impl ParquetAsyncReader { projected_arrow_schema, )?; } else { - if schema.len() > first_schema.len() { + if this_schema_width > first_schema.len() { polars_bail!( SchemaMismatch: "parquet file contained extra columns and no selection was given" diff --git a/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs b/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs index 746c517ce744..e3377036b908 100644 --- a/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs +++ b/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs @@ -141,7 +141,7 @@ impl ParquetSourceNode { } if allow_missing_columns { - ensure_matching_dtypes_if_found(&first_schema, &schema)?; + ensure_matching_dtypes_if_found(projected_arrow_schema.as_ref(), &schema)?; } else { ensure_schema_has_projected_fields( &schema, diff --git a/py-polars/tests/unit/io/test_lazy_parquet.py b/py-polars/tests/unit/io/test_lazy_parquet.py index 06efe16330d3..49a842386ea2 100644 --- a/py-polars/tests/unit/io/test_lazy_parquet.py +++ b/py-polars/tests/unit/io/test_lazy_parquet.py @@ -710,10 +710,18 @@ def test_parquet_schema_arg( schema: dict[str, type[pl.DataType]] = {"a": pl.Int64} # type: ignore[no-redef] - lf = pl.scan_parquet(paths, parallel=parallel, schema=schema) + for allow_missing_columns in [True, False]: + lf = pl.scan_parquet( + paths, + parallel=parallel, + schema=schema, + allow_missing_columns=allow_missing_columns, + ) - with pytest.raises(pl.exceptions.SchemaError, match="file contained extra columns"): - lf.collect(streaming=streaming) + with pytest.raises( + pl.exceptions.SchemaError, match="file contained extra columns" + ): + lf.collect(streaming=streaming) lf = pl.scan_parquet(paths, parallel=parallel, schema=schema).select("a") @@ -731,3 +739,29 @@ def test_parquet_schema_arg( match="data type mismatch for column b: expected: i8, found: i64", ): lf.collect(streaming=streaming) + + +@pytest.mark.parametrize("streaming", [True, False]) +@pytest.mark.parametrize("allow_missing_columns", [True, False]) +@pytest.mark.write_disk +def test_scan_parquet_ignores_dtype_mismatch_for_non_projected_columns_19249( + tmp_path: Path, + allow_missing_columns: bool, + streaming: bool, +) -> None: + tmp_path.mkdir(exist_ok=True) + paths = [tmp_path / "1", tmp_path / "2"] + + pl.DataFrame({"a": 1, "b": 1}, schema={"a": pl.Int32, "b": pl.UInt8}).write_parquet( + paths[0] + ) + pl.DataFrame( + {"a": 1, "b": 1}, schema={"a": pl.Int32, "b": pl.UInt64} + ).write_parquet(paths[1]) + + assert_frame_equal( + pl.scan_parquet(paths, allow_missing_columns=allow_missing_columns) + .select("a") + .collect(streaming=streaming), + pl.DataFrame({"a": [1, 1]}, schema={"a": pl.Int32}), + )