From 6afa233cf019b6468ed26cccf4b4582a45782df8 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Fri, 22 Sep 2023 15:09:05 -0700 Subject: [PATCH 01/11] Training on PQ: draft generate.py, index.py, leaviung train.py --- scripts/parquet/generate.py | 141 ++++++++++++++++++++++++++++++++++++ scripts/parquet/index.py | 110 ++++++++++++++++++++++++++++ 2 files changed, 251 insertions(+) create mode 100644 scripts/parquet/generate.py create mode 100644 scripts/parquet/index.py diff --git a/scripts/parquet/generate.py b/scripts/parquet/generate.py new file mode 100644 index 000000000..de14e6a00 --- /dev/null +++ b/scripts/parquet/generate.py @@ -0,0 +1,141 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Generate a parquet dataset for testing.""" + +import os +from argparse import ArgumentParser, Namespace +from typing import List, Tuple + +import numpy as np +import pyarrow as pa +from pyarrow import parquet as pq + + +def parse_args() -> Namespace: + """Parse command-line arguments. + + Returns: + Namespace: Command-line arguments. + """ + args = ArgumentParser() + args.add_argument('--num_train', type=int, default=10_000_000) + args.add_argument('--num_val', type=int, default=1_000_000) + args.add_argument('--out', type=str, default='data/pq/') + args.add_argument('--samples_per_shard', type=int, default=10_000) + return args.parse_args() + + +_ones = ('zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen ' + 'fifteen sixteen seventeen eighteen nineteen').split() + +_tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split() + + +def say(i: int) -> List[str]: + """Get the word form of a number. + + Args: + i (int): The number. + + Returns: + List[str]: The number in word form. + """ + if i < 0: + return ['negative'] + say(-i) + elif i <= 19: + return [_ones[i]] + elif i < 100: + return [_tens[i // 10 - 2]] + ([_ones[i % 10]] if i % 10 else []) + elif i < 1_000: + return [_ones[i // 100], 'hundred'] + (say(i % 100) if i % 100 else []) + elif i < 1_000_000: + return say(i // 1_000) + ['thousand'] + (say(i % 1_000) if i % 1_000 else []) + elif i < 1_000_000_000: + return say(i // 1_000_000) + ['million'] + (say(i % 1_000_000) if i % 1_000_000 else []) + else: + raise ValueError('Integer must be less than a billion, but got: {i}') + + +def generate_number() -> int: + """Generate a random integer to say. + + Returns: + int: The integer. + """ + sign = (np.random.uniform() < 0.8) * 2 - 1 + expt = np.random.uniform(0, 9) + mag = int(10**expt) + return sign * mag + + +def generate_numbers(num_train: int, num_val: int) -> Tuple[List[int], List[int]]: + """Get two non-overlapping splits of integers to say. + + Args: + num_train (int): Number of training samples. + num_val (int): Number of validation samples. + + Returns: + Tuple[List[int], List[int]]: The two generated splits. + """ + total = num_train + num_val + nums = set() + while len(nums) < total: + num = generate_number() + if num in nums: + continue + nums.add(num) + nums = sorted(nums) + np.random.shuffle(nums) + train_nums = nums[:num_train] + val_nums = nums[num_train:] + return train_nums, val_nums + + +def save_parquets(nums: List[int], txts: List[str], dirname: str, samples_per_shard: int) -> None: + """Save a parquet dataaset given the samples. + + Args: + nums (List[int]): List of sample integers. + txts (List[str]): List of sample texts. + dirname (str): Output dirname. + samples_per_shard (int): Output shard size in samples. + """ + if not os.path.exists(dirname): + os.makedirs(dirname) + num_shards = (len(nums) + samples_per_shard - 1) // samples_per_shard + for shard_id in range(num_shards): + begin = shard_id * samples_per_shard + end = min(begin + samples_per_shard, len(nums)) + shard_nums = nums[begin:end] + shard_txts = txts[begin:end] + filename = os.path.join(dirname, f'{shard_id:05}.parquet') + obj = { + 'num': shard_nums, + 'txt': shard_txts, + } + table = pa.Table.from_pydict(obj) + pq.write_table(table, filename) + + +def main(args: Namespace) -> None: + """Generate a parquet dataset for testing. + + Args: + args (Namespace): Command-line arguments. + """ + train_nums, val_nums = generate_numbers(args.num_train, args.num_val) + + train_txts = [' '.join(say(num)) for num in train_nums] + val_txts = [' '.join(say(num)) for num in val_nums] + + dirname = os.path.join(args.out, 'train') + save_parquets(train_nums, train_txts, dirname, args.samples_per_shard) + + dirname = os.path.join(args.out, 'val') + save_parquets(val_nums, val_txts, dirname, args.samples_per_shard) + + +if __name__ == '__main__': + main(parse_args()) diff --git a/scripts/parquet/index.py b/scripts/parquet/index.py new file mode 100644 index 000000000..2168fe92e --- /dev/null +++ b/scripts/parquet/index.py @@ -0,0 +1,110 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Index a parquet dataset for use by Streaming.""" + +import json +import os +from argparse import ArgumentParser, Namespace +from typing import Any, Dict, Iterator, Tuple + +from pyarrow import parquet as pq + + +def parse_args() -> Namespace: + """Parse command-line arguments. + + Returns: + Namespace: Command-line arguments. + """ + args = ArgumentParser() + args.add_argument('--dataset', type=str, required=True) + args.add_argument('--shard_suffix', type=str, default='.parquet') + return args.parse_args() + + +def get_dataset_relative_path(dataset_root: str, path: str) -> str: + """Get the dataset-relative path of a shard file. + + Args: + dataset_root (str): Dataset root directory containing this shard. + path (str): Path to shard under dataset root dir. + + Returns: + Dataset-relative shard path. + """ + if not path.startswith(dataset_root): + raise ValueError('Path {path} was not found under {dataset_root}.') + rel_path = path[len(dataset_root):] + + while rel_path.startswith(os.path.sep): + rel_path = rel_path[1:] + + return rel_path + + +def each_shard_path(dataset_root: str, shard_suffix: str) -> Iterator[Tuple[str, str]]: + """Collect each Parquet shard, in order. + + Args: + dataset_root (str): Dataset root directory. + shard_suffix (str): Suffix of each Parquet shard file. + + Returns: + Iterator[Tuple[str, str]]: Iterator over absolute and dataset-relative paths. + """ + for root, _, files in os.walk(dataset_root): + files = filter(lambda file: file.endswith(shard_suffix), files) + files = (os.path.join(root, file) for file in files) + files = sorted(files) + for path in files: + dataset_rel_path = get_dataset_relative_path(dataset_root, path) + yield path, dataset_rel_path + + +def get_shard_info(path: str, dataset_rel_path: str) -> Dict[str, Any]: + """Get info the index needs about a Parquet shard. + + Args: + path (str): Absolute or relative-to-cwd file path. + dataset_rel_path (str): Relative-to-dataset file path. + + Returns: + Dict[str, Any]: Shard info. + """ + num_bytes = os.stat(path).st_size + table = pq.read_table(path) + num_samples = len(table) + return { + 'format': 'parquet', + 'raw_parquet': { + 'basename': dataset_rel_path, + 'bytes': num_bytes, + }, + 'samples': num_samples, + } + + +def main(args: Namespace) -> None: + """Index a parquet dataset for use by Streaming. + + Args: + args (Namespace): Command-line arguments. + """ + infos = [] + for path, dataset_rel_path in each_shard_path(args.dataset, args.shard_suffix): + info = get_shard_info(path, dataset_rel_path) + infos.append(info) + obj = { + 'version': 2, + 'shards': infos, + } + filename = os.path.join(args.dataset, 'index.json') + if os.path.exists(filename): + raise ValueError(f'Index file {filename} already exists.') + with open(filename, 'w') as out: + json.dump(obj, out) + + +if __name__ == '__main__': + main(parse_args()) From 977980534cce943e8d2d2844d1e83e122b1aa757 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 23 Sep 2023 12:16:53 -0700 Subject: [PATCH 02/11] Updates. --- scripts/parquet/generate.py | 6 +- scripts/parquet/index.py | 57 +++++++- scripts/parquet/iterate.py | 77 +++++++++++ streaming/base/dataset.py | 3 +- streaming/base/format/__init__.py | 4 +- streaming/base/format/base/reader.py | 24 +++- streaming/base/format/mds/reader.py | 5 +- streaming/base/format/pq/__init__.py | 8 ++ streaming/base/format/pq/reader.py | 189 +++++++++++++++++++++++++++ streaming/base/stream.py | 2 +- 10 files changed, 357 insertions(+), 18 deletions(-) create mode 100644 scripts/parquet/iterate.py create mode 100644 streaming/base/format/pq/__init__.py create mode 100644 streaming/base/format/pq/reader.py diff --git a/scripts/parquet/generate.py b/scripts/parquet/generate.py index de14e6a00..0d67846a1 100644 --- a/scripts/parquet/generate.py +++ b/scripts/parquet/generate.py @@ -21,7 +21,7 @@ def parse_args() -> Namespace: args = ArgumentParser() args.add_argument('--num_train', type=int, default=10_000_000) args.add_argument('--num_val', type=int, default=1_000_000) - args.add_argument('--out', type=str, default='data/pq/') + args.add_argument('--dataset', type=str, default='data/pq/') args.add_argument('--samples_per_shard', type=int, default=10_000) return args.parse_args() @@ -130,10 +130,10 @@ def main(args: Namespace) -> None: train_txts = [' '.join(say(num)) for num in train_nums] val_txts = [' '.join(say(num)) for num in val_nums] - dirname = os.path.join(args.out, 'train') + dirname = os.path.join(args.dataset, 'train') save_parquets(train_nums, train_txts, dirname, args.samples_per_shard) - dirname = os.path.join(args.out, 'val') + dirname = os.path.join(args.dataset, 'val') save_parquets(val_nums, val_txts, dirname, args.samples_per_shard) diff --git a/scripts/parquet/index.py b/scripts/parquet/index.py index 2168fe92e..ed0d73f96 100644 --- a/scripts/parquet/index.py +++ b/scripts/parquet/index.py @@ -6,10 +6,12 @@ import json import os from argparse import ArgumentParser, Namespace -from typing import Any, Dict, Iterator, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from pyarrow import parquet as pq +from streaming.base.format.mds.encodings import get_mds_encoded_size + def parse_args() -> Namespace: """Parse command-line arguments. @@ -62,6 +64,42 @@ def each_shard_path(dataset_root: str, shard_suffix: str) -> Iterator[Tuple[str, yield path, dataset_rel_path +def get_column(val: Any) -> str: + """Get the MDS column encoding of one field. + + Args: + val (Any): The field. + + Returns: + str: Its corresponding MDS encoding. + """ + if isinstance(val, int): + return 'int' + elif isinstance(val, str): + return 'str' + else: + raise ValueError('Unsupported column type: {type(val)}.') + + +def get_columns(sample: Dict[str, Any]) -> Tuple[List[str], List[str], List[Optional[int]]]: + """Get column names, encodings, and sizes. + + Args: + sample (Dict[str, Any]): A sample to derive column info from. + + Returns: + Tuple[List[str], List[str], List[Optional[int]]]: Column names, encodings, and sizes. + """ + col_names = sorted(sample) + col_encs = [] + for name in col_names: + val = sample[name] + enc = get_column(val) + col_encs.append(enc) + col_sizes = list(map(get_mds_encoded_size, col_encs)) + return col_names, col_encs, col_sizes + + def get_shard_info(path: str, dataset_rel_path: str) -> Dict[str, Any]: """Get info the index needs about a Parquet shard. @@ -74,14 +112,23 @@ def get_shard_info(path: str, dataset_rel_path: str) -> Dict[str, Any]: """ num_bytes = os.stat(path).st_size table = pq.read_table(path) - num_samples = len(table) + samples = table.to_pylist() + num_samples = len(samples) + col_names, col_encs, col_sizes = get_columns(samples[0]) return { - 'format': 'parquet', + 'version': 2, + 'format': 'pq', + 'column_names': col_names, + 'column_encodings': col_encs, + 'column_sizes': col_sizes, 'raw_parquet': { 'basename': dataset_rel_path, - 'bytes': num_bytes, + 'bytes': num_bytes + }, + 'raw_data': { + 'basename': dataset_rel_path + '.mds' }, - 'samples': num_samples, + 'samples': num_samples } diff --git a/scripts/parquet/iterate.py b/scripts/parquet/iterate.py new file mode 100644 index 000000000..68d7e57dc --- /dev/null +++ b/scripts/parquet/iterate.py @@ -0,0 +1,77 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Randomly iterate over a Parquet dataset with Streaming.""" + +import os +from argparse import ArgumentParser, Namespace +from time import time + +import numpy as np +from matplotlib import pyplot as plt +from tqdm import tqdm, trange + +from streaming import StreamingDataset + + +def parse_args() -> Namespace: + """Parse command-line arguments. + + Returns: + Namespace: Command-line arguments. + """ + args = ArgumentParser() + args.add_argument('--dataset', type=str, required=True) + args.add_argument('--cache_mds', type=int, default=1) + args.add_argument('--plot', type=str, required=True) + return args.parse_args() + + +def clear(local: str) -> None: + """Clear the intermediate MDS shard files.""" + for root, _, files in os.walk(local): + for file in files: + file = os.path.join(root, file) + if file.endswith('.mds'): + os.remove(file) + + +def main(args: Namespace) -> None: + """Randomly iterate over a Parquet dataset with Streaming. + + Args: + args (Namespace): Command-line arguments. + """ + dataset = StreamingDataset(local=args.dataset) + + if not args.cache_mds: + clear(args.dataset) + + seq_times = np.zeros(dataset.num_samples) + t0 = time() + for i in trange(dataset.num_samples): + dataset[i] + seq_times[i] = time() - t0 + + if not args.cache_mds: + clear(args.dataset) + + indices = np.random.permutation(dataset.num_samples) + rand_times = np.zeros(dataset.num_samples) + t0 = time() + for i, index in enumerate(tqdm(indices)): + dataset[index] + rand_times[i] = time() - t0 + + plt.title('Parquet sample access times') + plt.xlabel('Samples seen') + plt.ylabel('Time (seconds)') + samples = np.arange(dataset.num_samples) + plt.plot(samples, seq_times, c='blue', label='Sequential') + plt.plot(samples, rand_times, c='red', label='Random') + plt.legend() + plt.savefig(args.plot, dpi=500) + + +if __name__ == '__main__': + main(parse_args()) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index e4b7d1cfd..fe61390b2 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -1090,8 +1090,7 @@ def prepare_shard(self, shard_id: int, blocking: bool = True) -> None: shard = self.shards[shard_id] # We may need to decompress the shard (if local dir just contains zips). - raw_info, _ = shard.file_pairs[0] # Each file pair is present in the same way. - raw_filename = os.path.join(stream.local, stream.split, raw_info.basename) # Find raw. + raw_filename = os.path.join(stream.local, stream.split, shard.raw_data.basename) if not os.path.isfile(raw_filename): # Is raw missing? self._shard_states[shard_id] = _ShardState.PREPARING # Lock the shard. lock.release() # Unblock other workers. diff --git a/streaming/base/format/__init__.py b/streaming/base/format/__init__.py index 962828ae2..f8323e461 100644 --- a/streaming/base/format/__init__.py +++ b/streaming/base/format/__init__.py @@ -9,11 +9,12 @@ from streaming.base.format.index import get_index_basename from streaming.base.format.json import JSONReader, JSONWriter from streaming.base.format.mds import MDSReader, MDSWriter +from streaming.base.format.pq import PQReader from streaming.base.format.xsv import (CSVReader, CSVWriter, TSVReader, TSVWriter, XSVReader, XSVWriter) __all__ = [ - 'CSVWriter', 'FileInfo', 'get_index_basename', 'JSONWriter', 'MDSWriter', 'Reader', + 'CSVWriter', 'FileInfo', 'get_index_basename', 'JSONWriter', 'MDSWriter', 'PQReader', 'Reader', 'reader_from_json', 'TSVWriter', 'XSVWriter' ] @@ -21,6 +22,7 @@ 'csv': CSVReader, 'json': JSONReader, 'mds': MDSReader, + 'pq': PQReader, 'tsv': TSVReader, 'xsv': XSVReader } diff --git a/streaming/base/format/base/reader.py b/streaming/base/format/base/reader.py index 80ec45231..bffe2f1a6 100644 --- a/streaming/base/format/base/reader.py +++ b/streaming/base/format/base/reader.py @@ -23,9 +23,14 @@ class FileInfo(object): bytes (int): File size in bytes. hashes (Dict[str, str]): Mapping of hash algorithm to hash value. """ - basename: str - bytes: int - hashes: Dict[str, str] + + def __init__(self, + basename: str, + bytes: Optional[int] = None, + hashes: Optional[Dict[str, str]] = None): + self.basename = basename + self.bytes = bytes + self.hashes = hashes or {} class Reader(Array, ABC): @@ -133,7 +138,7 @@ def set_up_local(self, listing: Set[str], safe_keep_zip: bool) -> int: compression was used. Necessary when local is the remote or there is no remote. Returns: - bool: Whether the shard is present. + int: Shard cache usage. """ # For raw/zip to be considered present, each raw/zip file must be present. raw_files_present = 0 @@ -318,6 +323,17 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: for i in range(len(self)): yield self[i] + def prepare(self, safe_keep_zip: bool) -> int: + """Do any additional work to prepare a shard for use. + + Args: + safe_keep_zip (bool): Whether to keep zip shard files, or drop post-conversion. + + Returns: + int: Change in cache usage in bytes due to preparation. + """ + return 0 + class JointReader(Reader): """Provides random access to the samples of a joint shard. diff --git a/streaming/base/format/mds/reader.py b/streaming/base/format/mds/reader.py index 275f01192..1f3f4a5ec 100644 --- a/streaming/base/format/mds/reader.py +++ b/streaming/base/format/mds/reader.py @@ -80,8 +80,9 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S args['dirname'] = dirname args['split'] = split for key in ['raw_data', 'zip_data']: - arg = args[key] - args[key] = FileInfo(**arg) if arg else None + arg = args.get(key) + if arg: + args[key] = FileInfo(**arg) return cls(**args) def decode_sample(self, data: bytes) -> Dict[str, Any]: diff --git a/streaming/base/format/pq/__init__.py b/streaming/base/format/pq/__init__.py new file mode 100644 index 000000000..366fd3459 --- /dev/null +++ b/streaming/base/format/pq/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Module to read the dataset in PQ format.""" + +from streaming.base.format.pq.reader import PQReader + +__all__ = ['PQReader'] diff --git a/streaming/base/format/pq/reader.py b/streaming/base/format/pq/reader.py new file mode 100644 index 000000000..2bc0eb465 --- /dev/null +++ b/streaming/base/format/pq/reader.py @@ -0,0 +1,189 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""PQReader reads Parquet shards for StreamingDataset (via MDS internally).""" + +import os +from copy import deepcopy +from tempfile import TemporaryDirectory +from typing import Any, Dict, List, Optional, Set + +from pyarrow import parquet as pq +from typing_extensions import Self + +from streaming.base.format.base.reader import FileInfo +from streaming.base.format.mds.reader import MDSReader +from streaming.base.format.mds.writer import MDSWriter + + +class PQReader(MDSReader): + """Provides random access to the samples of a Parquet shard (via MDS internally). + + Args: + dirname (str): Local dataset directory. + split (str, optional): Which dataset split to use, if any. + column_encodings (List[str]): Column encodings. + column_names (List[str]): Column names. + column_sizes (List[Optional[int]]): Column fixed sizes, if any. + raq_parquet (FileInfo): Non-compressed Parquet file info. + raw_data (FileInfo): Uncompressed data file info. + samples (int): Number of samples in this shard. + """ + + def __init__( + self, + dirname: str, + split: Optional[str], + column_encodings: List[str], + column_names: List[str], + column_sizes: List[Optional[int]], + raw_parquet: FileInfo, + raw_data: FileInfo, + samples: int, + ) -> None: + super().__init__(dirname=dirname, + split=split, + column_encodings=column_encodings, + column_names=column_names, + column_sizes=column_sizes, + compression=None, + hashes=[], + raw_data=raw_data, + samples=samples, + size_limit=None, + zip_data=None) + self.raw_parquet = raw_parquet + self.file_pairs.append((raw_parquet, None)) + + @classmethod + def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> Self: + """Initialize from JSON object. + + Args: + dirname (str): Local directory containing shards. + split (str, optional): Which dataset split to use, if any. + obj (Dict[str, Any]): JSON object to load. + + Returns: + Self: Loaded PQReader. + """ + args = deepcopy(obj) + + if args['version'] != 2: + raise ValueError(f'Unsupported streaming data version: {args["version"]}. ' + + f'Expected version 2.') + del args['version'] + + if args['format'] != 'pq': + raise ValueError(f'Unsupported data format: {args["format"]}. ' + + f'Expected to be `pq`.') + del args['format'] + + args['dirname'] = dirname + args['split'] = split + for key in ['raw_parquet', 'raw_data', 'zip_data']: + arg = args.get(key) + if arg: + args[key] = FileInfo(**arg) + + return cls(**args) + + def set_up_local(self, listing: Set[str], safe_keep_zip: bool) -> int: + """Bring what shard files are present to a consistent state, returning whether present. + + Args: + listing (Set[str]): The listing of all files under dirname/[split/]. This is listed + once and then saved because there could potentially be very many shard files. + safe_keep_zip (bool): Whether to keep zip files when decompressing. Possible when + compression was used. Necessary when local is the remote or there is no remote. + + Returns: + int: Shard cache usage. + """ + pq_filename = os.path.join(self.dirname, self.split, self.raw_parquet.basename) + mds_filename = os.path.join(self.dirname, self.split, self.raw_data.basename) + if os.path.exists(mds_filename): + if os.path.exists(pq_filename): + if safe_keep_zip: + # Present: keep both (because of safe_keep_zip). + usage = os.stat(mds_filename).st_size + os.stat(pq_filename).st_size + else: + # Present: keep MDS, drop PQ (because of safe_keep_zip). + os.remove(pq_filename) + usage = os.stat(mds_filename).st_size + else: + if safe_keep_zip: + # Normalize to missing, because safe_keep_zip requires that we keep the PQ. + os.remove(mds_filename) + usage = 0 + else: + # Present: have MDS, don't have or want PQ. + usage = os.stat(mds_filename).st_size + else: + if os.path.exists(pq_filename): + # Present: PQ hasn't been converted to MDS yet and we don't have time to here. + usage = os.stat(pq_filename).st_size + else: + # Missing: both PQ and MDS are not there. + usage = 0 + return usage + + def get_column(self, val: Any) -> str: + """Get the MDS column encoding of one field. + + Args: + val (Any): The field. + + Returns: + str: Its corresponding MDS encoding. + """ + if isinstance(val, int): + return 'int' + elif isinstance(val, str): + return 'str' + else: + raise ValueError('Unsupported column type: {type(val)}.') + + def get_columns(self, sample: Dict[str, Any]) -> Dict[str, str]: + """Get the MDS columns given one sample. + + Args: + sample (Dict[str, Any]): Mapping of column name to value. + + Returns: + Dict[str, str]: Mapping of column name to MDS encoding. + """ + col_names = sorted(sample) + col_encs = [] + for name in col_names: + val = sample[name] + enc = self.get_column(val) + col_encs.append(enc) + return dict(zip(col_names, col_encs)) + + def prepare(self, safe_keep_zip: bool) -> int: + """Prepare a Parquet shard for fast random access by converting to MDS. + + Args: + safe_keep_zip (bool): Whether to keep Parquet shards, or drop post-conversion. + + Returns: + int: Change in cache usage in bytes due to PQ -> MDS conversion. + """ + pq_filename = os.path.join(self.dirname, self.split, self.raw_parquet.basename) + table = pq.read_table(pq_filename) + samples = table.to_pylist() + columns = self.get_columns(samples[0]) + with TemporaryDirectory() as temp_dir: + with MDSWriter(columns=columns, out=temp_dir, size_limit=None) as out: + for sample in samples: + out.write(sample) + temp_mds_filename = os.path.join(temp_dir, 'shard.00000.mds') + mds_filename = os.path.join(self.dirname, self.split, self.raw_data.basename) + os.rename(temp_mds_filename, mds_filename) + delta = os.stat(mds_filename).st_size + if not safe_keep_zip: + delta -= os.stat(pq_filename).st_size + print('REMOVING', pq_filename) + os.remove(pq_filename) + return delta diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 2770b08f3..ce1b5c041 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -419,7 +419,7 @@ def prepare_shard(self, shard: Reader) -> int: Returns: int: Change in cache usage. """ - delta = 0 + delta = shard.prepare(self.safe_keep_zip) for raw_info, zip_info in shard.file_pairs: delta += self._prepare_shard_part(raw_info, zip_info, shard.compression) return delta From bcd54356dd772fa1705d89f49d58eda999769dc7 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 23 Sep 2023 12:18:31 -0700 Subject: [PATCH 03/11] Rm print. --- streaming/base/format/pq/reader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/streaming/base/format/pq/reader.py b/streaming/base/format/pq/reader.py index 2bc0eb465..7bb622e36 100644 --- a/streaming/base/format/pq/reader.py +++ b/streaming/base/format/pq/reader.py @@ -184,6 +184,5 @@ def prepare(self, safe_keep_zip: bool) -> int: delta = os.stat(mds_filename).st_size if not safe_keep_zip: delta -= os.stat(pq_filename).st_size - print('REMOVING', pq_filename) os.remove(pq_filename) return delta From 2ac4f3cd86d221c0186482d673f2c9fc4f69ce6b Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 24 Sep 2023 10:10:00 -0700 Subject: [PATCH 04/11] Updates. --- scripts/parquet/generate.py | 2 +- scripts/parquet/iterate.py | 217 +++++++++++++++++++++++++++++++----- 2 files changed, 192 insertions(+), 27 deletions(-) diff --git a/scripts/parquet/generate.py b/scripts/parquet/generate.py index 0d67846a1..15bd4199b 100644 --- a/scripts/parquet/generate.py +++ b/scripts/parquet/generate.py @@ -22,7 +22,7 @@ def parse_args() -> Namespace: args.add_argument('--num_train', type=int, default=10_000_000) args.add_argument('--num_val', type=int, default=1_000_000) args.add_argument('--dataset', type=str, default='data/pq/') - args.add_argument('--samples_per_shard', type=int, default=10_000) + args.add_argument('--samples_per_shard', type=int, default=100_000) return args.parse_args() diff --git a/scripts/parquet/iterate.py b/scripts/parquet/iterate.py index 68d7e57dc..97594479a 100644 --- a/scripts/parquet/iterate.py +++ b/scripts/parquet/iterate.py @@ -6,9 +6,12 @@ import os from argparse import ArgumentParser, Namespace from time import time +from typing import Iterator import numpy as np from matplotlib import pyplot as plt +from numpy.typing import NDArray +from pyarrow import parquet as pq from tqdm import tqdm, trange from streaming import StreamingDataset @@ -22,20 +25,174 @@ def parse_args() -> Namespace: """ args = ArgumentParser() args.add_argument('--dataset', type=str, required=True) - args.add_argument('--cache_mds', type=int, default=1) + args.add_argument('--pq_suffix', type=str, default='.parquet') + args.add_argument('--tqdm', type=int, default=1) + args.add_argument('--time_limit', type=float, default=10) args.add_argument('--plot', type=str, required=True) return args.parse_args() -def clear(local: str) -> None: - """Clear the intermediate MDS shard files.""" - for root, _, files in os.walk(local): +def each_pq(dataset_root: str, pq_suffix: str) -> Iterator[str]: + """Iteracte over each Parquet shard file of the dataset in order. + + Args: + dataset_root (str): Dataset root directory. + pq_suffix (str): Parquet shard file suffix. + + Returns: + Iterator[str]: Each Parquet shard file. + """ + for cwd, _, files in os.walk(dataset_root): + files = filter(lambda file: file.endswith(pq_suffix), files) + files = (os.path.join(cwd, file) for file in files) + yield from sorted(files) + + +def bench_pq_seq(dataset: StreamingDataset, pq_suffix: str, use_tqdm: int) -> NDArray[np.float64]: + """Benchmark iterating a StreamingDataset in sequential order. + + Args: + dataset (StreamingDataset): The streaming dataset to iterate. + pq_suffix (str): Parquet shard file suffix. + use_tqdm (int): Whether to use tqdm. + + Returns: + NDArray[np.float64]: Time taken to process that many dataset samples. + """ + times = np.zeros(dataset.num_samples, np.float64) + pbar = tqdm(total=dataset.num_samples) if use_tqdm else None + i = 0 + dataset_root = dataset.streams[0].local + t0 = time() + for file in each_pq(dataset_root, pq_suffix): + table = pq.read_table(file) + for _ in table.to_pylist(): + times[i] = time() - t0 + i += 1 + if use_tqdm: + pbar.update(1) + return times + + +def bench_pq_rand_cached(dataset: StreamingDataset, pq_suffix: str, + use_tqdm: int) -> NDArray[np.float64]: + """Benchmark iterating a StreamingDataset in random order. + + Args: + dataset (StreamingDataset): The streaming dataset to iterate. + pq_suffix (str): Parquet shard file suffix. + use_tqdm (int): Whether to use tqdm. + + Returns: + NDArray[np.float64]: Time taken to process that many dataset samples. + """ + dataset_root = dataset.streams[0].local + shard_files = list(each_pq(dataset_root, pq_suffix)) + shard_sample_lists = [None] * len(shard_files) + indices = np.random.permutation(dataset.num_samples) + times = np.zeros(dataset.num_samples, np.float64) + pbar = tqdm(total=dataset.num_samples) if use_tqdm else None + t0 = time() + for i, sample_id in enumerate(indices): + shard_id, shard_sample_id = dataset.spanner[sample_id] + shard_samples = shard_sample_lists[shard_id] + if shard_samples is None: + shard_file = shard_files[shard_id] + table = pq.read_table(shard_file) + shard_sample_lists[shard_id] = shard_samples = table.to_pylist() + shard_samples[shard_sample_id] + times[i] = time() - t0 + if use_tqdm: + pbar.update(1) + return times + + +def bench_pq_rand_uncached(dataset: StreamingDataset, pq_suffix: str, use_tqdm: int, + time_limit: float) -> NDArray[np.float64]: + """Benchmark iterating a StreamingDataset in random order. + + Args: + dataset (StreamingDataset): The streaming dataset to iterate. + pq_suffix (str): Parquet shard file suffix. + use_tqdm (int): Whether to use tqdm. + + Returns: + NDArray[np.float64]: Time taken to process that many dataset samples. + """ + dataset_root = dataset.streams[0].local + shard_files = list(each_pq(dataset_root, pq_suffix)) + indices = np.random.permutation(dataset.num_samples) + times = np.zeros(dataset.num_samples, np.float64) + pbar = tqdm(total=dataset.num_samples) if use_tqdm else None + t0 = time() + for i, sample_id in enumerate(indices): + shard_id, shard_sample_id = dataset.spanner[sample_id] + shard_file = shard_files[shard_id] + table = pq.read_table(shard_file) + shard_samples = table.to_pylist() + shard_samples[shard_sample_id] + times[i] = t = time() - t0 + if use_tqdm: + pbar.update(1) + if time_limit <= t: + times = times[:i] + break + return times + + +def clear_mds(dataset_root: str) -> None: + """Clear the intermediate MDS shard files. + + Args: + dataset_root (str): Dataset root directoyr. + """ + for cwd, _, files in os.walk(dataset_root): for file in files: - file = os.path.join(root, file) if file.endswith('.mds'): + file = os.path.join(cwd, file) os.remove(file) +def bench_seq(dataset: StreamingDataset, use_tqdm: int) -> NDArray[np.float64]: + """Benchmark iterating a StreamingDataset in sequential order. + + Args: + dataset (StreamingDataset): The streaming dataset to iterate. + use_tqdm (int): Whether to use tqdm. + + Returns: + NDArray[np.float64]: Time taken to process that many dataset samples. + """ + times = np.zeros(dataset.num_samples, np.float64) + t0 = time() + xrange = trange if use_tqdm else range + for i in xrange(dataset.num_samples): + dataset[i] + times[i] = time() - t0 + return times + + +def bench_rand(dataset: StreamingDataset, use_tqdm: int) -> NDArray[np.float64]: + """Benchmark iterating a StreamingDataset in random order. + + Args: + dataset (StreamingDataset): The streaming dataset to iterate. + use_tqdm (int): Whether to use tqdm. + + Returns: + NDArray[np.float64]: Time taken to process that many dataset samples. + """ + indices = np.random.permutation(dataset.num_samples) + times = np.zeros(dataset.num_samples) + t0 = time() + if use_tqdm: + indices = tqdm(indices) + for i, sample_id in enumerate(indices): + dataset[sample_id] + times[i] = time() - t0 + return times + + def main(args: Namespace) -> None: """Randomly iterate over a Parquet dataset with Streaming. @@ -44,32 +201,40 @@ def main(args: Namespace) -> None: """ dataset = StreamingDataset(local=args.dataset) - if not args.cache_mds: - clear(args.dataset) + plt.title('Time to iterate') + plt.xlabel('Seconds') + plt.ylabel('Samples') + samples = np.arange(dataset.num_samples) - seq_times = np.zeros(dataset.num_samples) - t0 = time() - for i in trange(dataset.num_samples): - dataset[i] - seq_times[i] = time() - t0 + times = bench_pq_seq(dataset, args.pq_suffix, args.tqdm) + rate = int(len(times) / times[-1]) + plt.plot(times, samples, c='green', ls='--', label=f'PQ seq (in mem): {rate:,}/s') - if not args.cache_mds: - clear(args.dataset) + times = bench_pq_rand_uncached(dataset, args.pq_suffix, args.tqdm, args.time_limit) + rate = int(len(times) / times[-1]) + plt.plot(times, samples[:len(times)], c='green', ls=':', + label=f'PQ rand (in mem): {rate:,}/s') - indices = np.random.permutation(dataset.num_samples) - rand_times = np.zeros(dataset.num_samples) - t0 = time() - for i, index in enumerate(tqdm(indices)): - dataset[index] - rand_times[i] = time() - t0 + clear_mds(args.dataset) + times = bench_seq(dataset, args.tqdm) + rate = int(len(times) / times[-1]) + plt.plot(times, samples, c='blue', ls='--', label=f'Cold PQ>MDS seq: {rate:,}/s') + + clear_mds(args.dataset) + times = bench_rand(dataset, args.tqdm) + rate = int(len(times) / times[-1]) + plt.plot(times, samples, c='blue', ls=':', label=f'Cold PQ>MDS rand: {rate:,}/s') + + times = bench_seq(dataset, args.tqdm) + rate = int(len(times) / times[-1]) + plt.plot(times, samples, c='red', ls='--', label=f'Warm MDS seq: {rate:,}/s') + + times = bench_rand(dataset, args.tqdm) + rate = int(len(times) / times[-1]) + plt.plot(times, samples, c='red', ls=':', label=f'Warm MDS rand: {rate:,}/s') - plt.title('Parquet sample access times') - plt.xlabel('Samples seen') - plt.ylabel('Time (seconds)') - samples = np.arange(dataset.num_samples) - plt.plot(samples, seq_times, c='blue', label='Sequential') - plt.plot(samples, rand_times, c='red', label='Random') plt.legend() + plt.grid(which='major', ls='--', c='#ddd') plt.savefig(args.plot, dpi=500) From 2c43df6ad4c03fa54fc4e2a1ac0ca075143fbd11 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 1 Oct 2023 21:23:55 -0700 Subject: [PATCH 05/11] Add tqdm to generate.py, etc. --- scripts/parquet/generate.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/scripts/parquet/generate.py b/scripts/parquet/generate.py index 15bd4199b..0da322245 100644 --- a/scripts/parquet/generate.py +++ b/scripts/parquet/generate.py @@ -5,6 +5,7 @@ import os from argparse import ArgumentParser, Namespace +from tqdm import tqdm from typing import List, Tuple import numpy as np @@ -19,10 +20,11 @@ def parse_args() -> Namespace: Namespace: Command-line arguments. """ args = ArgumentParser() - args.add_argument('--num_train', type=int, default=10_000_000) - args.add_argument('--num_val', type=int, default=1_000_000) - args.add_argument('--dataset', type=str, default='data/pq/') - args.add_argument('--samples_per_shard', type=int, default=100_000) + args.add_argument('--num_train', type=int, default=1 << 24) + args.add_argument('--num_val', type=int, default=1 << 20) + args.add_argument('--dataset', type=str, default='data/parquet/') + args.add_argument('--samples_per_shard', type=int, default=1 << 17) + args.add_argument('--tqdm', type=int, default=1) return args.parse_args() @@ -69,23 +71,29 @@ def generate_number() -> int: return sign * mag -def generate_numbers(num_train: int, num_val: int) -> Tuple[List[int], List[int]]: +def generate_numbers(num_train: int, num_val: int, use_tqdm: int) -> Tuple[List[int], List[int]]: """Get two non-overlapping splits of integers to say. Args: num_train (int): Number of training samples. num_val (int): Number of validation samples. + use_tqdm (int): Whether to display a progress bar. Returns: Tuple[List[int], List[int]]: The two generated splits. """ total = num_train + num_val nums = set() + pbar = tqdm(total=total) if use_tqdm else None while len(nums) < total: num = generate_number() if num in nums: continue nums.add(num) + if use_tqdm: + pbar.update(1) + if use_tqdm: + pbar.close() nums = sorted(nums) np.random.shuffle(nums) train_nums = nums[:num_train] @@ -125,7 +133,7 @@ def main(args: Namespace) -> None: Args: args (Namespace): Command-line arguments. """ - train_nums, val_nums = generate_numbers(args.num_train, args.num_val) + train_nums, val_nums = generate_numbers(args.num_train, args.num_val, args.tqdm) train_txts = [' '.join(say(num)) for num in train_nums] val_txts = [' '.join(say(num)) for num in val_nums] From 44a2dc287943e7d18ed445d856f9fd5f216a37f2 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 1 Oct 2023 21:27:46 -0700 Subject: [PATCH 06/11] parquet_to_lance. --- scripts/parquet/parquet_to_lance.py | 39 +++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 scripts/parquet/parquet_to_lance.py diff --git a/scripts/parquet/parquet_to_lance.py b/scripts/parquet/parquet_to_lance.py new file mode 100644 index 000000000..fa3c63b58 --- /dev/null +++ b/scripts/parquet/parquet_to_lance.py @@ -0,0 +1,39 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Convert a Parquet dataset to Lance. + +Warning: apparently, Lance will crash with an unhelpful error message if there are any extraneous +files in the Parquet dataset. +""" + +from argparse import ArgumentParser, Namespace + +import lance +import pyarrow as pa + + +def parse_args() -> Namespace: + """Parse command-line arguments. + + Returns: + Namespace: Command-line arguments. + """ + args = ArgumentParser() + args.add_argument('--parquet', type=str, required=True) + args.add_argument('--lance', type=str, required=True) + return args.parse_args() + + +def main(args: Namespace) -> None: + """Convert a Parquet dataset to Lance. + + Args: + args (Namespace): Command-line arguments. + """ + dataset = pa.dataset.dataset(args.parquet, format='parquet') + lance.write_dataset(dataset, args.lance) + + +if __name__ == '__main__': + main(parse_args()) From b383b724229597a217081d9ac0f459d9ea9e19a2 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 1 Oct 2023 21:32:48 -0700 Subject: [PATCH 07/11] Add Lance benchmarking. --- scripts/parquet/generate.py | 6 +- scripts/parquet/iterate.py | 194 +++++++++++++++++++++++++++++------- 2 files changed, 161 insertions(+), 39 deletions(-) diff --git a/scripts/parquet/generate.py b/scripts/parquet/generate.py index 0da322245..a6989db3f 100644 --- a/scripts/parquet/generate.py +++ b/scripts/parquet/generate.py @@ -5,12 +5,12 @@ import os from argparse import ArgumentParser, Namespace -from tqdm import tqdm from typing import List, Tuple import numpy as np import pyarrow as pa from pyarrow import parquet as pq +from tqdm import tqdm def parse_args() -> Namespace: @@ -90,9 +90,9 @@ def generate_numbers(num_train: int, num_val: int, use_tqdm: int) -> Tuple[List[ if num in nums: continue nums.add(num) - if use_tqdm: + if pbar: pbar.update(1) - if use_tqdm: + if pbar: pbar.close() nums = sorted(nums) np.random.shuffle(nums) diff --git a/scripts/parquet/iterate.py b/scripts/parquet/iterate.py index 97594479a..2b8a1e80b 100644 --- a/scripts/parquet/iterate.py +++ b/scripts/parquet/iterate.py @@ -8,7 +8,9 @@ from time import time from typing import Iterator +import lance import numpy as np +from lance import LanceDataset from matplotlib import pyplot as plt from numpy.typing import NDArray from pyarrow import parquet as pq @@ -24,14 +26,82 @@ def parse_args() -> Namespace: Namespace: Command-line arguments. """ args = ArgumentParser() - args.add_argument('--dataset', type=str, required=True) + args.add_argument('--streaming_dataset', type=str, required=True) + args.add_argument('--lance_dataset', type=str, required=True) + args.add_argument('--lance_pow', type=int, default=4) args.add_argument('--pq_suffix', type=str, default='.parquet') args.add_argument('--tqdm', type=int, default=1) - args.add_argument('--time_limit', type=float, default=10) + args.add_argument('--time_limit', type=float, default=20) args.add_argument('--plot', type=str, required=True) return args.parse_args() +def bench_lance_seq(dataset: LanceDataset, take_count: int, use_tqdm: int, + time_limit: float) -> NDArray[np.float64]: + """Benchmark iterating a Lance dataset in sequential order. + + Args: + dataset (LanceDataset): The Lance dataset to iterate. + take_count (int): How many samples to take per sequential access. + use_tqdm (int): Whether to use tqdm. + time_limit (float): Benchmarking cutoff time. + + Returns: + NDArray[np.float64]: Time taken to process that many dataset samples. + """ + num_samples = dataset.count_rows() + if num_samples % take_count: + raise ValueError(f'`num_samples` ({num_samples}) must be divisible by `take_count` ' + + f'({take_count}).') + shape = num_samples // take_count, take_count + times = np.zeros(shape, np.float64) + sample, = dataset.head(1).to_pylist() + columns = sorted(sample) + each_batch = enumerate(dataset.to_batches(columns=columns, batch_size=take_count)) + if use_tqdm: + each_batch = tqdm(each_batch, total=num_samples // take_count, leave=False) + t0 = time() + for i, samples in each_batch: + samples.to_pylist() + times[i] = t = time() - t0 + if time_limit <= t: + times = times[:i] + break + return times.flatten() + + +def bench_lance_rand(dataset: LanceDataset, take_count: int, use_tqdm: int, + time_limit: float) -> NDArray[np.float64]: + """Benchmark iterating a Lance dataset in random order. + + Args: + dataset (LanceDataset): The Lance dataset to iterate. + take_count (int): How many samples to take per random access. + use_tqdm (int): Whether to use tqdm. + time_limit (float): Benchmarking cutoff time. + + Returns: + NDArray[np.float64]: Time taken to process that many dataset samples. + """ + num_samples = dataset.count_rows() + if num_samples % take_count: + raise ValueError(f'`num_samples` ({num_samples}) must be divisible by `take_count` ' + + f'({take_count}).') + shape = num_samples // take_count, take_count + times = np.zeros(shape, np.float64) + batches = np.random.permutation(num_samples).reshape(shape) + if use_tqdm: + batches = tqdm(batches, leave=False) + t0 = time() + for i, sample_ids in enumerate(batches): + dataset.take(sample_ids).to_pylist() + times[i] = t = time() - t0 + if time_limit <= t: + times = times[:i] + break + return times.flatten() + + def each_pq(dataset_root: str, pq_suffix: str) -> Iterator[str]: """Iteracte over each Parquet shard file of the dataset in order. @@ -48,28 +118,32 @@ def each_pq(dataset_root: str, pq_suffix: str) -> Iterator[str]: yield from sorted(files) -def bench_pq_seq(dataset: StreamingDataset, pq_suffix: str, use_tqdm: int) -> NDArray[np.float64]: +def bench_pq_seq(dataset: StreamingDataset, pq_suffix: str, use_tqdm: int, + time_limit: float) -> NDArray[np.float64]: """Benchmark iterating a StreamingDataset in sequential order. Args: dataset (StreamingDataset): The streaming dataset to iterate. pq_suffix (str): Parquet shard file suffix. use_tqdm (int): Whether to use tqdm. + time_limit (float): Benchmarking cutoff time. Returns: NDArray[np.float64]: Time taken to process that many dataset samples. """ times = np.zeros(dataset.num_samples, np.float64) - pbar = tqdm(total=dataset.num_samples) if use_tqdm else None + pbar = tqdm(total=dataset.num_samples, leave=False) if use_tqdm else None i = 0 dataset_root = dataset.streams[0].local t0 = time() for file in each_pq(dataset_root, pq_suffix): table = pq.read_table(file) for _ in table.to_pylist(): - times[i] = time() - t0 + times[i] = t = time() - t0 + if time_limit <= t: + return times[:i] i += 1 - if use_tqdm: + if pbar: pbar.update(1) return times @@ -91,7 +165,7 @@ def bench_pq_rand_cached(dataset: StreamingDataset, pq_suffix: str, shard_sample_lists = [None] * len(shard_files) indices = np.random.permutation(dataset.num_samples) times = np.zeros(dataset.num_samples, np.float64) - pbar = tqdm(total=dataset.num_samples) if use_tqdm else None + pbar = tqdm(total=dataset.num_samples, leave=False) if use_tqdm else None t0 = time() for i, sample_id in enumerate(indices): shard_id, shard_sample_id = dataset.spanner[sample_id] @@ -102,7 +176,7 @@ def bench_pq_rand_cached(dataset: StreamingDataset, pq_suffix: str, shard_sample_lists[shard_id] = shard_samples = table.to_pylist() shard_samples[shard_sample_id] times[i] = time() - t0 - if use_tqdm: + if pbar: pbar.update(1) return times @@ -115,6 +189,7 @@ def bench_pq_rand_uncached(dataset: StreamingDataset, pq_suffix: str, use_tqdm: dataset (StreamingDataset): The streaming dataset to iterate. pq_suffix (str): Parquet shard file suffix. use_tqdm (int): Whether to use tqdm. + time_limit (float): Benchmarking cutoff time. Returns: NDArray[np.float64]: Time taken to process that many dataset samples. @@ -123,7 +198,7 @@ def bench_pq_rand_uncached(dataset: StreamingDataset, pq_suffix: str, use_tqdm: shard_files = list(each_pq(dataset_root, pq_suffix)) indices = np.random.permutation(dataset.num_samples) times = np.zeros(dataset.num_samples, np.float64) - pbar = tqdm(total=dataset.num_samples) if use_tqdm else None + pbar = tqdm(total=dataset.num_samples, leave=False) if use_tqdm else None t0 = time() for i, sample_id in enumerate(indices): shard_id, shard_sample_id = dataset.spanner[sample_id] @@ -132,7 +207,7 @@ def bench_pq_rand_uncached(dataset: StreamingDataset, pq_suffix: str, use_tqdm: shard_samples = table.to_pylist() shard_samples[shard_sample_id] times[i] = t = time() - t0 - if use_tqdm: + if pbar: pbar.update(1) if time_limit <= t: times = times[:i] @@ -153,43 +228,51 @@ def clear_mds(dataset_root: str) -> None: os.remove(file) -def bench_seq(dataset: StreamingDataset, use_tqdm: int) -> NDArray[np.float64]: +def bench_seq(dataset: StreamingDataset, use_tqdm: int, time_limit: float) -> NDArray[np.float64]: """Benchmark iterating a StreamingDataset in sequential order. Args: dataset (StreamingDataset): The streaming dataset to iterate. use_tqdm (int): Whether to use tqdm. + time_limit (float): Benchmarking cutoff time. Returns: NDArray[np.float64]: Time taken to process that many dataset samples. """ times = np.zeros(dataset.num_samples, np.float64) + xrange = trange(dataset.num_samples, leave=False) if use_tqdm else range(dataset.num_samples) t0 = time() - xrange = trange if use_tqdm else range - for i in xrange(dataset.num_samples): + for i in xrange: dataset[i] - times[i] = time() - t0 + times[i] = t = time() - t0 + if time_limit <= t: + times = times[:i] + break return times -def bench_rand(dataset: StreamingDataset, use_tqdm: int) -> NDArray[np.float64]: +def bench_rand(dataset: StreamingDataset, use_tqdm: int, time_limit: float) -> NDArray[np.float64]: """Benchmark iterating a StreamingDataset in random order. Args: dataset (StreamingDataset): The streaming dataset to iterate. use_tqdm (int): Whether to use tqdm. + time_limit (float): Benchmarking cutoff time. Returns: NDArray[np.float64]: Time taken to process that many dataset samples. """ indices = np.random.permutation(dataset.num_samples) times = np.zeros(dataset.num_samples) - t0 = time() if use_tqdm: - indices = tqdm(indices) + indices = tqdm(indices, leave=False) + t0 = time() for i, sample_id in enumerate(indices): dataset[sample_id] - times[i] = time() - t0 + times[i] = t = time() - t0 + if time_limit <= t: + times = times[:i] + break return times @@ -199,39 +282,78 @@ def main(args: Namespace) -> None: Args: args (Namespace): Command-line arguments. """ - dataset = StreamingDataset(local=args.dataset) + streaming_dataset = StreamingDataset(local=args.streaming_dataset) + lance_dataset = lance.dataset(args.lance_dataset) + plt.rc('legend', fontsize=6) plt.title('Time to iterate') plt.xlabel('Seconds') plt.ylabel('Samples') - samples = np.arange(dataset.num_samples) - - times = bench_pq_seq(dataset, args.pq_suffix, args.tqdm) + line_width = 0.75 + + if args.lance_pow == 4: + lance_colors = '#a60', '#b70', '#c80', '#d90', '#ea0', '#fb1' + lance_take_counts = 1, 4, 16, 64, 256, 1024 + elif args.lance_pow == 2: + lance_colors = '#730', '#840', '#950', '#a60', '#b70', '#c80', '#d90', '#ea0', '#fb1', \ + '#fc4', '#fd7' + lance_take_counts = 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024 + else: + raise ValueError(f'Unsupported --lance_pow: {args.lance_pow}.') + + for color, take_count in reversed(list(zip(lance_colors, lance_take_counts))): + times = bench_lance_seq(lance_dataset, take_count, args.tqdm, args.time_limit) + rate = int(len(times) / times[-1]) + label = f'Lance seq n={take_count}: {rate:,}/s' + plt.plot(times, np.arange(len(times)), c=color, ls='-', lw=line_width, label=label) + print(label) + + for color, take_count in reversed(list(zip(lance_colors, lance_take_counts))): + times = bench_lance_rand(lance_dataset, take_count, args.tqdm, args.time_limit) + rate = int(len(times) / times[-1]) + label = f'Lance rand n={take_count}: {rate:,}/s' + plt.plot(times, np.arange(len(times)), c=color, ls=':', lw=line_width, label=label) + print(label) + + times = bench_pq_seq(streaming_dataset, args.pq_suffix, args.tqdm, args.time_limit) rate = int(len(times) / times[-1]) - plt.plot(times, samples, c='green', ls='--', label=f'PQ seq (in mem): {rate:,}/s') + label = f'PQ seq (in mem): {rate:,}/s' + plt.plot(times, np.arange(len(times)), c='green', ls='-', lw=line_width, label=label) + print(label) - times = bench_pq_rand_uncached(dataset, args.pq_suffix, args.tqdm, args.time_limit) + times = bench_pq_rand_uncached(streaming_dataset, args.pq_suffix, args.tqdm, args.time_limit) rate = int(len(times) / times[-1]) - plt.plot(times, samples[:len(times)], c='green', ls=':', - label=f'PQ rand (in mem): {rate:,}/s') + label = f'PQ rand (in mem): {rate:,}/s' + plt.plot(times, np.arange(len(times)), c='green', ls=':', lw=line_width, label=label) + print(label) - clear_mds(args.dataset) - times = bench_seq(dataset, args.tqdm) + clear_mds(args.streaming_dataset) + + times = bench_seq(streaming_dataset, args.tqdm, args.time_limit) rate = int(len(times) / times[-1]) - plt.plot(times, samples, c='blue', ls='--', label=f'Cold PQ>MDS seq: {rate:,}/s') + label = f'Cold PQ>MDS seq: {rate:,}/s' + plt.plot(times, np.arange(len(times)), c='blue', ls='-', lw=line_width, label=label) + print(label) + + clear_mds(args.streaming_dataset) - clear_mds(args.dataset) - times = bench_rand(dataset, args.tqdm) + times = bench_rand(streaming_dataset, args.tqdm, args.time_limit) rate = int(len(times) / times[-1]) - plt.plot(times, samples, c='blue', ls=':', label=f'Cold PQ>MDS rand: {rate:,}/s') + label = f'Cold PQ>MDS rand: {rate:,}/s' + plt.plot(times, np.arange(len(times)), c='blue', ls=':', lw=line_width, label=label) + print(label) - times = bench_seq(dataset, args.tqdm) + times = bench_seq(streaming_dataset, args.tqdm, args.time_limit) rate = int(len(times) / times[-1]) - plt.plot(times, samples, c='red', ls='--', label=f'Warm MDS seq: {rate:,}/s') + label = f'Warm MDS seq: {rate:,}/s' + plt.plot(times, np.arange(len(times)), c='red', ls='-', lw=line_width, label=label) + print(label) - times = bench_rand(dataset, args.tqdm) + times = bench_rand(streaming_dataset, args.tqdm, args.time_limit) rate = int(len(times) / times[-1]) - plt.plot(times, samples, c='red', ls=':', label=f'Warm MDS rand: {rate:,}/s') + label = f'Warm MDS rand: {rate:,}/s' + plt.plot(times, np.arange(len(times)), c='red', ls=':', lw=line_width, label=label) + print(label) plt.legend() plt.grid(which='major', ls='--', c='#ddd') From 4a5a171583f8e21c53084f55a648a3526a81a556 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 1 Oct 2023 21:37:43 -0700 Subject: [PATCH 08/11] Renames. --- scripts/{parquet/iterate.py => iteration/bench_and_plot.py} | 0 scripts/{parquet/generate.py => iteration/generate_parquet.py} | 0 scripts/{parquet => iteration}/parquet_to_lance.py | 0 scripts/{parquet/index.py => iteration/streamify_parquet.py} | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename scripts/{parquet/iterate.py => iteration/bench_and_plot.py} (100%) rename scripts/{parquet/generate.py => iteration/generate_parquet.py} (100%) rename scripts/{parquet => iteration}/parquet_to_lance.py (100%) rename scripts/{parquet/index.py => iteration/streamify_parquet.py} (100%) diff --git a/scripts/parquet/iterate.py b/scripts/iteration/bench_and_plot.py similarity index 100% rename from scripts/parquet/iterate.py rename to scripts/iteration/bench_and_plot.py diff --git a/scripts/parquet/generate.py b/scripts/iteration/generate_parquet.py similarity index 100% rename from scripts/parquet/generate.py rename to scripts/iteration/generate_parquet.py diff --git a/scripts/parquet/parquet_to_lance.py b/scripts/iteration/parquet_to_lance.py similarity index 100% rename from scripts/parquet/parquet_to_lance.py rename to scripts/iteration/parquet_to_lance.py diff --git a/scripts/parquet/index.py b/scripts/iteration/streamify_parquet.py similarity index 100% rename from scripts/parquet/index.py rename to scripts/iteration/streamify_parquet.py From 68a56b476f4aafeee5d8d6202348c558c878e0d4 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 1 Oct 2023 22:44:00 -0700 Subject: [PATCH 09/11] Split bench and plot. --- .../iteration/{bench_and_plot.py => bench.py} | 52 ++++++------ scripts/iteration/plot.py | 85 +++++++++++++++++++ 2 files changed, 110 insertions(+), 27 deletions(-) rename scripts/iteration/{bench_and_plot.py => bench.py} (88%) create mode 100644 scripts/iteration/plot.py diff --git a/scripts/iteration/bench_and_plot.py b/scripts/iteration/bench.py similarity index 88% rename from scripts/iteration/bench_and_plot.py rename to scripts/iteration/bench.py index 2b8a1e80b..4032d04fb 100644 --- a/scripts/iteration/bench_and_plot.py +++ b/scripts/iteration/bench.py @@ -1,8 +1,9 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Randomly iterate over a Parquet dataset with Streaming.""" +"""Benchmark dataset iteration time.""" +import json import os from argparse import ArgumentParser, Namespace from time import time @@ -11,7 +12,6 @@ import lance import numpy as np from lance import LanceDataset -from matplotlib import pyplot as plt from numpy.typing import NDArray from pyarrow import parquet as pq from tqdm import tqdm, trange @@ -32,7 +32,7 @@ def parse_args() -> Namespace: args.add_argument('--pq_suffix', type=str, default='.parquet') args.add_argument('--tqdm', type=int, default=1) args.add_argument('--time_limit', type=float, default=20) - args.add_argument('--plot', type=str, required=True) + args.add_argument('--stats', type=str, required=True) return args.parse_args() @@ -285,46 +285,45 @@ def main(args: Namespace) -> None: streaming_dataset = StreamingDataset(local=args.streaming_dataset) lance_dataset = lance.dataset(args.lance_dataset) - plt.rc('legend', fontsize=6) - plt.title('Time to iterate') - plt.xlabel('Seconds') - plt.ylabel('Samples') - line_width = 0.75 - if args.lance_pow == 4: - lance_colors = '#a60', '#b70', '#c80', '#d90', '#ea0', '#fb1' lance_take_counts = 1, 4, 16, 64, 256, 1024 elif args.lance_pow == 2: - lance_colors = '#730', '#840', '#950', '#a60', '#b70', '#c80', '#d90', '#ea0', '#fb1', \ - '#fc4', '#fd7' lance_take_counts = 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024 else: raise ValueError(f'Unsupported --lance_pow: {args.lance_pow}.') - for color, take_count in reversed(list(zip(lance_colors, lance_take_counts))): + obj = {} + + to_dict = lambda label, rate, times: ({ + 'label': label, + 'rate': rate, + 'times': (times * 1e9).astype(np.int64).tolist() + }) + + for take_count in lance_take_counts: times = bench_lance_seq(lance_dataset, take_count, args.tqdm, args.time_limit) rate = int(len(times) / times[-1]) - label = f'Lance seq n={take_count}: {rate:,}/s' - plt.plot(times, np.arange(len(times)), c=color, ls='-', lw=line_width, label=label) + label = f'Lance seq n={take_count:04}: {rate:,}/s' + obj[f'lance_seq_{take_count:04}'] = to_dict(label, rate, times) print(label) - for color, take_count in reversed(list(zip(lance_colors, lance_take_counts))): + for take_count in lance_take_counts: times = bench_lance_rand(lance_dataset, take_count, args.tqdm, args.time_limit) rate = int(len(times) / times[-1]) - label = f'Lance rand n={take_count}: {rate:,}/s' - plt.plot(times, np.arange(len(times)), c=color, ls=':', lw=line_width, label=label) + label = f'Lance rand n={take_count:04}: {rate:,}/s' + obj[f'lance_rand_{take_count:04}'] = to_dict(label, rate, times) print(label) times = bench_pq_seq(streaming_dataset, args.pq_suffix, args.tqdm, args.time_limit) rate = int(len(times) / times[-1]) label = f'PQ seq (in mem): {rate:,}/s' - plt.plot(times, np.arange(len(times)), c='green', ls='-', lw=line_width, label=label) + obj['pq_seq'] = to_dict(label, rate, times) print(label) times = bench_pq_rand_uncached(streaming_dataset, args.pq_suffix, args.tqdm, args.time_limit) rate = int(len(times) / times[-1]) label = f'PQ rand (in mem): {rate:,}/s' - plt.plot(times, np.arange(len(times)), c='green', ls=':', lw=line_width, label=label) + obj['pq_rand'] = to_dict(label, rate, times) print(label) clear_mds(args.streaming_dataset) @@ -332,7 +331,7 @@ def main(args: Namespace) -> None: times = bench_seq(streaming_dataset, args.tqdm, args.time_limit) rate = int(len(times) / times[-1]) label = f'Cold PQ>MDS seq: {rate:,}/s' - plt.plot(times, np.arange(len(times)), c='blue', ls='-', lw=line_width, label=label) + obj['pq_mds_seq'] = to_dict(label, rate, times) print(label) clear_mds(args.streaming_dataset) @@ -340,24 +339,23 @@ def main(args: Namespace) -> None: times = bench_rand(streaming_dataset, args.tqdm, args.time_limit) rate = int(len(times) / times[-1]) label = f'Cold PQ>MDS rand: {rate:,}/s' - plt.plot(times, np.arange(len(times)), c='blue', ls=':', lw=line_width, label=label) + obj['pq_mds_rand'] = to_dict(label, rate, times) print(label) times = bench_seq(streaming_dataset, args.tqdm, args.time_limit) rate = int(len(times) / times[-1]) label = f'Warm MDS seq: {rate:,}/s' - plt.plot(times, np.arange(len(times)), c='red', ls='-', lw=line_width, label=label) + obj['mds_seq'] = to_dict(label, rate, times) print(label) times = bench_rand(streaming_dataset, args.tqdm, args.time_limit) rate = int(len(times) / times[-1]) label = f'Warm MDS rand: {rate:,}/s' - plt.plot(times, np.arange(len(times)), c='red', ls=':', lw=line_width, label=label) + obj['mds_rand'] = to_dict(label, rate, times) print(label) - plt.legend() - plt.grid(which='major', ls='--', c='#ddd') - plt.savefig(args.plot, dpi=500) + with open(args.stats, 'w') as out: + json.dump(obj, out) if __name__ == '__main__': diff --git a/scripts/iteration/plot.py b/scripts/iteration/plot.py new file mode 100644 index 000000000..214f9149b --- /dev/null +++ b/scripts/iteration/plot.py @@ -0,0 +1,85 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Plot dataset iteration time.""" + +import json +from argparse import ArgumentParser, Namespace +from typing import Dict + +import numpy as np +from matplotlib import pyplot as plt + + +def parse_args() -> Namespace: + """Parse command-line arguments. + + Returns: + Namespace: Command-line arguments. + """ + args = ArgumentParser() + args.add_argument('--stats', type=str, required=True) + args.add_argument('--plot', type=str, required=True) + return args.parse_args() + + +def get_color(key: str, pq_mds_colors: Dict[str, str], lance_colors: Dict[int, str]) -> str: + """Get a plot color for a given statistic key. + + Args: + key (str): The statistic key. + pq_mds_colors (Dict[str, str]): Mapping of PQ/MDS type to color. + lance_colors (Dict[int, str]): Mapping of Lance take count to color. + + Returns: + str: Color. + """ + parts = key.split('_') + first = parts[0] + if first in {'pq', 'mds'}: + kind = '_'.join(parts[:-1]) + color = pq_mds_colors[kind] + elif first == 'lance': + take_count = int(parts[-1]) + color = lance_colors[take_count] + else: + raise ValueError(f'Unknown type of key: {key}.') + return color + + +def main(args: Namespace) -> None: + """Randomly iterate over a Parquet dataset with Streaming. + + Args: + args (Namespace): Command-line arguments. + """ + pq_mds_colors = {'pq': 'green', 'pq_mds': 'blue', 'mds': 'red'} + + lance_take_counts = 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024 + lance_colors = '#730', '#840', '#950', '#a60', '#b70', '#c80', '#d90', '#ea0', '#fb1', \ + '#fc4', '#fd7' + lance_colors = dict(zip(lance_take_counts, lance_colors)) + + stats = json.load(open(args.stats)) + + plt.rc('legend', fontsize=6) + plt.title('Time to iterate') + plt.xlabel('Seconds') + plt.ylabel('Samples') + line_width = 0.75 + + for key in sorted(stats): + stat = stats[key] + times = np.array(stat['times']) / 1e9 + color = get_color(key, pq_mds_colors, lance_colors) + line_style = '-' if 'seq' in key else ':' + label = stat['label'] + plt.plot(times, np.arange(len(times)), c=color, ls=line_style, lw=line_width, label=label) + + plt.legend() + plt.grid(which='major', ls='--', c='#ddd') + plt.savefig(args.plot, dpi=500) + + +if __name__ == '__main__': + main(parse_args()) From c9bb6b9d9575728a4015559a6b1e85763c7ed6a2 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Tue, 17 Oct 2023 15:27:31 -0700 Subject: [PATCH 10/11] Misc. --- scripts/iteration/bench.py | 38 ++++++++++--------- scripts/iteration/generate_parquet.py | 16 ++++---- scripts/iteration/parquet_to_lance.py | 3 +- ...ify_parquet.py => parquet_to_streaming.py} | 7 +++- scripts/iteration/run.sh | 19 ++++++++++ 5 files changed, 56 insertions(+), 27 deletions(-) rename scripts/iteration/{streamify_parquet.py => parquet_to_streaming.py} (94%) create mode 100755 scripts/iteration/run.sh diff --git a/scripts/iteration/bench.py b/scripts/iteration/bench.py index 4032d04fb..b60581264 100644 --- a/scripts/iteration/bench.py +++ b/scripts/iteration/bench.py @@ -31,7 +31,7 @@ def parse_args() -> Namespace: args.add_argument('--lance_pow', type=int, default=4) args.add_argument('--pq_suffix', type=str, default='.parquet') args.add_argument('--tqdm', type=int, default=1) - args.add_argument('--time_limit', type=float, default=20) + args.add_argument('--time_limit', type=float, default=180) args.add_argument('--stats', type=str, required=True) return args.parse_args() @@ -53,16 +53,20 @@ def bench_lance_seq(dataset: LanceDataset, take_count: int, use_tqdm: int, if num_samples % take_count: raise ValueError(f'`num_samples` ({num_samples}) must be divisible by `take_count` ' + f'({take_count}).') - shape = num_samples // take_count, take_count + num_batches = num_samples // take_count + shape = num_batches, take_count times = np.zeros(shape, np.float64) sample, = dataset.head(1).to_pylist() columns = sorted(sample) each_batch = enumerate(dataset.to_batches(columns=columns, batch_size=take_count)) if use_tqdm: - each_batch = tqdm(each_batch, total=num_samples // take_count, leave=False) + each_batch = tqdm(each_batch, total=num_batches, leave=False) t0 = time() for i, samples in each_batch: samples.to_pylist() + assert len(samples) == take_count + if num_batches < i: # ??? + break times[i] = t = time() - t0 if time_limit <= t: times = times[:i] @@ -300,20 +304,6 @@ def main(args: Namespace) -> None: 'times': (times * 1e9).astype(np.int64).tolist() }) - for take_count in lance_take_counts: - times = bench_lance_seq(lance_dataset, take_count, args.tqdm, args.time_limit) - rate = int(len(times) / times[-1]) - label = f'Lance seq n={take_count:04}: {rate:,}/s' - obj[f'lance_seq_{take_count:04}'] = to_dict(label, rate, times) - print(label) - - for take_count in lance_take_counts: - times = bench_lance_rand(lance_dataset, take_count, args.tqdm, args.time_limit) - rate = int(len(times) / times[-1]) - label = f'Lance rand n={take_count:04}: {rate:,}/s' - obj[f'lance_rand_{take_count:04}'] = to_dict(label, rate, times) - print(label) - times = bench_pq_seq(streaming_dataset, args.pq_suffix, args.tqdm, args.time_limit) rate = int(len(times) / times[-1]) label = f'PQ seq (in mem): {rate:,}/s' @@ -354,6 +344,20 @@ def main(args: Namespace) -> None: obj['mds_rand'] = to_dict(label, rate, times) print(label) + for take_count in lance_take_counts: + times = bench_lance_seq(lance_dataset, take_count, args.tqdm, args.time_limit) + rate = int(len(times) / times[-1]) + label = f'Lance seq n={take_count:04}: {rate:,}/s' + obj[f'lance_seq_{take_count:04}'] = to_dict(label, rate, times) + print(label) + + for take_count in lance_take_counts: + times = bench_lance_rand(lance_dataset, take_count, args.tqdm, args.time_limit) + rate = int(len(times) / times[-1]) + label = f'Lance rand n={take_count:04}: {rate:,}/s' + obj[f'lance_rand_{take_count:04}'] = to_dict(label, rate, times) + print(label) + with open(args.stats, 'w') as out: json.dump(obj, out) diff --git a/scripts/iteration/generate_parquet.py b/scripts/iteration/generate_parquet.py index a6989db3f..9eab5db69 100644 --- a/scripts/iteration/generate_parquet.py +++ b/scripts/iteration/generate_parquet.py @@ -20,18 +20,18 @@ def parse_args() -> Namespace: Namespace: Command-line arguments. """ args = ArgumentParser() - args.add_argument('--num_train', type=int, default=1 << 24) - args.add_argument('--num_val', type=int, default=1 << 20) + args.add_argument('--num_train', type=int, default=1) + args.add_argument('--num_val', type=int, default=1 << 26) args.add_argument('--dataset', type=str, default='data/parquet/') - args.add_argument('--samples_per_shard', type=int, default=1 << 17) + args.add_argument('--samples_per_shard', type=int, default=1 << 20) args.add_argument('--tqdm', type=int, default=1) return args.parse_args() -_ones = ('zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen ' +ones = ('zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen ' 'fifteen sixteen seventeen eighteen nineteen').split() -_tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split() +tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split() def say(i: int) -> List[str]: @@ -46,11 +46,11 @@ def say(i: int) -> List[str]: if i < 0: return ['negative'] + say(-i) elif i <= 19: - return [_ones[i]] + return [ones[i]] elif i < 100: - return [_tens[i // 10 - 2]] + ([_ones[i % 10]] if i % 10 else []) + return [tens[i // 10 - 2]] + ([ones[i % 10]] if i % 10 else []) elif i < 1_000: - return [_ones[i // 100], 'hundred'] + (say(i % 100) if i % 100 else []) + return [ones[i // 100], 'hundred'] + (say(i % 100) if i % 100 else []) elif i < 1_000_000: return say(i // 1_000) + ['thousand'] + (say(i % 1_000) if i % 1_000 else []) elif i < 1_000_000_000: diff --git a/scripts/iteration/parquet_to_lance.py b/scripts/iteration/parquet_to_lance.py index fa3c63b58..6561d869b 100644 --- a/scripts/iteration/parquet_to_lance.py +++ b/scripts/iteration/parquet_to_lance.py @@ -22,6 +22,7 @@ def parse_args() -> Namespace: args = ArgumentParser() args.add_argument('--parquet', type=str, required=True) args.add_argument('--lance', type=str, required=True) + args.add_argument('--max_rows_per_group', type=int, default=1024) return args.parse_args() @@ -32,7 +33,7 @@ def main(args: Namespace) -> None: args (Namespace): Command-line arguments. """ dataset = pa.dataset.dataset(args.parquet, format='parquet') - lance.write_dataset(dataset, args.lance) + lance.write_dataset(dataset, args.lance, max_rows_per_group=args.max_rows_per_group) if __name__ == '__main__': diff --git a/scripts/iteration/streamify_parquet.py b/scripts/iteration/parquet_to_streaming.py similarity index 94% rename from scripts/iteration/streamify_parquet.py rename to scripts/iteration/parquet_to_streaming.py index ed0d73f96..e67b7561f 100644 --- a/scripts/iteration/streamify_parquet.py +++ b/scripts/iteration/parquet_to_streaming.py @@ -6,6 +6,7 @@ import json import os from argparse import ArgumentParser, Namespace +from tqdm import tqdm from typing import Any, Dict, Iterator, List, Optional, Tuple from pyarrow import parquet as pq @@ -22,6 +23,7 @@ def parse_args() -> Namespace: args = ArgumentParser() args.add_argument('--dataset', type=str, required=True) args.add_argument('--shard_suffix', type=str, default='.parquet') + args.add_argument('--tqdm', type=int, default=1) return args.parse_args() @@ -138,8 +140,11 @@ def main(args: Namespace) -> None: Args: args (Namespace): Command-line arguments. """ + each = each_shard_path(args.dataset, args.shard_suffix) + if args.tqdm: + each = tqdm(list(each), leave=False) infos = [] - for path, dataset_rel_path in each_shard_path(args.dataset, args.shard_suffix): + for path, dataset_rel_path in each: info = get_shard_info(path, dataset_rel_path) infos.append(info) obj = { diff --git a/scripts/iteration/run.sh b/scripts/iteration/run.sh new file mode 100755 index 000000000..6ca8e7605 --- /dev/null +++ b/scripts/iteration/run.sh @@ -0,0 +1,19 @@ +python3 scripts/iteration/generate_parquet.py \ + --dataset data/iteration/parquet/ + +python3 scripts/iteration/parquet_to_lance.py \ + --parquet data/iteration/parquet/val/ \ + --lance data/iteration/lance_1024/val/ \ + --max_rows_per_group 1024 + +python3 scripts/iteration/parquet_to_streaming.py \ + --dataset data/iteration/parquet/val/ + +python3 scripts/iteration/bench.py \ + --streaming_dataset data/iteration/parquet/val/ \ + --lance_dataset data/iteration/lance_1024/val/ \ + --stats data/iteration/stats_1024_val.json + +python3 scripts/iteration/plot.py \ + --stats data/iteration/stats_1024_val.json \ + --plot data/iteration/plot_1024_val.png From f90a2a0e6e46dea7866d72ae32fd31599bc2c694 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Tue, 17 Oct 2023 18:57:47 -0700 Subject: [PATCH 11/11] Update streaming/base/format/pq/reader.py Co-authored-by: Aaron Gokaslan --- streaming/base/format/pq/reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/format/pq/reader.py b/streaming/base/format/pq/reader.py index 7bb622e36..d64c544ec 100644 --- a/streaming/base/format/pq/reader.py +++ b/streaming/base/format/pq/reader.py @@ -142,7 +142,7 @@ def get_column(self, val: Any) -> str: elif isinstance(val, str): return 'str' else: - raise ValueError('Unsupported column type: {type(val)}.') + raise TypeError('Unsupported column type: {type(val)}.') def get_columns(self, sample: Dict[str, Any]) -> Dict[str, str]: """Get the MDS columns given one sample.