Skip to content

Commit

Permalink
fix: Do not check dtypes of non-projected columns for parquet (#19254)
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Oct 16, 2024
1 parent 109c404 commit d91b99c
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 8 deletions.
20 changes: 16 additions & 4 deletions crates/polars-io/src/parquet/read/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,15 @@ impl<R: MmapBytesReader> ParquetReader<R> {
projected_arrow_schema: Option<&ArrowSchema>,
allow_missing_columns: bool,
) -> PolarsResult<Self> {
// `self.schema` gets overwritten if allow_missing_columns
let this_schema_width = self.schema()?.len();

if allow_missing_columns {
// Must check the dtypes
ensure_matching_dtypes_if_found(first_schema, self.schema()?.as_ref())?;
ensure_matching_dtypes_if_found(
projected_arrow_schema.unwrap_or(first_schema.as_ref()),
self.schema()?.as_ref(),
)?;
self.schema.replace(first_schema.clone());
}

Expand All @@ -104,7 +110,7 @@ impl<R: MmapBytesReader> ParquetReader<R> {
projected_arrow_schema,
)?;
} else {
if schema.len() > first_schema.len() {
if this_schema_width > first_schema.len() {
polars_bail!(
SchemaMismatch:
"parquet file contained extra columns and no selection was given"
Expand Down Expand Up @@ -328,9 +334,15 @@ impl ParquetAsyncReader {
projected_arrow_schema: Option<&ArrowSchema>,
allow_missing_columns: bool,
) -> PolarsResult<Self> {
// `self.schema` gets overwritten if allow_missing_columns
let this_schema_width = self.schema().await?.len();

if allow_missing_columns {
// Must check the dtypes
ensure_matching_dtypes_if_found(first_schema, self.schema().await?.as_ref())?;
ensure_matching_dtypes_if_found(
projected_arrow_schema.unwrap_or(first_schema.as_ref()),
self.schema().await?.as_ref(),
)?;
self.schema.replace(first_schema.clone());
}

Expand All @@ -343,7 +355,7 @@ impl ParquetAsyncReader {
projected_arrow_schema,
)?;
} else {
if schema.len() > first_schema.len() {
if this_schema_width > first_schema.len() {
polars_bail!(
SchemaMismatch:
"parquet file contained extra columns and no selection was given"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ impl ParquetSourceNode {
}

if allow_missing_columns {
ensure_matching_dtypes_if_found(&first_schema, &schema)?;
ensure_matching_dtypes_if_found(projected_arrow_schema.as_ref(), &schema)?;
} else {
ensure_schema_has_projected_fields(
&schema,
Expand Down
40 changes: 37 additions & 3 deletions py-polars/tests/unit/io/test_lazy_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,10 +710,18 @@ def test_parquet_schema_arg(

schema: dict[str, type[pl.DataType]] = {"a": pl.Int64} # type: ignore[no-redef]

lf = pl.scan_parquet(paths, parallel=parallel, schema=schema)
for allow_missing_columns in [True, False]:
lf = pl.scan_parquet(
paths,
parallel=parallel,
schema=schema,
allow_missing_columns=allow_missing_columns,
)

with pytest.raises(pl.exceptions.SchemaError, match="file contained extra columns"):
lf.collect(streaming=streaming)
with pytest.raises(
pl.exceptions.SchemaError, match="file contained extra columns"
):
lf.collect(streaming=streaming)

lf = pl.scan_parquet(paths, parallel=parallel, schema=schema).select("a")

Expand All @@ -731,3 +739,29 @@ def test_parquet_schema_arg(
match="data type mismatch for column b: expected: i8, found: i64",
):
lf.collect(streaming=streaming)


@pytest.mark.parametrize("streaming", [True, False])
@pytest.mark.parametrize("allow_missing_columns", [True, False])
@pytest.mark.write_disk
def test_scan_parquet_ignores_dtype_mismatch_for_non_projected_columns_19249(
tmp_path: Path,
allow_missing_columns: bool,
streaming: bool,
) -> None:
tmp_path.mkdir(exist_ok=True)
paths = [tmp_path / "1", tmp_path / "2"]

pl.DataFrame({"a": 1, "b": 1}, schema={"a": pl.Int32, "b": pl.UInt8}).write_parquet(
paths[0]
)
pl.DataFrame(
{"a": 1, "b": 1}, schema={"a": pl.Int32, "b": pl.UInt64}
).write_parquet(paths[1])

assert_frame_equal(
pl.scan_parquet(paths, allow_missing_columns=allow_missing_columns)
.select("a")
.collect(streaming=streaming),
pl.DataFrame({"a": [1, 1]}, schema={"a": pl.Int32}),
)

0 comments on commit d91b99c

Please sign in to comment.