From f39a0467f73d2169ec955a4b20ba20c8d3eb2a87 Mon Sep 17 00:00:00 2001 From: Artem Kozhevnikov Date: Fri, 8 Dec 2023 03:10:19 +0100 Subject: [PATCH] Parquet dataloader with fairseq2 primitives (#162) --- .github/workflows/_build_wheel-linux.yaml | 7 +- recipes/parquet/README.md | 50 +++ recipes/parquet/parquet_dataloader.py | 228 ++++++++++++ setup.py | 3 + src/fairseq2/data/parquet_tools.py | 352 ++++++++++++++++++ .../parquet/test_parquet_dataloader.py | 289 ++++++++++++++ 6 files changed, 927 insertions(+), 2 deletions(-) create mode 100644 recipes/parquet/README.md create mode 100644 recipes/parquet/parquet_dataloader.py create mode 100644 src/fairseq2/data/parquet_tools.py create mode 100644 tests/integration/parquet/test_parquet_dataloader.py diff --git a/.github/workflows/_build_wheel-linux.yaml b/.github/workflows/_build_wheel-linux.yaml index b5aea3cc8..656914604 100644 --- a/.github/workflows/_build_wheel-linux.yaml +++ b/.github/workflows/_build_wheel-linux.yaml @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +NllbTokenizer Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -188,7 +188,10 @@ jobs: pip install --no-cache-dir ~/artifacts/fairseq2n/python/build/wheelhouse/*.whl - name: Install fairseq2 run: | - pip install --no-cache-dir ~/artifacts/build/wheelhouse/*.whl + for whl in ~/artifacts/build/wheelhouse/*.whl; do + pip install --no-cache-dir $whl + pip install --no-cache-dir "fairseq['arrow']@$whl" + done - name: Set the sanitizer variables if: inputs.sanitizers != 'nosan' env: diff --git a/recipes/parquet/README.md b/recipes/parquet/README.md new file mode 100644 index 000000000..d0ccf44e5 --- /dev/null +++ b/recipes/parquet/README.md @@ -0,0 +1,50 @@ +## Parquet Data Loading with fairseq2 + +The recipe module [parquet_dataloader](./parquet_dataloader.py) shows one way to +build an efficient dataloader over a Apache Parquet dataset (partitioned or not) +using `fairseq2.data` primitives. It uses the [pyarrow.parquet](https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetDataset.html) +API to interface with Parquet files, so it requires an extra package +installation with `pip install fairseq2[arrow]`. + +The present dataloader is of general purpose and can be combined with various +downstream workflows. Some important technical notes to keep in mind: + +* Dataloader will simultaneously load several Parquet dataset fragments + (`nb_parallel_fragments`) and shuffle their elements together before returning +* Thus, increasing `nb_parallel_fragments` will result in better randomization + but also increase the memory footprint. +* For heavy rows datasets, prefer save the Parquet files with relatively small + `row_groups` to improve streaming regularity. +* For reading from S3 storage, `fairseq2.data` being multithreaded, + `from pyarrow.fs import S3FileSystem` (releasing GIL) works best. +* Currently, only some of pyarrow dtypes are mapped to their torch equivalent, + this support will improve in the future. + +Please refer to the `ParquetBasicDataloaderConfig` for more details about the +existing configuration parameters. + +Example of simple usage: + +```python +import pyarrow.compute as pc + +from recipes.parquet.parquet_dataloader import ( + ParquetBasicDataloaderConfig, + ParquetBatchFormat, + build_parquet_iterator_pipeline +) + +config = ParquetBasicDataloaderConfig( + parquet_path="path/to/parquet/dataset", + filters=pc.greater(pc.utf8_length(pc.field("src_text")), 5) + columns=["src_text", "src_lang", "audio_wav"], + batch_size=20, + output_format=ParquetBatchFormat.torch, + world_size=1, + rank=0, + seed=123, +) + +for batch in parquet_iterator(config): + pass +``` diff --git a/recipes/parquet/parquet_dataloader.py b/recipes/parquet/parquet_dataloader.py new file mode 100644 index 000000000..28f1a0cb8 --- /dev/null +++ b/recipes/parquet/parquet_dataloader.py @@ -0,0 +1,228 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import Enum +from functools import partial +from typing import Any, Generator, List, Optional, Union + +import pyarrow as pa +import pyarrow.parquet as pq + +from fairseq2.data.data_pipeline import DataPipeline, DataPipelineBuilder +from fairseq2.data.parquet_tools import ( + BatchOutputType, + _TableWrapper, + _to_real_object, + apply_filter, + build_iterator_over_one_table, + concat_table, + list_parquet_fragments, + load_one_fragment, + pyarrow_cpu, + pyarrow_table_to_torch_dict, + table_func_wrap, +) + + +class ParquetBatchFormat(Enum): + pyarrow = 0 + pandas = 1 + torch = 2 + + +@dataclass # TODO: (kw_only=True) with python3.10 +class ParquetBasicDataloaderConfig: + parquet_path: str + """The path to parquet dataset file.""" + + batch_size: Optional[int] = None + """The output batch size.""" + + order_by_length: Optional[str] = None + """The column in the dataset whose length will be used for batch ordering. + This results in batches with relatively homogeneous values, typically to + support optimal padding.""" + + max_tokens: Optional[int] = None + """Used with the ``order_by_length`` option to control the total number of + padded tokens in each batch. Typically, this option is preferred over + ``batch_size`` to reduce the memory footprint. + """ + + columns: Optional[List[str]] = None + """The list of columns to load.""" + + filters: Optional[Union[List[Any], pa.dataset.Expression]] = None + """See https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html#pyarrow.dataset.Expression + + Some examples : + + >>> import pyarrow.compute as pc + >>> import pyarrow as pa + + >>> filters = [("data_split", "=", "train"), ("lang1", "in", ["eng","spa"]), ("lang2", "=", "eng")]) + >>> filters = (pc.field("data_split") == pc.scalar("train")) & (pc.field("duration") > 7) + >>> filters = pa.compute.greater(pa.compute.utf8_length(ds.field("lang1_text")), 4) + >>> filters = pa.compute.less_equal(pa.compute.list_value_length(pa.dataset.field("audio_wav")), 16_000 * 30) + + Note that all fields used here should be among existing columns in the dataset schema. + """ + + output_format: ParquetBatchFormat = ParquetBatchFormat.pyarrow + """The format to use for output batches.""" + + split_to_row_groups: bool = True + """If ``True``, uses Parquet row groups instead of simple partitions which + are generally smaller. Highly recommended for non-partitioned parquet files.""" + + shuffle: bool = True + """If ``True``, shuffles the dataset samples during the iteration. If ``False`` + and ``order_by_length`` is ``None``, the batch samples will be produced in + natural Parquet dataset reading order.""" + + drop_null: bool = True + """If ``True``, drops rows containing any null value.""" + + seed: Optional[int] = None + """The RNG seed value for deterministic behavior.""" + + min_batch_size: int = 1 + """Drops batches whose length is less than ``min_batch_size``""" + + nb_parallel_fragments: int = 5 + """The number of Parquet fragments allowed to be read in parallel. Higher + values will result in higher speeds, better randomization, and higher memory + footprint. If partition size is rather small compared to the batch size, we + recommend to increase ``nb_parallel_fragments``.""" + + nb_prefetch: int = 2 + """The number of producer groups (of size `nb_parallel_fragments`) to + prefetch.""" + + world_size: int = 1 + """The world size of the process group.""" + + rank: int = 0 + """The rank of this worker in the process group.""" + + num_parallel_calls: int = 8 + """The number of parallel calls in map operations.""" + + use_threads: bool = False + """Whether pyarrow should use its internal threads to read the Parquet file. + Since we rely on the external parallelism, this param is tuned off by + default.""" + + filesystem: Optional[pa.fs.FileSystem] = None + """The filesystem to read the Parquet files from. S3 example: + >>> import s3fs + >>> filesystem = s3fs.core.S3FileSystem(...) + """ + + def __post_init__(self) -> None: + if not self.parquet_path: + raise ValueError(f"requires non-empty path got {self.parquet_path}") + + if not ((self.batch_size is None) ^ (self.max_tokens is None)): + raise ValueError("need to provide either `batch_size` either `max_tokens`") + if self.max_tokens is not None and self.order_by_length is None: + raise ValueError( + "`order_by_length` should be given to deal with `max_tokens`" + ) + + if self.filters is not None and not isinstance( + self.filters, pa.dataset.Expression + ): + self.filters = pq.filters_to_expression(self.filters) + + +def build_parquet_iterator_pipeline( + config: ParquetBasicDataloaderConfig, +) -> DataPipelineBuilder: + def inner_iterator(wrap_table: _TableWrapper) -> DataPipeline: + return build_iterator_over_one_table( + table=wrap_table.table, + order_by_length=config.order_by_length, + batch_size=config.batch_size, + max_tokens=config.max_tokens, + shuffle=config.shuffle, + seed=config.seed, + num_parallel_calls=max(config.num_parallel_calls // 2, 1), + ) + + pipeline_builder = ( + list_parquet_fragments( + parquet_path=config.parquet_path, + filters=config.filters, + columns=config.columns, + split_to_row_groups=config.split_to_row_groups, + filesystem=config.filesystem, + shuffle_window=2 * config.nb_prefetch * config.nb_parallel_fragments + if config.shuffle + else None, + seed=config.seed, + ) + .shard(shard_idx=config.rank, num_shards=config.world_size) + .map( + table_func_wrap(partial(load_one_fragment, columns=config.columns)), + num_parallel_calls=config.num_parallel_calls, + ) + .map( + table_func_wrap( + partial( + apply_filter, filters=config.filters, drop_null=config.drop_null + ) + ) + ) + .bucket(config.nb_parallel_fragments) + .prefetch(config.nb_prefetch) + .map( + table_func_wrap(concat_table), + num_parallel_calls=config.nb_prefetch, + ) + .yield_from(inner_iterator) + .filter( + table_func_wrap(lambda table: bool(len(table) >= config.min_batch_size)) + ) + ) + + if config.output_format == ParquetBatchFormat.pandas: + pipeline_builder = pipeline_builder.map( + table_func_wrap(lambda table: table.to_pandas()) + ) + elif config.output_format == ParquetBatchFormat.torch: + pipeline_builder = pipeline_builder.map( + table_func_wrap(pyarrow_table_to_torch_dict) + ) + return pipeline_builder + + +def parquet_iterator( + config: ParquetBasicDataloaderConfig, +) -> Generator[BatchOutputType, None, None]: + """ + Example of usage: + + >>> from recipes.parquet.parquet_dataloader import ( + ... ParquetBasicDataloaderConfig, ParquetBatchFormat, build_parquet_iterator_pipeline) + >>> from tqdm.auto import tqdm + >>> bpd_config = ParquetBasicDataloaderConfig(parquet_path="...", batch_size=20, + ... columns=["src_text", "src_lang", "audio_wav"], + ... output_format=ParquetBatchFormat.torch) + >>> ei_batch = parquet_iterator(bpd_config) + >>> res = [] + >>> for i, batch in tqdm(enumerate(ei_batch)): res.append(len(batch)) + """ + with pyarrow_cpu(config.num_parallel_calls): + yield from map( + _to_real_object, + iter( + build_parquet_iterator_pipeline(config) + .prefetch(config.num_parallel_calls) + .and_return(max_num_warnings=4) + ), + ) diff --git a/setup.py b/setup.py index 65ae68e52..926fcc340 100644 --- a/setup.py +++ b/setup.py @@ -55,4 +55,7 @@ "tqdm~=4.62", "typing_extensions~=4.3;python_version<'3.10'", ], + extras_require={ + "arrow": ["pyarrow>=13.0.0", "pandas~=2.0.0"], + }, ) diff --git a/src/fairseq2/data/parquet_tools.py b/src/fairseq2/data/parquet_tools.py new file mode 100644 index 000000000..60f521359 --- /dev/null +++ b/src/fairseq2/data/parquet_tools.py @@ -0,0 +1,352 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from contextlib import contextmanager +from typing import Dict, Generator, List, Optional, Union + +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +import torch +from numpy.typing import NDArray +from pyarrow.dataset import get_partition_keys # requires pyarrow >= 13 + +from fairseq2.data import CString +from fairseq2.data.data_pipeline import DataPipeline, DataPipelineBuilder, read_sequence + + +@contextmanager +def pyarrow_cpu(nb_cpu: int) -> Generator[None, None, None]: + nb_cpu_old = pa.cpu_count() + nb_io_cpu_old = pa.io_thread_count() + pa.set_cpu_count(nb_cpu) + pa.set_io_thread_count(nb_cpu) + try: + yield + finally: + pa.set_cpu_count(nb_cpu_old) + pa.set_io_thread_count(nb_io_cpu_old) + + +@contextmanager +def torch_random_seed(seed: Optional[int] = None) -> Generator[None, None, None]: + if seed is not None: + torch.manual_seed(seed) + yield + + +NestedDict = Dict[str, "NestedDictValue"] +NestedDictValue = Union[torch.Tensor, List[CString], pd.Series, NestedDict] +BatchOutputType = Union[pa.Table, pd.DataFrame, NestedDict] + + +def from_pyarrow_to_torch_tensor( + arr: Union[pa.Array, pa.ChunkedArray], strict: bool = True +) -> NestedDictValue: + """ + struct_array = pa.Array.from_pandas([{"x": 4, "y": "RR"}] * 10) + nest_array = pa.Array.from_pandas([[{'a': 1}, {'a': 2}]]) + """ + # for future ideas https://arrow.apache.org/docs/python/generated/pyarrow.Tensor.html + # for sparse matrix support https://github.com/apache/arrow/blob/main/python/pyarrow/tests/test_sparse_tensor.py + + if arr.null_count != 0: + raise ValueError("to torch conversion does not support null values") + + if isinstance(arr, pa.ChunkedArray): + arr = arr.chunks[0] if arr.num_chunks == 1 else arr.combine_chunks() + + arr_type = arr.type + if pa.types.is_primitive(arr_type): + return torch.from_numpy(arr.to_numpy(zero_copy_only=True)) + + try: + return torch.from_numpy(arr.to_numpy(zero_copy_only=True)) + except pa.ArrowInvalid: + pass + + if pa.types.is_dictionary(arr_type): + return from_pyarrow_to_torch_tensor(arr.dictionary_decode()) + + if pa.types.is_string(arr_type): + return list(map(CString, arr.to_pandas())) + + if ( + pa.types.is_list(arr_type) or pa.types.is_large_list(arr_type) + ) and pa.types.is_primitive(arr_type.value_type): + return torch.nested.as_nested_tensor( + list(map(torch.from_numpy, arr.to_pandas())) + ) + + if pa.types.is_fixed_size_list(arr_type) and pa.types.is_primitive( + arr_type.value_type + ): + return torch.from_numpy(np.reshape(arr.values, (-1, arr_type.list_size))) + + if pa.types.is_struct(arr_type): + return { + arr_type.field(i).name: from_pyarrow_to_torch_tensor(arr.field(i)) + for i in range(arr_type.num_fields) + } + + if pa.types.is_nested(arr_type): + # TODO: deal with arr = [[{'a': 1}, {'a': 2}]] + pass + + if strict: + raise NotImplementedError(f"{arr_type} cannot be converted to torch.Tensor") + else: + return arr + + +def pyarrow_table_to_torch_dict(tt: pa.Table, strict: bool = True) -> NestedDict: + return { + col: from_pyarrow_to_torch_tensor(tt[col], strict) for col in tt.column_names + } + + +def init_parquet_dataset( + parquet_path: str, + filters: Optional[pa.dataset.Expression] = None, + filesystem: Optional[pa.fs.FileSystem] = None, +) -> pq.ParquetDataset: + source_ds = pq.ParquetDataset( + parquet_path, + validate_schema=True, + filters=filters, + filesystem=filesystem, + ) + return source_ds + + +def get_dataset_fragments( + dataset: pq.ParquetDataset, filters: pa.dataset.Expression +) -> List[pa.dataset.Fragment]: + """ + This could be simplified once `split_row_groups=True` is implemented at `pq.ParquetDataset`. + We could also return a generator instead of list (when getting full infos from S3 may be slow) + """ + return list(dataset._dataset.get_fragments(filters)) + + +def split_fragment_in_row_groups( + fragment: pa.dataset.Fragment, +) -> List[pa.dataset.Fragment]: + return list(fragment.split_by_row_group()) + + +def add_partitioning_values( + table: pa.Table, fragment: pa.dataset.Fragment, columns: Optional[List[str]] +) -> pa.Table: + """ + When loading a single fragment, pyarrow does not add the partitioning columns, + so we need to do it manually. + """ + for key, val in get_partition_keys(fragment.partition_expression).items(): + if columns is None or key in columns: + values = pa.DictionaryArray.from_arrays( + np.zeros(len(table), dtype=np.int32), [val] + ) + table = table.append_column(key, values) + return table + + +def load_one_fragment( + fragment: pa.dataset.Fragment, columns: Optional[List[str]] = None +) -> pa.Table: + fragment_columns = columns + if fragment_columns is not None: + fragment_columns = [ + col for col in fragment_columns if col in fragment.physical_schema.names + ] + fragment_table = fragment.to_table(columns=fragment_columns, use_threads=False) + fragment_table = add_partitioning_values(fragment_table, fragment, columns) + return fragment_table + + +def apply_filter( + table: pa.Table, + filters: Optional[pa.dataset.Expression] = None, + drop_null: bool = True, +) -> pa.Table: + if drop_null: + table = table.drop_null() + if filters is not None: + table = table.filter(filters) + return table + + +def concat_table(tables: List[pa.Table], combine: bool = True) -> pa.Table: + result = pa.concat_tables( + tables, + promote_options="permissive", # needed to get deal with empty segments + ) + if combine: + result = result.combine_chunks() + return result + + +def compute_length_splits( + length_col: NDArray[np.int32], max_tokens: int +) -> List[NDArray[np.int32]]: + """split sequence of length_col in the chunks such that total length is ~ max_tokens + countint the padding to max length of elements in a chunk + + Args: + length_col (np.ndarray): + max_tokens (int): + + Returns: + List[np.ndarray]: splits that contain indices over the original length_col + """ + argsort_ind = np.argsort(length_col) + # TODO: remove 0 lengths + sorted_length_col = length_col[argsort_ind] + + splits = [] + ptr = 0 + for i, length in enumerate(sorted_length_col): + if length * (i - ptr) > max_tokens: + splits.append(argsort_ind[ptr : (i - 1)]) + ptr = i - 1 + if ( + length <= max_tokens + ): # we drop the last iteration if it results in a batch greater than max_tokens + splits.append(argsort_ind[ptr:]) + return splits + + +def compute_rows_length(pa_array: pa.Array) -> NDArray[np.int32]: + type_ = pa_array.type + if pa.types.is_list(type_) or pa.types.is_large_list(type_): + length_col = pa.compute.list_value_length(pa_array).to_numpy() + elif pa.types.is_string(type_): + length_col = pa.compute.utf8_length(pa_array).to_numpy() + else: + length_col = np.asarray(pa_array.to_pandas().apply(len)) + + length_col = length_col.copy() + length_col[np.isnan(length_col)] = 0 + return np.asarray(length_col, dtype=np.int32) + + +class _TableWrapper: + """ + class to avoid fairseq2 casting pa.Table to iterable objects + which currently fails + """ + + def __init__(self, table: pa.Table) -> None: + self.table: pa.Table = table + + +def _to_real_object(x: Union[_TableWrapper, NestedDict]) -> BatchOutputType: + if isinstance(x, _TableWrapper): + return x.table + elif isinstance(x, list): + return [_to_real_object(e) for e in x] + elif isinstance(x, tuple): + return tuple(_to_real_object(e) for e in x) + else: + return x + + +def table_func_wrap(func): # type: ignore + def inner(*args): # type: ignore + fixed_args = [_to_real_object(x) for x in args] + result = func(*fixed_args) + if isinstance(result, (pa.Table, pd.DataFrame)): + result = _TableWrapper(result) + return result + + return inner + + +def list_parquet_fragments( + parquet_path: str, + filters: Optional[pa.dataset.Expression] = None, + columns: Optional[List[str]] = None, + split_to_row_groups: bool = True, + filesystem: Optional[pa.fs.FileSystem] = None, + shuffle_window: Optional[int] = None, + seed: Optional[int] = None, +) -> DataPipelineBuilder: + dataset = init_parquet_dataset(parquet_path, filters=filters, filesystem=filesystem) + columns = columns or dataset.schema.names + if not set(columns).issubset(set(dataset.schema.names)): + raise ValueError( + f"columns {sorted(set(columns) - set(dataset.schema.names))} are not found in the dataset schema" + ) + + pipeline_builder = read_sequence(get_dataset_fragments(dataset, filters)) + + with torch_random_seed(seed): + if shuffle_window is not None: + # shuffle them in full memory since fragments are already known + pipeline_builder = pipeline_builder.shuffle(shuffle_window=0) + + if split_to_row_groups: + pipeline_builder = pipeline_builder.yield_from( + lambda fragment: read_sequence( + split_fragment_in_row_groups(fragment) + ).and_return() + ) + if shuffle_window is not None: + pipeline_builder = pipeline_builder.shuffle( + shuffle_window=shuffle_window + ) + + return pipeline_builder + + +def build_iterator_over_one_table( + table: pa.Table, + order_by_length: Optional[str] = None, + batch_size: Optional[int] = None, + max_tokens: Optional[int] = None, + shuffle: bool = True, + seed: Optional[int] = None, + num_parallel_calls: int = 8, +) -> DataPipeline: + random_state = np.random.RandomState(seed) + if order_by_length is not None: + length_col = compute_rows_length(table[order_by_length]) + # add small perturbation to avoid same sample appear together during different epochs + if shuffle: + perturbation = random_state.randint( + 0, + np.quantile(length_col, 0.001).astype(np.int32) + 2, + len(length_col), + ) + length_col += np.asarray(perturbation, dtype=np.int32) + else: + if shuffle: + length_col = random_state.randint(0, 2**23, len(table)) + else: + length_col = np.zeros(len(table), dtype=np.int32) + + if batch_size is not None: + order_tt = pa.Table.from_arrays( + [pa.array(np.argsort(length_col, kind="stable"))], ["order"] + ) + batches = [ind["order"] for ind in order_tt.to_batches(batch_size)] + elif max_tokens is not None: + batches = compute_length_splits(length_col, max_tokens) + else: + raise ValueError("unknown batching method") + + if shuffle: + batches = [batches[i] for i in random_state.permutation(len(batches))] + + return ( + read_sequence(batches) + .map( + table_func_wrap(lambda ind: table.take(ind).combine_chunks()), + num_parallel_calls=num_parallel_calls, + ) + .and_return(max_num_warnings=4) + ) diff --git a/tests/integration/parquet/test_parquet_dataloader.py b/tests/integration/parquet/test_parquet_dataloader.py new file mode 100644 index 000000000..fe00187c3 --- /dev/null +++ b/tests/integration/parquet/test_parquet_dataloader.py @@ -0,0 +1,289 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import random +import shutil +import string +import tempfile +from collections import Counter +from typing import Any, Dict, Generator, List, Union + +import pytest + +try: + import numpy as np + import pandas as pd + import pyarrow as pa + import pyarrow.parquet as pq + + arrow_found = True + + from numpy.typing import NDArray + + from recipes.parquet.parquet_dataloader import ( + ParquetBasicDataloaderConfig, + ParquetBatchFormat, + parquet_iterator, + ) +except ImportError: + arrow_found = False + + +def gen_random_string(length: int) -> str: + return "".join( + random.choice(string.ascii_letters + string.digits) for n in range(length) + ) + + +def generate_random_pandas_df(size: int, seed: int = 123) -> pd.DataFrame: + np_rs = np.random.RandomState(seed) + df: Dict[str, Union[NDArray[Any], List[Any]]] = {} + df["int_col"] = np_rs.randint(0, 200, size) + df["float_col"] = np_rs.randn(size) + + df["string_col1"] = [gen_random_string(10) for _ in range(size)] + df["string_col2"] = [gen_random_string(2) for _ in range(size)] + + df["list_int_col"] = [ + np_rs.randint(-10, 10, np_rs.randint(0, 100)) for _ in range(size) + ] + df["list_float_col"] = [ + np_rs.rand(np_rs.randint(0, 10)).astype(np.float32) for _ in range(size) + ] + df["list_float_fixed_size_col"] = [ + np_rs.rand(7).astype(np.float32) for _ in range(size) + ] + return pd.DataFrame(df) + + +def generated_partitioned_parquet_file( + path: str, size: int, n_partitions: int = 20, seed: int = 123 +) -> None: + df = generate_random_pandas_df(size, seed) + + if n_partitions > 0: + df["part_key"] = np.arange(size) % n_partitions + + table = pa.Table.from_pandas(df) + + pq.write_to_dataset( + table, + path, + partition_cols=["part_key"] if n_partitions > 0 else None, + existing_data_behavior="delete_matching", + **{"row_group_size": 110}, + ) + + +@pytest.fixture() +def single_file() -> Generator[str, None, None]: + tmpdir = tempfile.mkdtemp() + tmp_parquet_ds_path = os.path.join(tmpdir, "test") + generated_partitioned_parquet_file( + tmp_parquet_ds_path, size=10**3, n_partitions=0 + ) + yield tmp_parquet_ds_path + shutil.rmtree(tmpdir) + + +@pytest.fixture() +def multi_partition_file() -> Generator[str, None, None]: + tmpdir = tempfile.mkdtemp() + tmp_parquet_ds_path = os.path.join(tmpdir, "test") + generated_partitioned_parquet_file(tmp_parquet_ds_path, size=2 * 10**3) + yield tmp_parquet_ds_path + shutil.rmtree(tmpdir) + + +@pytest.mark.skipif(not arrow_found, reason="arrow not found") +class TestParquetDataloader: + def test_simple_dataload(self, multi_partition_file: str) -> None: + config = ParquetBasicDataloaderConfig( + parquet_path=multi_partition_file, + batch_size=11, + nb_parallel_fragments=2, + seed=333, + ) + res: List[pd.DataFrame] = list(parquet_iterator(config)) + + assert all(isinstance(x, pa.Table) for x in res) + + assert list(res[0].to_pandas().columns) == [ + "int_col", + "float_col", + "string_col1", + "string_col2", + "list_int_col", + "list_float_col", + "list_float_fixed_size_col", + "part_key", + ] + + assert Counter(map(len, res)) == Counter({11: 180, 2: 10}) # 180 * 11 + assert sum(map(len, res)) == 2000 + + # determinism check + config_new = ParquetBasicDataloaderConfig( + parquet_path=multi_partition_file, + batch_size=11, + nb_parallel_fragments=2, + seed=333, + output_format=ParquetBatchFormat.pandas, + ) + res_bis = list(parquet_iterator(config_new)) + + assert all(isinstance(x, pd.DataFrame) for x in res_bis) + + assert all( + (x["float_col"].to_pandas() == y["float_col"]).all() + for x, y in zip(res, res_bis) + ) + + config_another_seed = ParquetBasicDataloaderConfig( + parquet_path=multi_partition_file, + batch_size=11, + nb_parallel_fragments=2, + seed=111, + output_format=ParquetBatchFormat.pandas, + ) + res_ter = list(parquet_iterator(config_another_seed)) + assert any( + (x["float_col"] != y["float_col"]).any() for x, y in zip(res, res_ter) + ) + + def test_filtered_with_columns_dataload(self, multi_partition_file: str) -> None: + config = ParquetBasicDataloaderConfig( + parquet_path=multi_partition_file, + batch_size=3, + nb_parallel_fragments=5, + seed=111, + columns=["string_col2", "list_int_col", "float_col"], + filters=[("float_col", ">", 0)], + output_format=ParquetBatchFormat.pandas, + ) + + res: List[pd.DataFrame] = list(parquet_iterator(config)) + + assert list(res[0].columns) == ["string_col2", "list_int_col", "float_col"] + + assert Counter(map(len, res)) == Counter({3: 339, 1: 3, 2: 1}) + + def test_filtered_with_columns_dataload_min_batch_size( + self, multi_partition_file: str + ) -> None: + config = ParquetBasicDataloaderConfig( + parquet_path=multi_partition_file, + batch_size=3, + nb_parallel_fragments=5, + seed=111, + min_batch_size=3, + columns=["string_col2", "list_int_col", "float_col"], + filters=[("float_col", ">", 0)], + output_format=ParquetBatchFormat.pandas, + ) + res = list(parquet_iterator(config)) + assert Counter(map(len, res)) == Counter({3: 339}) + + def test_ordered_dataload(self, multi_partition_file: str) -> None: + config = ParquetBasicDataloaderConfig( + parquet_path=multi_partition_file, + batch_size=20, + nb_parallel_fragments=20, + order_by_length="list_int_col", + seed=123, + output_format=ParquetBatchFormat.pandas, + ) + res: List[pd.DataFrame] = list(parquet_iterator(config)) + length_by_batches = [tt["list_int_col"].apply(len) for tt in res] + length_by_batches_diff = max(tt.max() - tt.min() for tt in length_by_batches) + total_length = sum(map(len, length_by_batches)) + + assert length_by_batches_diff < 4 + assert total_length == 2000 + assert all(len(tt) == 20 for tt in length_by_batches) + + def test_ordered_max_token_dataload(self, multi_partition_file: str) -> None: + config = ParquetBasicDataloaderConfig( + parquet_path=multi_partition_file, + nb_parallel_fragments=20, + order_by_length="list_int_col", + max_tokens=3000, + seed=123, + output_format=ParquetBatchFormat.pandas, + ) + res: List[pd.DataFrame] = list(parquet_iterator(config)) + length_by_batches = [tt["list_int_col"].apply(len) for tt in res] + length_by_batches_diff = max(tt.max() - tt.min() for tt in length_by_batches) + max_padded_total_length = max(tt.max() * len(tt) for tt in length_by_batches) + mean_padded_total_length = np.mean( + [tt.max() * len(tt) for tt in length_by_batches] + ) + total_length = sum(map(len, length_by_batches)) + + assert length_by_batches_diff <= 12 + assert total_length == 2000 + assert max_padded_total_length <= 3000 + assert mean_padded_total_length >= 2900 + + def test_ordered_max_token_single_file_dataload(self, single_file: str) -> None: + config = ParquetBasicDataloaderConfig( + parquet_path=single_file, + nb_parallel_fragments=2, + batch_size=10, + seed=333, + ) + res: List[pa.Table] = list(parquet_iterator(config)) + + assert Counter(map(len, res)) == Counter({10: 100}) + + assert res[0].column_names == [ + "int_col", + "float_col", + "string_col1", + "string_col2", + "list_int_col", + "list_float_col", + "list_float_fixed_size_col", + ] + + def test_dataload_without_shuffle(self, multi_partition_file: str) -> None: + config = ParquetBasicDataloaderConfig( + parquet_path=multi_partition_file, + nb_parallel_fragments=4, + nb_prefetch=2, + num_parallel_calls=3, + shuffle=False, + batch_size=17, + columns=["float_col"], + ) + res = pa.concat_tables(list(parquet_iterator(config))) + res_relaod = pq.read_table(multi_partition_file, columns=["float_col"]) + + assert res.equals(res_relaod) + + def test_dataload_max_row_groups(self, single_file: str) -> None: + config = ParquetBasicDataloaderConfig( + parquet_path=single_file, + nb_parallel_fragments=1, + nb_prefetch=2, + num_parallel_calls=3, + batch_size=250, + ) + res = list(list(parquet_iterator(config))) + + assert Counter(list(map(len, res))) == Counter({110: 9, 10: 1}) + + config = ParquetBasicDataloaderConfig( + parquet_path=single_file, + nb_parallel_fragments=2, # increasing this + nb_prefetch=2, + num_parallel_calls=3, + batch_size=250, + ) + res = list(list(parquet_iterator(config))) + + assert Counter(list(map(len, res))) == Counter({220: 4, 120: 1})