From c00c0932384a5bdb7e619da626a1426243715514 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 20 Oct 2023 16:03:53 -0700 Subject: [PATCH 01/10] streaming dataset better defaults --- streaming/base/batching/per_stream.py | 1 + streaming/base/batching/random.py | 1 + streaming/base/batching/stratified.py | 1 + streaming/base/dataset.py | 52 +++++++++++++++------------ 4 files changed, 33 insertions(+), 22 deletions(-) diff --git a/streaming/base/batching/per_stream.py b/streaming/base/batching/per_stream.py index 8e686cc58..6fb65fe68 100644 --- a/streaming/base/batching/per_stream.py +++ b/streaming/base/batching/per_stream.py @@ -63,6 +63,7 @@ def generate_work_per_stream_batching(dataset: StreamingDataset, world: World, e # same as the ratio of the stream's samples to overall samples. # This ensures that the overall training shuffle block size is still approximately # equal to what is set by the user, and allows for reasoning about cache_limit as well. + assert isinstance(dataset.shuffle_block_size, int) shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion) stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units, dataset.num_canonical_nodes, dataset.shuffle_seed, epoch, diff --git a/streaming/base/batching/random.py b/streaming/base/batching/random.py index 7b8d3717c..193265e37 100644 --- a/streaming/base/batching/random.py +++ b/streaming/base/batching/random.py @@ -57,6 +57,7 @@ def generate_work_random_batching(dataset: StreamingDataset, world: World, epoch # If we need to shuffle, shuffle in a node-aware and *underlying* shard-aware way. if dataset.shuffle: + assert isinstance(dataset.shuffle_block_size, int) shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units, dataset.num_canonical_nodes, dataset.shuffle_seed, epoch, dataset.shuffle_block_size) big_ids = np.where(big_ids != -1, shuffle[big_ids], -1) diff --git a/streaming/base/batching/stratified.py b/streaming/base/batching/stratified.py index 85e7008a6..791eb39dd 100644 --- a/streaming/base/batching/stratified.py +++ b/streaming/base/batching/stratified.py @@ -74,6 +74,7 @@ def generate_work_stratified_batching(dataset: StreamingDataset, world: World, e # same as the ratio of the stream's samples to overall samples. # This ensures that the overall training shuffle block size is still approximately # equal to what is set by the user, and allows for reasoning about cache_limit as well. + assert isinstance(dataset.shuffle_block_size, int) shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion) stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units, dataset.num_canonical_nodes, dataset.shuffle_seed, epoch, diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 4ab95d37f..d881382c0 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -265,9 +265,7 @@ class StreamingDataset(Array, IterableDataset): of current sample. Workers will attempt to download ahead by this many samples during, but not before, training. Recommendation is to provide a value greater than per device batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. + If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``. cache_limit (Union[int, str], optional): Maximum size in bytes of this StreamingDataset's shard cache. Before downloading a shard, the least recently used resident shard(s) may be evicted (deleted from the local cache) in order to stay under the limit. @@ -285,8 +283,8 @@ class StreamingDataset(Array, IterableDataset): resumption. The sample space is divided evenly according to the number of canonical nodes. The higher the value, the more independent non-overlapping paths the StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. + source diversity). Defaults to ``None``, which is interpreted as the number of physical + nodes of the initial run. .. note:: @@ -296,10 +294,12 @@ class StreamingDataset(Array, IterableDataset): partitioned over the workers. Defaults to ``None``. shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. + shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``. shuffle_seed (int): Seed for deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. A canonical node's samples are split into blocks - of this size, and samples within each block are shuffled. Defaults to ``1 << 18``. + shuffle_block_size (int, optional): Unit of shuffle. A canonical node's samples are split + into blocks of this size, and samples within each block are shuffled. If ``None``, its + value is calculated as ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to + ``None``. batching_method (str): Which batching method to use, either ``random``, ``stratified``, or ``per_stream``. Defaults to ``random``. """ @@ -323,9 +323,9 @@ def __init__(self, num_canonical_nodes: Optional[int] = None, batch_size: Optional[int] = None, shuffle: bool = False, - shuffle_algo: str = 'py1s', + shuffle_algo: str = 'py1e', shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, + shuffle_block_size: Optional[int] = None, batching_method: str = 'random') -> None: # Global arguments (which do not live in Streams). self.predownload = predownload @@ -377,12 +377,15 @@ def __init__(self, raise ValueError(f'`shuffle_seed` must be a non-negative integer, but got: ' + f'{self.shuffle_seed}.') - # Check that predownload is at least per device batch size. + # Check that predownload is at least per device batch size, and set it if currently None. if self.predownload is not None and self.batch_size is not None and \ self.predownload < self.batch_size: warnings.warn(f'predownload < batch_size ({self.predownload} < {self.batch_size}).' + f'This may result in slower batch time. Recommendation is to set ' + f'predownload to at-least batch_size.') + elif self.predownload is None: + self.predownload = 8 * self.batch_size if self.batch_size is not None else 64 + # Convert epoch size from string to int, if needed. Cannot be negative. epoch_size_value = None if epoch_size: @@ -632,12 +635,11 @@ def __len__(self) -> int: """ return self.length - def _set_predownload(self) -> None: - """Set the predownload value which is per number of workers.""" - if self.predownload is None: - self.predownload = max( - self.batch_size, 256 * self.batch_size // self.num_canonical_nodes - ) if self.batch_size is not None and self.num_canonical_nodes is not None else 512 + def _set_shuffle_block_size(self): + """Set the shuffle block size value.""" + if self.shuffle_block_size is None: + self.shuffle_block_size = max(4_000_000 // self.num_canonical_nodes, 1 << 18) \ + if self.num_canonical_nodes is not None else 1 << 18 def _resume(self, world: World, epoch: int) -> Tuple[int, int]: """Either resume from checkpoint or start at the beginning. @@ -656,8 +658,11 @@ def _resume(self, world: World, epoch: int) -> Tuple[int, int]: except FileNotFoundError: # There is nothing to resume. if not self.num_canonical_nodes: - self.num_canonical_nodes = world.num_nodes * 64 - self._set_predownload() + if self.shuffle_algo in ['py1s', 'py2s']: + self.num_canonical_nodes = 64 * world.num_nodes + else: + self.num_canonical_nodes = world.num_nodes + self._set_shuffle_block_size() return epoch, 0 # SharedMemory buffers may contain additional null bytes at the end. @@ -669,8 +674,11 @@ def _resume(self, world: World, epoch: int) -> Tuple[int, int]: # Check if the resume state is stale. if obj['epoch'] < epoch: if not self.num_canonical_nodes: - self.num_canonical_nodes = world.num_nodes * 64 - self._set_predownload() + if self.shuffle_algo in ['py1s', 'py2s']: + self.num_canonical_nodes = 64 * world.num_nodes + else: + self.num_canonical_nodes = world.num_nodes + self._set_shuffle_block_size() return epoch, 0 # Load the correct resumption meta data. @@ -678,7 +686,7 @@ def _resume(self, world: World, epoch: int) -> Tuple[int, int]: sample_in_epoch = obj['sample_in_epoch'] self.num_canonical_nodes = obj['num_canonical_nodes'] self.shuffle_seed = obj['shuffle_seed'] - self._set_predownload() + self._set_shuffle_block_size() return epoch, sample_in_epoch From 5723286b3e082481af5424d4ff32a4a7928ad47f Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 20 Oct 2023 16:07:13 -0700 Subject: [PATCH 02/10] relaxed partition default --- streaming/base/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index d881382c0..7f166b2f6 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -278,7 +278,7 @@ class StreamingDataset(Array, IterableDataset): how many samples to pick from the same shard at a time (``1`` for evenly balanced across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc). Defaults to ``1``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. + partition_algo (str): Which partitioning algorithm to use. Defaults to ``relaxed``. num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption. The sample space is divided evenly according to the number of canonical nodes. The higher the value, the more independent non-overlapping paths the @@ -319,7 +319,7 @@ def __init__(self, cache_limit: Optional[Union[int, str]] = None, sampling_method: str = 'balanced', sampling_granularity: int = 1, - partition_algo: str = 'orig', + partition_algo: str = 'relaxed', num_canonical_nodes: Optional[int] = None, batch_size: Optional[int] = None, shuffle: bool = False, From eff06e24dc4640dd466fb42f17f4fc027f6970b7 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 20 Oct 2023 16:21:36 -0700 Subject: [PATCH 03/10] modified dosctring --- streaming/base/dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 7f166b2f6..1df00cfbc 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -283,8 +283,9 @@ class StreamingDataset(Array, IterableDataset): resumption. The sample space is divided evenly according to the number of canonical nodes. The higher the value, the more independent non-overlapping paths the StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as the number of physical - nodes of the initial run. + source diversity). If ``None``, this is interpreted as 64 times the number of physical + nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the + number of physical nodes of the initial run otherwise. Defaults to ``None``. .. note:: From 70caf40382026c7cf16a9ffec58b598a2686d81d Mon Sep 17 00:00:00 2001 From: Saaketh Date: Thu, 26 Oct 2023 08:47:59 -0700 Subject: [PATCH 04/10] removed assert statements --- streaming/base/batching/per_stream.py | 4 +++- streaming/base/batching/random.py | 4 +++- streaming/base/batching/stratified.py | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/streaming/base/batching/per_stream.py b/streaming/base/batching/per_stream.py index 6fb65fe68..8ccaaf2a4 100644 --- a/streaming/base/batching/per_stream.py +++ b/streaming/base/batching/per_stream.py @@ -63,7 +63,9 @@ def generate_work_per_stream_batching(dataset: StreamingDataset, world: World, e # same as the ratio of the stream's samples to overall samples. # This ensures that the overall training shuffle block size is still approximately # equal to what is set by the user, and allows for reasoning about cache_limit as well. - assert isinstance(dataset.shuffle_block_size, int) + if not isinstance(dataset.shuffle_block_size, int): + raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' + + f'Got {dataset.shuffle_block_size} instead.') shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion) stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units, dataset.num_canonical_nodes, dataset.shuffle_seed, epoch, diff --git a/streaming/base/batching/random.py b/streaming/base/batching/random.py index 193265e37..e562655d6 100644 --- a/streaming/base/batching/random.py +++ b/streaming/base/batching/random.py @@ -57,7 +57,9 @@ def generate_work_random_batching(dataset: StreamingDataset, world: World, epoch # If we need to shuffle, shuffle in a node-aware and *underlying* shard-aware way. if dataset.shuffle: - assert isinstance(dataset.shuffle_block_size, int) + if not isinstance(dataset.shuffle_block_size, int): + raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' + + f'Got {dataset.shuffle_block_size} instead.') shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units, dataset.num_canonical_nodes, dataset.shuffle_seed, epoch, dataset.shuffle_block_size) big_ids = np.where(big_ids != -1, shuffle[big_ids], -1) diff --git a/streaming/base/batching/stratified.py b/streaming/base/batching/stratified.py index 791eb39dd..ee40ba0f2 100644 --- a/streaming/base/batching/stratified.py +++ b/streaming/base/batching/stratified.py @@ -74,7 +74,9 @@ def generate_work_stratified_batching(dataset: StreamingDataset, world: World, e # same as the ratio of the stream's samples to overall samples. # This ensures that the overall training shuffle block size is still approximately # equal to what is set by the user, and allows for reasoning about cache_limit as well. - assert isinstance(dataset.shuffle_block_size, int) + if not isinstance(dataset.shuffle_block_size, int): + raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' + + f'Got {dataset.shuffle_block_size} instead.') shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion) stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units, dataset.num_canonical_nodes, dataset.shuffle_seed, epoch, From 2ad8a89cff59e113fc95c6ad8475a7254414f7aa Mon Sep 17 00:00:00 2001 From: Saaketh Date: Thu, 26 Oct 2023 09:22:11 -0700 Subject: [PATCH 05/10] corrected type in TypeError --- streaming/base/batching/per_stream.py | 2 +- streaming/base/batching/random.py | 2 +- streaming/base/batching/stratified.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/streaming/base/batching/per_stream.py b/streaming/base/batching/per_stream.py index 8ccaaf2a4..d9eed4065 100644 --- a/streaming/base/batching/per_stream.py +++ b/streaming/base/batching/per_stream.py @@ -65,7 +65,7 @@ def generate_work_per_stream_batching(dataset: StreamingDataset, world: World, e # equal to what is set by the user, and allows for reasoning about cache_limit as well. if not isinstance(dataset.shuffle_block_size, int): raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' + - f'Got {dataset.shuffle_block_size} instead.') + f'Got {type(dataset.shuffle_block_size)} instead.') shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion) stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units, dataset.num_canonical_nodes, dataset.shuffle_seed, epoch, diff --git a/streaming/base/batching/random.py b/streaming/base/batching/random.py index e562655d6..82315a448 100644 --- a/streaming/base/batching/random.py +++ b/streaming/base/batching/random.py @@ -59,7 +59,7 @@ def generate_work_random_batching(dataset: StreamingDataset, world: World, epoch if dataset.shuffle: if not isinstance(dataset.shuffle_block_size, int): raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' + - f'Got {dataset.shuffle_block_size} instead.') + f'Got {type(dataset.shuffle_block_size)} instead.') shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units, dataset.num_canonical_nodes, dataset.shuffle_seed, epoch, dataset.shuffle_block_size) big_ids = np.where(big_ids != -1, shuffle[big_ids], -1) diff --git a/streaming/base/batching/stratified.py b/streaming/base/batching/stratified.py index ee40ba0f2..aaa4ea752 100644 --- a/streaming/base/batching/stratified.py +++ b/streaming/base/batching/stratified.py @@ -76,7 +76,7 @@ def generate_work_stratified_batching(dataset: StreamingDataset, world: World, e # equal to what is set by the user, and allows for reasoning about cache_limit as well. if not isinstance(dataset.shuffle_block_size, int): raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' + - f'Got {dataset.shuffle_block_size} instead.') + f'Got {type(dataset.shuffle_block_size)} instead.') shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion) stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units, dataset.num_canonical_nodes, dataset.shuffle_seed, epoch, From e06ebc887dd05daeb9b588e88c8af842e48774ff Mon Sep 17 00:00:00 2001 From: Saaketh Date: Thu, 26 Oct 2023 13:13:55 -0700 Subject: [PATCH 06/10] linting fix --- streaming/base/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index a5e66be67..ed1a3fb93 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -749,12 +749,12 @@ def state_dict(self, num_samples: int, from_beginning: bool) -> Dict[str, Any]: else: sample_in_epoch = offset + num_samples - # If `self.initial_physical_nodes` is None, we are running for the first time, so we set + # If `self.initial_physical_nodes` is None, we are running for the first time, so we set # initial_physical_nodes to the current number of physical nodes. Otherwise, we persist # initial_physical_nodes as the value loaded and set from the resumption state. initial_physical_nodes = world.num_nodes if self.initial_physical_nodes is None \ else self.initial_physical_nodes - + return { 'epoch': epoch, 'sample_in_epoch': sample_in_epoch, From 5397b9cad5ffcf6e229566d3f396ed6647819f3d Mon Sep 17 00:00:00 2001 From: Saaketh Date: Thu, 26 Oct 2023 14:57:08 -0700 Subject: [PATCH 07/10] test cleanup --- streaming/base/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index ed1a3fb93..40cc10b8e 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -382,7 +382,7 @@ def __init__(self, raise ValueError(f'`shuffle_seed` must be a non-negative integer, but got: ' + f'{self.shuffle_seed}.') - # Check that predownload is at least per device batch size, and set it if currently None. + # Check that predownload is at least per device batch size, and set it if currently `None`. if self.predownload is not None and self.batch_size is not None and \ self.predownload < self.batch_size: warnings.warn(f'predownload < batch_size ({self.predownload} < {self.batch_size}).' + From 271a208bf62a3673a1aa07c496beeee2aa782bcf Mon Sep 17 00:00:00 2001 From: Saaketh Date: Thu, 26 Oct 2023 14:59:23 -0700 Subject: [PATCH 08/10] test cleanup --- tests/test_streaming.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 086f894a7..ad55b3659 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -11,6 +11,7 @@ from torch.utils.data import DataLoader from streaming.base import Stream, StreamingDataLoader, StreamingDataset +from streaming.base.util import clean_stale_shared_memory from tests.common.utils import convert_to_mds @@ -762,6 +763,8 @@ def test_streamingdataloader_mid_epoch_resumption(local_remote_dir: Any, batch_s del dataloader del dataset + clean_stale_shared_memory() + dataset = StreamingDataset(local=local_dir, remote=remote_dir, shuffle=shuffle, From 1307b97b0ac6c16820b84d4e9f3dcfd61452a158 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 27 Oct 2023 06:21:56 -0700 Subject: [PATCH 09/10] test modification --- .github/workflows/pytest.yaml | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml index 4a2d49d9d..a1163f6b7 100644 --- a/.github/workflows/pytest.yaml +++ b/.github/workflows/pytest.yaml @@ -47,11 +47,12 @@ jobs: id: tests run: | set -ex - pytest --splits 8 --group 1 --cov-fail-under=10 - pytest --splits 8 --group 2 --cov-fail-under=10 - pytest --splits 8 --group 3 --cov-fail-under=10 - pytest --splits 8 --group 4 --cov-fail-under=10 - pytest --splits 8 --group 5 --cov-fail-under=10 - pytest --splits 8 --group 6 --cov-fail-under=10 - pytest --splits 8 --group 7 --cov-fail-under=10 - pytest --splits 8 --group 8 --cov-fail-under=10 + pytest --splits 9 --group 1 --cov-fail-under=10 + pytest --splits 9 --group 2 --cov-fail-under=10 + pytest --splits 9 --group 3 --cov-fail-under=10 + pytest --splits 9 --group 4 --cov-fail-under=10 + pytest --splits 9 --group 5 --cov-fail-under=10 + pytest --splits 9 --group 6 --cov-fail-under=10 + pytest --splits 9 --group 7 --cov-fail-under=10 + pytest --splits 9 --group 8 --cov-fail-under=10 + pytest --splits 9 --group 9 --cov-fail-under=10 From 47dd7c356b094865a416f8de5e0c665cbf50ca14 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 27 Oct 2023 08:00:29 -0700 Subject: [PATCH 10/10] test modification --- .github/workflows/pytest.yaml | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml index a1163f6b7..8cb994a89 100644 --- a/.github/workflows/pytest.yaml +++ b/.github/workflows/pytest.yaml @@ -47,12 +47,13 @@ jobs: id: tests run: | set -ex - pytest --splits 9 --group 1 --cov-fail-under=10 - pytest --splits 9 --group 2 --cov-fail-under=10 - pytest --splits 9 --group 3 --cov-fail-under=10 - pytest --splits 9 --group 4 --cov-fail-under=10 - pytest --splits 9 --group 5 --cov-fail-under=10 - pytest --splits 9 --group 6 --cov-fail-under=10 - pytest --splits 9 --group 7 --cov-fail-under=10 - pytest --splits 9 --group 8 --cov-fail-under=10 - pytest --splits 9 --group 9 --cov-fail-under=10 + pytest --splits 10 --group 1 --cov-fail-under=10 + pytest --splits 10 --group 2 --cov-fail-under=10 + pytest --splits 10 --group 3 --cov-fail-under=10 + pytest --splits 10 --group 4 --cov-fail-under=10 + pytest --splits 10 --group 5 --cov-fail-under=10 + pytest --splits 10 --group 6 --cov-fail-under=10 + pytest --splits 10 --group 7 --cov-fail-under=10 + pytest --splits 10 --group 8 --cov-fail-under=10 + pytest --splits 10 --group 9 --cov-fail-under=10 + pytest --splits 10 --group 10 --cov-fail-under=10