Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training on PQ shards #443

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
141 changes: 141 additions & 0 deletions scripts/parquet/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Generate a parquet dataset for testing."""

import os
from argparse import ArgumentParser, Namespace
from typing import List, Tuple

import numpy as np
import pyarrow as pa
from pyarrow import parquet as pq


def parse_args() -> Namespace:
"""Parse command-line arguments.

Returns:
Namespace: Command-line arguments.
"""
args = ArgumentParser()
args.add_argument('--num_train', type=int, default=10_000_000)
args.add_argument('--num_val', type=int, default=1_000_000)
args.add_argument('--dataset', type=str, default='data/pq/')
args.add_argument('--samples_per_shard', type=int, default=100_000)
return args.parse_args()


_ones = ('zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen '
'fifteen sixteen seventeen eighteen nineteen').split()

_tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split()


def say(i: int) -> List[str]:
"""Get the word form of a number.

Args:
i (int): The number.

Returns:
List[str]: The number in word form.
"""
if i < 0:
return ['negative'] + say(-i)
elif i <= 19:
return [_ones[i]]
elif i < 100:
return [_tens[i // 10 - 2]] + ([_ones[i % 10]] if i % 10 else [])
elif i < 1_000:
return [_ones[i // 100], 'hundred'] + (say(i % 100) if i % 100 else [])
elif i < 1_000_000:
return say(i // 1_000) + ['thousand'] + (say(i % 1_000) if i % 1_000 else [])
elif i < 1_000_000_000:
return say(i // 1_000_000) + ['million'] + (say(i % 1_000_000) if i % 1_000_000 else [])
else:
raise ValueError('Integer must be less than a billion, but got: {i}')


def generate_number() -> int:
"""Generate a random integer to say.

Returns:
int: The integer.
"""
sign = (np.random.uniform() < 0.8) * 2 - 1
expt = np.random.uniform(0, 9)
mag = int(10**expt)
return sign * mag


def generate_numbers(num_train: int, num_val: int) -> Tuple[List[int], List[int]]:
"""Get two non-overlapping splits of integers to say.

Args:
num_train (int): Number of training samples.
num_val (int): Number of validation samples.

Returns:
Tuple[List[int], List[int]]: The two generated splits.
"""
total = num_train + num_val
nums = set()
while len(nums) < total:
num = generate_number()
if num in nums:
continue
nums.add(num)
nums = sorted(nums)
np.random.shuffle(nums)
train_nums = nums[:num_train]
val_nums = nums[num_train:]
return train_nums, val_nums


def save_parquets(nums: List[int], txts: List[str], dirname: str, samples_per_shard: int) -> None:
"""Save a parquet dataaset given the samples.

Args:
nums (List[int]): List of sample integers.
txts (List[str]): List of sample texts.
dirname (str): Output dirname.
samples_per_shard (int): Output shard size in samples.
"""
if not os.path.exists(dirname):
os.makedirs(dirname)
num_shards = (len(nums) + samples_per_shard - 1) // samples_per_shard
for shard_id in range(num_shards):
begin = shard_id * samples_per_shard
end = min(begin + samples_per_shard, len(nums))
shard_nums = nums[begin:end]
shard_txts = txts[begin:end]
filename = os.path.join(dirname, f'{shard_id:05}.parquet')
obj = {
'num': shard_nums,
'txt': shard_txts,
}
table = pa.Table.from_pydict(obj)
pq.write_table(table, filename)


def main(args: Namespace) -> None:
"""Generate a parquet dataset for testing.

Args:
args (Namespace): Command-line arguments.
"""
train_nums, val_nums = generate_numbers(args.num_train, args.num_val)

train_txts = [' '.join(say(num)) for num in train_nums]
val_txts = [' '.join(say(num)) for num in val_nums]

dirname = os.path.join(args.dataset, 'train')
save_parquets(train_nums, train_txts, dirname, args.samples_per_shard)

dirname = os.path.join(args.dataset, 'val')
save_parquets(val_nums, val_txts, dirname, args.samples_per_shard)


if __name__ == '__main__':
main(parse_args())
157 changes: 157 additions & 0 deletions scripts/parquet/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Index a parquet dataset for use by Streaming."""

