Skip to content

Commit

Permalink
Break Stream in two, rewrite core part of Stream, in prep for new sha…
Browse files Browse the repository at this point in the history
…rds (#547)

* Stream -> StreamCore (Shard args) + Stream (all).

* Drop pointless underscore vars.

* Auto keyword.

* Fix handling of generating local when default split.

* Clean up.

* Improve apply_defaults().

* Plug it in, propagate rewrites outward.

* Adjust keep_old_phases vs keep_zip handling.

* Adjust hash args handling.

* Default apply_defaults() args to auto so you don't have to provide them all.

* Update usage in test cases.

* Fix edge case.

* Another tweak.
  • Loading branch information
knighton authored Dec 25, 2023
1 parent 3972c9d commit bdb5725
Show file tree
Hide file tree
Showing 9 changed files with 762 additions and 289 deletions.
116 changes: 60 additions & 56 deletions simulation/core/sim_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,41 +33,50 @@ class SimulationDataset(StreamingDataset):
nodes (int): Number of nodes.
devices (int): Number of devices.
workers (int): Number of workers.
streams (Optional[Sequence[Stream]]): One or more streams to stream/cache samples from,
streams (Sequence[Stream], optional): One or more streams to stream/cache samples from,
which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
remote (Optional[str]): Remote path or directory to download the dataset from. If ``None``,
remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
its data must exist locally. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
local (Optional[str]): Local working directory to download shards to. This is where shards
local (str, optional): Local working directory to download shards to. This is where shards
are cached while they are being used. Uses a temp directory if not set.
StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``.
split (Optional[str]): Which dataset split to use, if any. If provided, we stream from/to
split (str, optional): Which dataset split to use, if any. If provided, we stream from/to
the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``.
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (Optional[str]): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
``False``.
epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all
streams. If ``None``, takes its value from the total number of underlying samples.
Provide this field if you are weighting streams relatively to target a larger or
smaller epoch size. Defaults to ``None``. Can also take in human-readable number
abbreviations (e.g., ``"100k"``, ``"64M"``, ``"77b"``, and so on). Defaults to ``None``.
download_retry (int): Number of download re-attempts before raising an error. Defaults to
``2``.
download_timeout (str | float): Time in seconds to wait for a file download to complete
before raising an error. Streaming duration shorthand (e.g., ``1m23s``) is also
accepted. Defaults to ``1m``.
hash_algos (str | Sequence[str], optional): Ranked list of hashing algorithms to try.
Defaults to ``None``.
validate_hash (str, optional): Deprecated. See ``hash_algos``. Defaults to ``None``.
keep_old_phases (str): Which old phases of shard files to cache (until shard eviction).
Must be one of ``nil``, ``src``, or ``all``. Defaults to ``nil``.
keep_zip (bool, optional): Deprecated. See ``keep_old_phases``. Defaults to ``None``.
epoch_size (Union[str, int], optional): Number of samples to draw per epoch balanced
across all streams. If ``None``, takes its value from the total number of underlying
samples. Provide this field if you are weighting streams relatively to target a larger
or smaller epoch size. Defaults to ``None``. Can also take in human-readable number
abbreviations (e.g., ``"100k"``, ``"64M"``, ``"77b"``, etc). Defaults to ``None``.
predownload (int, optional): Target number of samples to download per worker in advance
of current sample. Workers will attempt to download ahead by this many samples during,
but not before, training. Recommendation is to provide a value greater than per device
batch size to ensure at-least per device batch size number of samples cached locally.
If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``.
cache_limit (Union[int, str], optional): Maximum size in bytes of this StreamingDataset's
cache_limit (Union[str, int], optional): Maximum size in bytes of this StreamingDataset's
shard cache. Before downloading a shard, the least recently used resident shard(s)
may be evicted (deleted from the local cache) in order to stay under the limit.
Set to ``None`` to disable shard eviction. Supports integer bytes as well as string
human-readable bytes (e.g., ``100b``, ``64kb``, ``77mb``, and so on). Defaults to
``None``.
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
Defaults to ``balanced``.
sampling_granularity (int): When picking samples for a stream's final partial repeat,
how many samples to pick from the same shard at a time (``1`` for evenly balanced
across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
Defaults to ``1``.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``relaxed``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
resumption. The sample space is divided evenly according to the number of canonical
Expand All @@ -86,17 +95,11 @@ class SimulationDataset(StreamingDataset):
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
shuffle_seed (int): Seed for deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int, optional): Unit of shuffle. A canonical node's samples are split
into blocks of this size, and samples within each block are shuffled. If ``None``, its
value is calculated as ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to
``None``.
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
Defaults to ``balanced``.
sampling_granularity (int): When picking samples for a stream's final partial repeat,
how many samples to pick from the same shard at a time (``1`` for evenly balanced
across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
Defaults to ``1``.
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
``per_stream``. Defaults to ``random``.
allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
Expand All @@ -105,6 +108,7 @@ class SimulationDataset(StreamingDataset):
"""

def __init__(self,
*,
nodes: int,
devices: int,
workers: int,
Expand All @@ -113,42 +117,44 @@ def __init__(self,
local: Optional[str] = None,
split: Optional[str] = None,
download_retry: int = 2,
download_timeout: float = 60,
download_timeout: Union[str, float] = '1m',
hash_algos: Optional[Union[str, Sequence[str]]] = None,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[Union[int, str]] = None,
keep_old_phases: str = 'nil',
keep_zip: Optional[bool] = None,
epoch_size: Optional[Union[str, int]] = None,
predownload: Optional[int] = None,
cache_limit: Optional[Union[int, str]] = None,
cache_limit: Optional[Union[str, int]] = None,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
partition_algo: str = 'relaxed',
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = False,
shuffle_algo: str = 'py1e',
shuffle_seed: int = 9176,
shuffle_block_size: Optional[int] = None,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
batching_method: str = 'random',
allow_unsafe_types: bool = False) -> None:

# Time how long it takes for StreamingDataset instantiation
t0 = time.time()

# Global arguments (which do not live in Streams).
self.nodes = nodes
self.devices = devices
self.workers = workers
self.partition_algo = partition_algo

# Global arguments (which do not live in Streams).
self.predownload = predownload
self.sampling_method = sampling_method
self.sampling_granularity = sampling_granularity
self.partition_algo = partition_algo
self.num_canonical_nodes = num_canonical_nodes
self.batch_size = batch_size
self.shuffle = shuffle
self.shuffle_algo = shuffle_algo
self.shuffle_seed = shuffle_seed
self.shuffle_block_size = shuffle_block_size
self.sampling_method = sampling_method
self.sampling_granularity = sampling_granularity
self.batching_method = batching_method
self.num_canonical_nodes = num_canonical_nodes
self.allow_unsafe_types = allow_unsafe_types

self.initial_physical_nodes = nodes
Expand Down Expand Up @@ -197,26 +203,24 @@ def __init__(self,

# Initialize the Stream defaults and normalize to a list of Streams.
if streams:
default = {
'remote': remote,
'local': local,
'split': split,
'download_retry': download_retry,
'download_timeout': download_timeout,
'validate_hash': validate_hash,
'keep_zip': keep_zip,
}
for stream in streams:
stream.apply_default(default)
stream.apply_defaults(split=split,
download_retry=download_retry,
download_timeout=download_timeout,
hash_algos=hash_algos,
validate_hash=validate_hash,
keep_old_phases=keep_old_phases,
keep_zip=keep_zip)
else:
default = Stream(remote=remote,
streams = Stream(remote=remote,
local=local,
split=split,
download_retry=download_retry,
download_timeout=download_timeout,
hash_algos=hash_algos,
validate_hash=validate_hash,
keep_zip=keep_zip)
streams = [default]
keep_old_phases=keep_old_phases,
keep_zip=keep_zip),

# Validate the stream weighting scheme (relative or absolute) to catch errors before we go
# to the trouble of loading them.
Expand All @@ -231,10 +235,10 @@ def __init__(self,
indices_created = []
for stream_idx, stream in enumerate(self.streams):
if stream.remote:
filepath = os.path.join(stream.remote, stream.split, get_index_basename())
filepath = os.path.join(stream.remote, stream.split or '', get_index_basename())
indices_created.append(0)
else:
filepath = os.path.join(stream.local, stream.split, get_index_basename())
filepath = os.path.join(stream.local, stream.split or '', get_index_basename())
# This suffix means a mock index file was created. Have to clean up later.
if stream.local.split('_')[-1] == 'indexcreated':
indices_created.append(2)
Expand All @@ -245,9 +249,9 @@ def __init__(self,
'path': filepath,
'local': stream.local,
'remote': stream.remote,
'proportion': stream._proportion,
'repeat': stream._repeat,
'choose': stream._choose
'proportion': getattr(stream, 'proportion', None),
'repeat': getattr(stream, 'repeat', None),
'choose': getattr(stream, 'choose', None),
}

# Initialize the SimulationWorld, which tells us about nodes/devices/workers
Expand All @@ -267,7 +271,7 @@ def __init__(self,
logger.info(f' Processing index file for stream {stream_id + 1}')
stream_shards = stream.get_shards(self.world, self.allow_unsafe_types)
num_stream_samples = sum(map(len, stream_shards))
index_filename = os.path.join(stream.local, stream.split, get_index_basename())
index_filename = os.path.join(stream.local, stream.split or '', get_index_basename())
index_filenames.append(index_filename)
local_foldernames.append(stream.local)
if not num_stream_samples:
Expand Down
30 changes: 24 additions & 6 deletions simulation/core/yaml_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,29 @@ def create_simulation_dataset(nodes: int, devices: int, workers: int, global_bat
sampling_granularity = train_dataset.get('sampling_granularity', 1)
batching_method = train_dataset.get('batching_method', 'random')

dataset = SimulationDataset(nodes, devices, workers, streams, remote, local, split,
download_retry, download_timeout, validate_hash, keep_zip,
epoch_size, predownload, cache_limit, partition_algo,
num_canonical_nodes, batch_size, shuffle, shuffle_algo,
shuffle_seed, shuffle_block_size, sampling_method,
sampling_granularity, batching_method)
dataset = SimulationDataset(nodes=nodes,
devices=devices,
workers=workers,
streams=streams,
remote=remote,
local=local,
split=split,
download_retry=download_retry,
download_timeout=download_timeout,
validate_hash=validate_hash,
keep_zip=keep_zip,
epoch_size=epoch_size,
predownload=predownload,
cache_limit=cache_limit,
partition_algo=partition_algo,
num_canonical_nodes=num_canonical_nodes,
batch_size=batch_size,
shuffle=shuffle,
shuffle_algo=shuffle_algo,
shuffle_seed=shuffle_seed,
shuffle_block_size=shuffle_block_size,
sampling_method=sampling_method,
sampling_granularity=sampling_granularity,
batching_method=batching_method)

return dataset
Loading

0 comments on commit bdb5725

Please sign in to comment.