From d5d1924e91b378f3084a7b23c26d240c5f627702 Mon Sep 17 00:00:00 2001 From: sungjun lee Date: Wed, 28 Aug 2024 18:51:36 +0900 Subject: [PATCH] Implement zstd Compression Support for JSONL and Parquet Files (#230) * Add zstandard dependency for compression support * feat: Add zstd compression support for jsonl reader * feat: Add zstd compression support for ParquetWriter * feat: Update DiskWriter to handle the other compression for Parquet files * Remove annotaion * feat: Update compression handling in DiskWriter and ParquetWriter * Update src/datatrove/pipeline/writers/disk_base.py Handle compression on ParquetWriter directly Co-authored-by: Guilherme Penedo * Update src/datatrove/pipeline/writers/parquet.py None to out of list Co-authored-by: Guilherme Penedo * Refactor constructor to explicitly set default compression to None * Add validation for compression parameter in ParquetWriter * Update src/datatrove/pipeline/writers/disk_base.py official extension for zstd is ".zst" Co-authored-by: Guilherme Penedo --------- Co-authored-by: Guilherme Penedo --- pyproject.toml | 3 +- src/datatrove/pipeline/writers/disk_base.py | 2 ++ src/datatrove/pipeline/writers/parquet.py | 19 ++++++++---- tests/pipeline/test_jsonl_zstd_compression.py | 30 +++++++++++++++++++ .../pipeline/test_parquet_zstd_compression.py | 30 +++++++++++++++++++ 5 files changed, 78 insertions(+), 6 deletions(-) create mode 100644 tests/pipeline/test_jsonl_zstd_compression.py create mode 100644 tests/pipeline/test_parquet_zstd_compression.py diff --git a/pyproject.toml b/pyproject.toml index a290adf6..d5ace92e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,8 @@ io = [ "python-magic", "warcio", "datasets>=2.18.0", - "orjson" + "orjson", + "zstandard" ] s3 = [ "s3fs>=2023.12.2", diff --git a/src/datatrove/pipeline/writers/disk_base.py b/src/datatrove/pipeline/writers/disk_base.py index 168af37c..8106bafa 100644 --- a/src/datatrove/pipeline/writers/disk_base.py +++ b/src/datatrove/pipeline/writers/disk_base.py @@ -42,6 +42,8 @@ def __init__( output_filename = output_filename or self.default_output_filename if self.compression == "gzip" and not output_filename.endswith(".gz"): output_filename += ".gz" + elif self.compression == "zstd" and not output_filename.endswith(".zst"): + output_filename += ".zst" self.max_file_size = max_file_size self.file_id_counter = Counter() if self.max_file_size > 0 and mode != "wb": diff --git a/src/datatrove/pipeline/writers/parquet.py b/src/datatrove/pipeline/writers/parquet.py index d9e7d093..01b5387c 100644 --- a/src/datatrove/pipeline/writers/parquet.py +++ b/src/datatrove/pipeline/writers/parquet.py @@ -1,5 +1,5 @@ from collections import Counter, defaultdict -from typing import IO, Callable +from typing import IO, Callable, Literal from datatrove.io import DataFolderLike from datatrove.pipeline.writers.disk_base import DiskWriter @@ -14,17 +14,23 @@ def __init__( self, output_folder: DataFolderLike, output_filename: str = None, - compression: str | None = None, + compression: Literal["snappy", "gzip", "brotli", "lz4", "zstd"] | None = None, adapter: Callable = None, batch_size: int = 1000, expand_metadata: bool = False, max_file_size: int = 5 * 2**30, # 5GB ): + # Validate the compression setting + if compression not in {"snappy", "gzip", "brotli", "lz4", "zstd", None}: + raise ValueError( + "Invalid compression type. Allowed types are 'snappy', 'gzip', 'brotli', 'lz4', 'zstd', or None." + ) + super().__init__( output_folder, output_filename, - compression, - adapter, + compression=None, # Ensure superclass initializes without compression + adapter=adapter, mode="wb", expand_metadata=expand_metadata, max_file_size=max_file_size, @@ -32,6 +38,7 @@ def __init__( self._writers = {} self._batches = defaultdict(list) self._file_counter = Counter() + self.compression = compression self.batch_size = batch_size def _on_file_switch(self, original_name, old_filename, new_filename): @@ -62,7 +69,9 @@ def _write(self, document: dict, file_handler: IO, filename: str): if filename not in self._writers: self._writers[filename] = pq.ParquetWriter( - file_handler, schema=pa.RecordBatch.from_pylist([document]).schema + file_handler, + schema=pa.RecordBatch.from_pylist([document]).schema, + compression=self.compression, ) self._batches[filename].append(document) if len(self._batches[filename]) == self.batch_size: diff --git a/tests/pipeline/test_jsonl_zstd_compression.py b/tests/pipeline/test_jsonl_zstd_compression.py new file mode 100644 index 00000000..4c0d4911 --- /dev/null +++ b/tests/pipeline/test_jsonl_zstd_compression.py @@ -0,0 +1,30 @@ +import shutil +import tempfile +import unittest + +from datatrove.data import Document +from datatrove.pipeline.readers.jsonl import JsonlReader +from datatrove.pipeline.writers.jsonl import JsonlWriter + + +class TestZstdCompression(unittest.TestCase): + def setUp(self): + # Create a temporary directory + self.tmp_dir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp_dir) + + def test_jsonl_writer_reader(self): + data = [ + Document(text=text, id=str(i), metadata={"somedata": 2 * i, "somefloat": i * 0.4, "somestring": "hello"}) + for i, text in enumerate(["hello", "text2", "more text"]) + ] + with JsonlWriter(output_folder=self.tmp_dir, compression="zstd") as w: + for doc in data: + w.write(doc) + reader = JsonlReader(self.tmp_dir, compression="zstd") + c = 0 + for read_doc, original in zip(reader(), data): + read_doc.metadata.pop("file_path", None) + assert read_doc == original + c += 1 + assert c == len(data) diff --git a/tests/pipeline/test_parquet_zstd_compression.py b/tests/pipeline/test_parquet_zstd_compression.py new file mode 100644 index 00000000..898c7399 --- /dev/null +++ b/tests/pipeline/test_parquet_zstd_compression.py @@ -0,0 +1,30 @@ +import shutil +import tempfile +import unittest + +from datatrove.data import Document +from datatrove.pipeline.readers.parquet import ParquetReader +from datatrove.pipeline.writers.parquet import ParquetWriter + + +class TestZstdCompression(unittest.TestCase): + def setUp(self): + # Create a temporary directory + self.tmp_dir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp_dir) + + def test_parquet_writer_reader(self): + data = [ + Document(text=text, id=str(i), metadata={"somedata": 2 * i, "somefloat": i * 0.4, "somestring": "hello"}) + for i, text in enumerate(["hello", "text2", "more text"]) + ] + with ParquetWriter(output_folder=self.tmp_dir, compression="zstd") as w: + for doc in data: + w.write(doc) + reader = ParquetReader(self.tmp_dir) + c = 0 + for read_doc, original in zip(reader(), data): + read_doc.metadata.pop("file_path", None) + assert read_doc == original + c += 1 + assert c == len(data)