import json
import os
from argparse import ArgumentParser, Namespace
from typing import Any, Dict, Iterator, List, Optional, Tuple

from pyarrow import parquet as pq

from streaming.base.format.mds.encodings import get_mds_encoded_size


def parse_args() -> Namespace:
"""Parse command-line arguments.

Returns:
Namespace: Command-line arguments.
"""
args = ArgumentParser()
args.add_argument('--dataset', type=str, required=True)
args.add_argument('--shard_suffix', type=str, default='.parquet')
return args.parse_args()


def get_dataset_relative_path(dataset_root: str, path: str) -> str:
"""Get the dataset-relative path of a shard file.

Args:
dataset_root (str): Dataset root directory containing this shard.
path (str): Path to shard under dataset root dir.

Returns:
Dataset-relative shard path.
"""
if not path.startswith(dataset_root):
raise ValueError('Path {path} was not found under {dataset_root}.')
rel_path = path[len(dataset_root):]

while rel_path.startswith(os.path.sep):
rel_path = rel_path[1:]

return rel_path


def each_shard_path(dataset_root: str, shard_suffix: str) -> Iterator[Tuple[str, str]]:
"""Collect each Parquet shard, in order.

Args:
dataset_root (str): Dataset root directory.
shard_suffix (str): Suffix of each Parquet shard file.

Returns:
Iterator[Tuple[str, str]]: Iterator over absolute and dataset-relative paths.
"""
for root, _, files in os.walk(dataset_root):
files = filter(lambda file: file.endswith(shard_suffix), files)
files = (os.path.join(root, file) for file in files)
files = sorted(files)
for path in files:
dataset_rel_path = get_dataset_relative_path(dataset_root, path)
yield path, dataset_rel_path


def get_column(val: Any) -> str:
"""Get the MDS column encoding of one field.

Args:
val (Any): The field.

Returns:
str: Its corresponding MDS encoding.
"""
if isinstance(val, int):
return 'int'
elif isinstance(val, str):
return 'str'
else:
raise ValueError('Unsupported column type: {type(val)}.')


def get_columns(sample: Dict[str, Any]) -> Tuple[List[str], List[str], List[Optional[int]]]:
"""Get column names, encodings, and sizes.

Args:
sample (Dict[str, Any]): A sample to derive column info from.

Returns:
Tuple[List[str], List[str], List[Optional[int]]]: Column names, encodings, and sizes.
"""
col_names = sorted(sample)
col_encs = []
for name in col_names:
val = sample[name]
enc = get_column(val)
col_encs.append(enc)
col_sizes = list(map(get_mds_encoded_size, col_encs))
return col_names, col_encs, col_sizes


def get_shard_info(path: str, dataset_rel_path: str) -> Dict[str, Any]:
"""Get info the index needs about a Parquet shard.

Args:
path (str): Absolute or relative-to-cwd file path.
dataset_rel_path (str): Relative-to-dataset file path.

Returns:
Dict[str, Any]: Shard info.
"""
num_bytes = os.stat(path).st_size
table = pq.read_table(path)
samples = table.to_pylist()
num_samples = len(samples)
col_names, col_encs, col_sizes = get_columns(samples[0])
return {
'version': 2,
'format': 'pq',
'column_names': col_names,
'column_encodings': col_encs,
'column_sizes': col_sizes,
'raw_parquet': {
'basename': dataset_rel_path,
'bytes': num_bytes
},
'raw_data': {
'basename': dataset_rel_path + '.mds'
},
'samples': num_samples
}


def main(args: Namespace) -> None:
"""Index a parquet dataset for use by Streaming.

Args:
args (Namespace): Command-line arguments.
"""
infos = []
for path, dataset_rel_path in each_shard_path(args.dataset, args.shard_suffix):
info = get_shard_info(path, dataset_rel_path)
infos.append(info)
obj = {
'version': 2,
'shards': infos,
}
filename = os.path.join(args.dataset, 'index.json')
if os.path.exists(filename):
raise ValueError(f'Index file {filename} already exists.')
with open(filename, 'w') as out:
json.dump(obj, out)


if __name__ == '__main__':
main(parse_args())
Loading