From f4e65dfbd12e16d48ed0d1e83dce3aeef928904e Mon Sep 17 00:00:00 2001 From: Wainberg Date: Fri, 19 Jan 2024 06:12:00 -0500 Subject: [PATCH] fix(rust): decompress the right number of rows when reading compressed CSVs (#13721) Co-authored-by: Wainberg --- crates/polars-io/src/csv/read_impl/mod.rs | 11 ++++++++--- py-polars/tests/unit/io/test_csv.py | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/crates/polars-io/src/csv/read_impl/mod.rs b/crates/polars-io/src/csv/read_impl/mod.rs index db268f92147c..88177211913d 100644 --- a/crates/polars-io/src/csv/read_impl/mod.rs +++ b/crates/polars-io/src/csv/read_impl/mod.rs @@ -183,10 +183,15 @@ impl<'a> CoreReader<'a> { // In case the file is compressed this schema inference is wrong and has to be done // again after decompression. #[cfg(any(feature = "decompress", feature = "decompress-fast"))] - if let Some(b) = - decompress(&reader_bytes, n_rows, separator, quote_char, eol_char) { - reader_bytes = ReaderBytes::Owned(b); + let total_n_rows = n_rows.map(|n| { + skip_rows + (has_header as usize) + skip_rows_after_header + n + }); + if let Some(b) = + decompress(&reader_bytes, total_n_rows, separator, quote_char, eol_char) + { + reader_bytes = ReaderBytes::Owned(b); + } } let (inferred_schema, _, _) = infer_file_schema( diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index cfd9836f31f5..bdd0f564e376 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -1744,3 +1744,21 @@ def test_invalid_csv_raise() -> None: "SK0127960V000","SK BT 0018977"," """.strip() ) + + +@pytest.mark.write_disk() +def test_partial_read_compressed_file(tmp_path: Path) -> None: + df = pl.DataFrame( + {"idx": range(1_000), "dt": date(2025, 12, 31), "txt": "hello world"} + ) + tmp_path.mkdir(exist_ok=True) + file_path = tmp_path / "large.csv.gz" + bytes_io = io.BytesIO() + df.write_csv(bytes_io) + bytes_io.seek(0) + with gzip.open(file_path, mode="wb") as f: + f.write(bytes_io.getvalue()) + df = pl.read_csv( + file_path, skip_rows=40, has_header=False, skip_rows_after_header=20, n_rows=30 + ) + assert df.shape == (30, 3)