Skip to content

Commit

Permalink
Implement zstd Compression Support for JSONL and Parquet Files (#230)
Browse files Browse the repository at this point in the history
* Add zstandard dependency for compression support

* feat: Add zstd compression support for jsonl reader

* feat: Add zstd compression support for ParquetWriter

* feat: Update DiskWriter to handle the other compression for Parquet files

* Remove annotaion

* feat: Update compression handling in DiskWriter and ParquetWriter

* Update src/datatrove/pipeline/writers/disk_base.py

Handle compression on ParquetWriter directly

Co-authored-by: Guilherme Penedo <[email protected]>

* Update src/datatrove/pipeline/writers/parquet.py

None to out of list

Co-authored-by: Guilherme Penedo <[email protected]>

* Refactor constructor to explicitly set default compression to None

* Add validation for compression parameter in ParquetWriter

* Update src/datatrove/pipeline/writers/disk_base.py

official extension for zstd is ".zst"

Co-authored-by: Guilherme Penedo <[email protected]>

---------

Co-authored-by: Guilherme Penedo <[email protected]>
  • Loading branch information
justHungryMan and guipenedo authored Aug 28, 2024
1 parent 3b91550 commit d5d1924
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 6 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ io = [
"python-magic",
"warcio",
"datasets>=2.18.0",
"orjson"
"orjson",
"zstandard"
]
s3 = [
"s3fs>=2023.12.2",
Expand Down
2 changes: 2 additions & 0 deletions src/datatrove/pipeline/writers/disk_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __init__(
output_filename = output_filename or self.default_output_filename
if self.compression == "gzip" and not output_filename.endswith(".gz"):
output_filename += ".gz"
elif self.compression == "zstd" and not output_filename.endswith(".zst"):
output_filename += ".zst"
self.max_file_size = max_file_size
self.file_id_counter = Counter()
if self.max_file_size > 0 and mode != "wb":
Expand Down
19 changes: 14 additions & 5 deletions src/datatrove/pipeline/writers/parquet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import Counter, defaultdict
from typing import IO, Callable
from typing import IO, Callable, Literal

from datatrove.io import DataFolderLike
from datatrove.pipeline.writers.disk_base import DiskWriter
Expand All @@ -14,24 +14,31 @@ def __init__(
self,
output_folder: DataFolderLike,
output_filename: str = None,
compression: str | None = None,
compression: Literal["snappy", "gzip", "brotli", "lz4", "zstd"] | None = None,
adapter: Callable = None,
batch_size: int = 1000,
expand_metadata: bool = False,
max_file_size: int = 5 * 2**30, # 5GB
):
# Validate the compression setting
if compression not in {"snappy", "gzip", "brotli", "lz4", "zstd", None}:
raise ValueError(
"Invalid compression type. Allowed types are 'snappy', 'gzip', 'brotli', 'lz4', 'zstd', or None."
)

super().__init__(
output_folder,
output_filename,
compression,
adapter,
compression=None, # Ensure superclass initializes without compression
adapter=adapter,
mode="wb",
expand_metadata=expand_metadata,
max_file_size=max_file_size,
)
self._writers = {}
self._batches = defaultdict(list)
self._file_counter = Counter()
self.compression = compression
self.batch_size = batch_size

def _on_file_switch(self, original_name, old_filename, new_filename):
Expand Down Expand Up @@ -62,7 +69,9 @@ def _write(self, document: dict, file_handler: IO, filename: str):

if filename not in self._writers:
self._writers[filename] = pq.ParquetWriter(
file_handler, schema=pa.RecordBatch.from_pylist([document]).schema
file_handler,
schema=pa.RecordBatch.from_pylist([document]).schema,
compression=self.compression,
)
self._batches[filename].append(document)
if len(self._batches[filename]) == self.batch_size:
Expand Down
30 changes: 30 additions & 0 deletions tests/pipeline/test_jsonl_zstd_compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import shutil
import tempfile
import unittest

from datatrove.data import Document
from datatrove.pipeline.readers.jsonl import JsonlReader
from datatrove.pipeline.writers.jsonl import JsonlWriter


class TestZstdCompression(unittest.TestCase):
def setUp(self):
# Create a temporary directory
self.tmp_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, self.tmp_dir)

def test_jsonl_writer_reader(self):
data = [
Document(text=text, id=str(i), metadata={"somedata": 2 * i, "somefloat": i * 0.4, "somestring": "hello"})
for i, text in enumerate(["hello", "text2", "more text"])
]
with JsonlWriter(output_folder=self.tmp_dir, compression="zstd") as w:
for doc in data:
w.write(doc)
reader = JsonlReader(self.tmp_dir, compression="zstd")
c = 0
for read_doc, original in zip(reader(), data):
read_doc.metadata.pop("file_path", None)
assert read_doc == original
c += 1
assert c == len(data)
30 changes: 30 additions & 0 deletions tests/pipeline/test_parquet_zstd_compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import shutil
import tempfile
import unittest

from datatrove.data import Document
from datatrove.pipeline.readers.parquet import ParquetReader
from datatrove.pipeline.writers.parquet import ParquetWriter


class TestZstdCompression(unittest.TestCase):
def setUp(self):
# Create a temporary directory
self.tmp_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, self.tmp_dir)

def test_parquet_writer_reader(self):
data = [
Document(text=text, id=str(i), metadata={"somedata": 2 * i, "somefloat": i * 0.4, "somestring": "hello"})
for i, text in enumerate(["hello", "text2", "more text"])
]
with ParquetWriter(output_folder=self.tmp_dir, compression="zstd") as w:
for doc in data:
w.write(doc)
reader = ParquetReader(self.tmp_dir)
c = 0
for read_doc, original in zip(reader(), data):
read_doc.metadata.pop("file_path", None)
assert read_doc == original
c += 1
assert c == len(data)

0 comments on commit d5d1924

Please sign in to comment.