From 79a8fa2dfd3667640a544ce6304f0b0f93de8027 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Fri, 21 Jun 2024 20:21:34 -0400 Subject: [PATCH 01/61] temporary commit --- src/accelerate/accelerator.py | 8 ++++---- src/accelerate/data_loader.py | 28 ++++++++++++++++++++++++---- src/accelerate/utils/dataclasses.py | 8 ++++++++ src/accelerate/utils/imports.py | 10 ++++++++++ tests/test_accelerator.py | 12 ++++++------ 5 files changed, 52 insertions(+), 14 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 9005e1fb563..1cbf93bc0d1 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -34,7 +34,7 @@ from huggingface_hub import split_torch_state_dict_into_shards from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state -from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches +from .data_loader import DataLoaderDispatcher, DataLoaderWrapper, prepare_data_loader, skip_first_batches from .hooks import AlignDevicesHook from .logging import get_logger from .optimizer import AcceleratedOptimizer @@ -1174,7 +1174,7 @@ def print(self, *args, **kwargs): def _prepare_one(self, obj, first_pass=False, device_placement=None): # First pass of preparation: DataLoader, model, optimizer if first_pass: - if isinstance(obj, torch.utils.data.DataLoader): + if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderWrapper): return self.prepare_data_loader(obj, device_placement=device_placement) elif isinstance(obj, torch.nn.Module): return self.prepare_model(obj, device_placement=device_placement) @@ -1582,7 +1582,7 @@ def _prepare_deepspeed(self, *args): is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args) result = [ - self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj + self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderWrapper) else obj for obj in args ] @@ -1833,7 +1833,7 @@ def _prepare_megatron_lm(self, *args): scheduler = None batch_data = None for obj in args: - if isinstance(obj, torch.utils.data.DataLoader) and batch_data is None: + if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderWrapper) and batch_data is None: batch_data = next(iter(obj)) elif isinstance(obj, torch.nn.Module): model = obj diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index fcf6631f162..8a294151ad0 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -19,6 +19,8 @@ import torch from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler +from accelerate.utils.imports import is_torchdata_stateful_dataloader_available + from .logging import get_logger from .state import AcceleratorState, DistributedType, GradientState, is_torch_xla_available from .utils import ( @@ -35,6 +37,8 @@ synchronize_rng_states, ) +if is_torchdata_stateful_dataloader_available(): + from torchdata.stateful_dataloader import StatefulDataLoader logger = get_logger(__name__) @@ -387,10 +391,26 @@ def end(self): "Cleans up the gradient state after exiting the dataloader" self.gradient_state._remove_dataloader(self) +class DataLoaderWrapper: + """ + Class that wraps around a PyTorch `DataLoader` (or subclasses, such as torchdata's `StatefulDataLoader`). + + """ + def __init__(self, dataset, **kwargs): + if False and is_torchdata_stateful_dataloader_available(): + self.dataloader = StatefulDataLoader(dataset, **kwargs) + else: + self.dataloader = DataLoader(dataset, **kwargs) + + for attr in self.dataloader.__dict__.keys(): + setattr(self, attr, getattr(self.dataloader, attr)) + + def __iter__(self): + return self.dataloader.__iter__() -class DataLoaderShard(DataLoader, DataLoaderStateMixin): +class DataLoaderShard(DataLoaderWrapper, DataLoaderStateMixin): """ - Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup. + Subclass of `DataLoaderWrapper` that will deal with device placement and current distributed setup. Args: dataset (`torch.utils.data.dataset.Dataset`): @@ -559,9 +579,9 @@ def batch_sampler(self): return self._loader.batch_sampler -class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin): +class DataLoaderDispatcher(DataLoaderWrapper, DataLoaderStateMixin): """ - Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each + Subclass of `DataLoaderWrapper` that will iterate and preprocess on process 0 only, then dispatch on each process their part of the batch. Args: diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 11d3bc31aae..2b7e930107c 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -600,6 +600,14 @@ class DataLoaderConfiguration: " prepared dataloader has `pin_memory` set to `True` to work properly." }, ) + use_stateful_dataloader: bool = field( + default=False, + metadata={ + "help": "If set to true, the dataloader prepared by the Accelerator will be backed by " + "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires a version" + " of `torchdata` with StatefulDataLoader to be installed." + }, + ) @dataclass diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 2beb22795cb..9c278c9a3c0 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -405,3 +405,13 @@ def is_xpu_available(check_device=False): def is_dvclive_available(): return _is_package_available("dvclive") + +def is_torchdata_available(): + return _is_package_available("torchdata") + +# TODO: Remove this function once stateful_dataloader is a stable feature in torchdata. +def is_torchdata_stateful_dataloader_available(): + if not _is_package_available("torchdata"): + return False + import torchdata + return hasattr(torchdata, "stateful_dataloader") and hasattr(torchdata.stateful_dataloader, "StatefulDataLoader") \ No newline at end of file diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 6a868f8d356..148cc5f705d 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -230,7 +230,7 @@ def noop(*args, **kwargs): accelerator = Accelerator() assert str(accelerator.state.device) == "cuda:64" - @parameterized.expand((True, False), name_func=parameterized_custom_name_func) + @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) def test_save_load_model(self, use_safetensors): accelerator = Accelerator() model, optimizer, scheduler, train_dl, valid_dl = create_components() @@ -249,7 +249,7 @@ def test_save_load_model(self, use_safetensors): accelerator.load_state(tmpdirname) assert abs(model_signature - get_signature(model)) < 1e-3 - @parameterized.expand([True, False], name_func=parameterized_custom_name_func) + @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) def test_save_model(self, use_safetensors): accelerator = Accelerator() model = torch.nn.Linear(10, 10) @@ -261,7 +261,7 @@ def test_save_model(self, use_safetensors): load_checkpoint_in_model(model, tmpdirname) assert abs(model_signature - get_signature(model)) < 1e-3 - @parameterized.expand([True, False], name_func=parameterized_custom_name_func) + @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) def test_save_sharded_model(self, use_safetensors): accelerator = Accelerator() inputs = torch.randn(3, 3) @@ -277,7 +277,7 @@ def test_save_sharded_model(self, use_safetensors): assert torch.allclose(expected, output, atol=1e-5) - @parameterized.expand([True, False], name_func=parameterized_custom_name_func) + @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) def test_save_model_offload(self, use_safetensors): accelerator = Accelerator() @@ -297,7 +297,7 @@ def test_save_model_offload(self, use_safetensors): output = model(inputs) assert torch.allclose(expected, output, atol=1e-5) - @parameterized.expand([True, False], name_func=parameterized_custom_name_func) + @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) @require_cuda def test_get_state_dict_from_offload(self, use_safetensors): accelerator = Accelerator() @@ -325,7 +325,7 @@ def test_get_state_dict_from_offload(self, use_safetensors): assert cpu_onloaded_layer_weight.device.type == "cpu" assert cuda_onloaded_layer_weight.device.type == "cuda" - @parameterized.expand([True, False], name_func=parameterized_custom_name_func) + @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) def test_save_load_model_with_hooks(self, use_safetensors): accelerator = Accelerator() model, optimizer, scheduler, train_dl, valid_dl = create_components() From efa1e7dd17bc7cf57f7625d4493eaf4611057b87 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Fri, 21 Jun 2024 20:53:32 -0400 Subject: [PATCH 02/61] checkout? --- tests/test_accelerator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 148cc5f705d..6a868f8d356 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -230,7 +230,7 @@ def noop(*args, **kwargs): accelerator = Accelerator() assert str(accelerator.state.device) == "cuda:64" - @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) + @parameterized.expand((True, False), name_func=parameterized_custom_name_func) def test_save_load_model(self, use_safetensors): accelerator = Accelerator() model, optimizer, scheduler, train_dl, valid_dl = create_components() @@ -249,7 +249,7 @@ def test_save_load_model(self, use_safetensors): accelerator.load_state(tmpdirname) assert abs(model_signature - get_signature(model)) < 1e-3 - @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) + @parameterized.expand([True, False], name_func=parameterized_custom_name_func) def test_save_model(self, use_safetensors): accelerator = Accelerator() model = torch.nn.Linear(10, 10) @@ -261,7 +261,7 @@ def test_save_model(self, use_safetensors): load_checkpoint_in_model(model, tmpdirname) assert abs(model_signature - get_signature(model)) < 1e-3 - @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) + @parameterized.expand([True, False], name_func=parameterized_custom_name_func) def test_save_sharded_model(self, use_safetensors): accelerator = Accelerator() inputs = torch.randn(3, 3) @@ -277,7 +277,7 @@ def test_save_sharded_model(self, use_safetensors): assert torch.allclose(expected, output, atol=1e-5) - @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) + @parameterized.expand([True, False], name_func=parameterized_custom_name_func) def test_save_model_offload(self, use_safetensors): accelerator = Accelerator() @@ -297,7 +297,7 @@ def test_save_model_offload(self, use_safetensors): output = model(inputs) assert torch.allclose(expected, output, atol=1e-5) - @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) + @parameterized.expand([True, False], name_func=parameterized_custom_name_func) @require_cuda def test_get_state_dict_from_offload(self, use_safetensors): accelerator = Accelerator() @@ -325,7 +325,7 @@ def test_get_state_dict_from_offload(self, use_safetensors): assert cpu_onloaded_layer_weight.device.type == "cpu" assert cuda_onloaded_layer_weight.device.type == "cuda" - @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) + @parameterized.expand([True, False], name_func=parameterized_custom_name_func) def test_save_load_model_with_hooks(self, use_safetensors): accelerator = Accelerator() model, optimizer, scheduler, train_dl, valid_dl = create_components() From 8dc107dc87195b068e7ef6736b7affe030cb3dca Mon Sep 17 00:00:00 2001 From: byi8220 Date: Fri, 21 Jun 2024 21:37:53 -0400 Subject: [PATCH 03/61] dataloader wrapper --- src/accelerate/data_loader.py | 7 +++++-- src/accelerate/test_utils/scripts/test_script.py | 2 ++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 8a294151ad0..fba7fa5385e 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -393,7 +393,8 @@ def end(self): class DataLoaderWrapper: """ - Class that wraps around a PyTorch `DataLoader` (or subclasses, such as torchdata's `StatefulDataLoader`). + Class that wraps around a PyTorch `DataLoader` (or subclasses of `DataLoader`, such as torchdata's `StatefulDataLoader`). + """ def __init__(self, dataset, **kwargs): @@ -401,12 +402,14 @@ def __init__(self, dataset, **kwargs): self.dataloader = StatefulDataLoader(dataset, **kwargs) else: self.dataloader = DataLoader(dataset, **kwargs) - for attr in self.dataloader.__dict__.keys(): setattr(self, attr, getattr(self.dataloader, attr)) def __iter__(self): return self.dataloader.__iter__() + + def __len__(self): + return self.dataloader.__len__() class DataLoaderShard(DataLoaderWrapper, DataLoaderStateMixin): """ diff --git a/src/accelerate/test_utils/scripts/test_script.py b/src/accelerate/test_utils/scripts/test_script.py index ff09d9daaad..45292b54a0e 100644 --- a/src/accelerate/test_utils/scripts/test_script.py +++ b/src/accelerate/test_utils/scripts/test_script.py @@ -377,6 +377,8 @@ def check_seedable_sampler(): for batch in train_dl: new_items.append(batch["x"]) new_items = torch.cat(new_items) + print(original_items) + print(new_items) assert torch.allclose(original_items, new_items), "Did not obtain the same items with the same seed and epoch." From f342f4c5d9ef8ca3c722af2c27d086e71864114b Mon Sep 17 00:00:00 2001 From: byi8220 Date: Fri, 21 Jun 2024 22:08:42 -0400 Subject: [PATCH 04/61] tmp --- src/accelerate/accelerator.py | 8 ++++---- src/accelerate/data_loader.py | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 1cbf93bc0d1..9005e1fb563 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -34,7 +34,7 @@ from huggingface_hub import split_torch_state_dict_into_shards from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state -from .data_loader import DataLoaderDispatcher, DataLoaderWrapper, prepare_data_loader, skip_first_batches +from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches from .hooks import AlignDevicesHook from .logging import get_logger from .optimizer import AcceleratedOptimizer @@ -1174,7 +1174,7 @@ def print(self, *args, **kwargs): def _prepare_one(self, obj, first_pass=False, device_placement=None): # First pass of preparation: DataLoader, model, optimizer if first_pass: - if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderWrapper): + if isinstance(obj, torch.utils.data.DataLoader): return self.prepare_data_loader(obj, device_placement=device_placement) elif isinstance(obj, torch.nn.Module): return self.prepare_model(obj, device_placement=device_placement) @@ -1582,7 +1582,7 @@ def _prepare_deepspeed(self, *args): is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args) result = [ - self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderWrapper) else obj + self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj for obj in args ] @@ -1833,7 +1833,7 @@ def _prepare_megatron_lm(self, *args): scheduler = None batch_data = None for obj in args: - if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderWrapper) and batch_data is None: + if isinstance(obj, torch.utils.data.DataLoader) and batch_data is None: batch_data = next(iter(obj)) elif isinstance(obj, torch.nn.Module): model = obj diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index fba7fa5385e..2a93fd96776 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -394,8 +394,6 @@ def end(self): class DataLoaderWrapper: """ Class that wraps around a PyTorch `DataLoader` (or subclasses of `DataLoader`, such as torchdata's `StatefulDataLoader`). - - """ def __init__(self, dataset, **kwargs): if False and is_torchdata_stateful_dataloader_available(): From 065849a8c596060066f4f7f86f2f11452a4d4ddb Mon Sep 17 00:00:00 2001 From: byi8220 Date: Fri, 21 Jun 2024 22:31:07 -0400 Subject: [PATCH 05/61] weird failing test --- src/accelerate/data_loader.py | 34 ++++++------------- .../test_utils/scripts/test_script.py | 2 -- 2 files changed, 11 insertions(+), 25 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 2a93fd96776..df36ac8fbae 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -37,8 +37,6 @@ synchronize_rng_states, ) -if is_torchdata_stateful_dataloader_available(): - from torchdata.stateful_dataloader import StatefulDataLoader logger = get_logger(__name__) @@ -391,27 +389,10 @@ def end(self): "Cleans up the gradient state after exiting the dataloader" self.gradient_state._remove_dataloader(self) -class DataLoaderWrapper: - """ - Class that wraps around a PyTorch `DataLoader` (or subclasses of `DataLoader`, such as torchdata's `StatefulDataLoader`). - """ - def __init__(self, dataset, **kwargs): - if False and is_torchdata_stateful_dataloader_available(): - self.dataloader = StatefulDataLoader(dataset, **kwargs) - else: - self.dataloader = DataLoader(dataset, **kwargs) - for attr in self.dataloader.__dict__.keys(): - setattr(self, attr, getattr(self.dataloader, attr)) - def __iter__(self): - return self.dataloader.__iter__() - - def __len__(self): - return self.dataloader.__len__() - -class DataLoaderShard(DataLoaderWrapper, DataLoaderStateMixin): +class DataLoaderShard(DataLoader, DataLoaderStateMixin): """ - Subclass of `DataLoaderWrapper` that will deal with device placement and current distributed setup. + Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup. Args: dataset (`torch.utils.data.dataset.Dataset`): @@ -580,9 +561,9 @@ def batch_sampler(self): return self._loader.batch_sampler -class DataLoaderDispatcher(DataLoaderWrapper, DataLoaderStateMixin): +class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin): """ - Subclass of `DataLoaderWrapper` that will iterate and preprocess on process 0 only, then dispatch on each + Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each process their part of the batch. Args: @@ -796,6 +777,13 @@ def set_sampler(self, sampler): if hasattr(self.batch_sampler, "batch_sampler"): self.batch_sampler.batch_sampler.sampler = sampler +if is_torchdata_stateful_dataloader_available(): + from torchdata.stateful_dataloader import StatefulDataLoader + + class StatefulDataLoaderShard(DataLoaderShard, StatefulDataLoader, DataLoaderStateMixin): + """ + Subclass of DataLoaderShard which inherits from torchdata's `StatefulDataLoader` + """ def get_sampler(dataloader): """ diff --git a/src/accelerate/test_utils/scripts/test_script.py b/src/accelerate/test_utils/scripts/test_script.py index 45292b54a0e..ff09d9daaad 100644 --- a/src/accelerate/test_utils/scripts/test_script.py +++ b/src/accelerate/test_utils/scripts/test_script.py @@ -377,8 +377,6 @@ def check_seedable_sampler(): for batch in train_dl: new_items.append(batch["x"]) new_items = torch.cat(new_items) - print(original_items) - print(new_items) assert torch.allclose(original_items, new_items), "Did not obtain the same items with the same seed and epoch." From 1e3fad194500a87a5a50af627d91b394d5767d51 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Fri, 21 Jun 2024 22:54:14 -0400 Subject: [PATCH 06/61] trying multiple inheritance --- src/accelerate/data_loader.py | 3 ++- src/accelerate/utils/imports.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index df36ac8fbae..e1a403d28e5 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -777,13 +777,14 @@ def set_sampler(self, sampler): if hasattr(self.batch_sampler, "batch_sampler"): self.batch_sampler.batch_sampler.sampler = sampler + if is_torchdata_stateful_dataloader_available(): from torchdata.stateful_dataloader import StatefulDataLoader - class StatefulDataLoaderShard(DataLoaderShard, StatefulDataLoader, DataLoaderStateMixin): """ Subclass of DataLoaderShard which inherits from torchdata's `StatefulDataLoader` """ + pass def get_sampler(dataloader): """ diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 9c278c9a3c0..c7a486b337b 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -411,7 +411,7 @@ def is_torchdata_available(): # TODO: Remove this function once stateful_dataloader is a stable feature in torchdata. def is_torchdata_stateful_dataloader_available(): - if not _is_package_available("torchdata"): + if not is_torchdata_available(): return False import torchdata return hasattr(torchdata, "stateful_dataloader") and hasattr(torchdata.stateful_dataloader, "StatefulDataLoader") \ No newline at end of file From a41cf38443f7e20cbd8a5d08e34055763e8bab14 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Sat, 22 Jun 2024 20:32:26 -0400 Subject: [PATCH 07/61] DataLoaderAdapter --- src/accelerate/accelerator.py | 17 ++++--- src/accelerate/data_loader.py | 76 +++++++++++++++++++++------- src/accelerate/test_utils/testing.py | 14 +++++ tests/test_data_loader.py | 49 ++++++++++++++++++ 4 files changed, 133 insertions(+), 23 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 9005e1fb563..842c3d61ebf 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -34,7 +34,7 @@ from huggingface_hub import split_torch_state_dict_into_shards from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state -from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches +from .data_loader import DataLoaderAdapter, DataLoaderDispatcher, prepare_data_loader, skip_first_batches from .hooks import AlignDevicesHook from .logging import get_logger from .optimizer import AcceleratedOptimizer @@ -569,6 +569,10 @@ def use_seedable_sampler(self): @property def non_blocking(self): return self.dataloader_config.non_blocking + + @property + def use_stateful_dataloader(self): + return self.dataloader_config.use_stateful_dataloader @property def project_dir(self): @@ -1174,7 +1178,7 @@ def print(self, *args, **kwargs): def _prepare_one(self, obj, first_pass=False, device_placement=None): # First pass of preparation: DataLoader, model, optimizer if first_pass: - if isinstance(obj, torch.utils.data.DataLoader): + if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter): return self.prepare_data_loader(obj, device_placement=device_placement) elif isinstance(obj, torch.nn.Module): return self.prepare_model(obj, device_placement=device_placement) @@ -1580,9 +1584,9 @@ def _prepare_deepspeed(self, *args): deepspeed_plugin = self.state.deepspeed_plugin - is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args) + is_dataloader_present = any((isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter)) for obj in args) result = [ - self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj + self._prepare_one(obj, first_pass=True) if (isinstance(obj, torch.utils.data.DataLoader or isinstance(obj, DataLoaderAdapter))) else obj for obj in args ] @@ -1833,7 +1837,7 @@ def _prepare_megatron_lm(self, *args): scheduler = None batch_data = None for obj in args: - if isinstance(obj, torch.utils.data.DataLoader) and batch_data is None: + if (isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter)) and batch_data is None: batch_data = next(iter(obj)) elif isinstance(obj, torch.nn.Module): model = obj @@ -1862,7 +1866,7 @@ def _prepare_megatron_lm(self, *args): counter = 0 result = [] for obj in args: - if isinstance(obj, torch.utils.data.DataLoader): + if (isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter)): result.append(megatron_lm_prepare_data_loader(self, obj)) counter += 1 elif isinstance(obj, MegatronLMDummyDataLoader): @@ -2025,6 +2029,7 @@ def prepare_data_loader( slice_fn_for_dispatch=slice_fn_for_dispatch, use_seedable_sampler=self.use_seedable_sampler, non_blocking=self.non_blocking, + use_stateful_dataloader=self.use_stateful_dataloader, ) self._dataloaders.append(prepared_data_loader) return prepared_data_loader diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index e1a403d28e5..0366b76aad7 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -20,6 +20,8 @@ from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler from accelerate.utils.imports import is_torchdata_stateful_dataloader_available +if is_torchdata_stateful_dataloader_available(): + from torchdata.stateful_dataloader import StatefulDataLoader from .logging import get_logger from .state import AcceleratorState, DistributedType, GradientState, is_torch_xla_available @@ -38,6 +40,7 @@ ) + logger = get_logger(__name__) # kwargs of the DataLoader in min version 1.4.0. @@ -389,10 +392,44 @@ def end(self): "Cleans up the gradient state after exiting the dataloader" self.gradient_state._remove_dataloader(self) +# TODO: Maybe generalize this class? +class DataLoaderAdapter: + """ + A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. + """ + def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs): + self.use_stateful_dataloader = use_stateful_dataloader + if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available(): + raise ValueError("StatefulDataLoader is not available. Please install torchdata to use it.") + if use_stateful_dataloader: + self.base_dataloader = StatefulDataLoader(dataset, **kwargs) + else: + self.base_dataloader = DataLoader(dataset, **kwargs) + + for attr in self.base_dataloader.__dict__.keys(): + setattr(self, attr, getattr(self.base_dataloader, attr)) -class DataLoaderShard(DataLoader, DataLoaderStateMixin): + def __iter__(self): + return iter(self.base_dataloader) + + def __len__(self): + return len(self.base_dataloader) + + def load_state_dict(self): + """ + Only supported for `StatefulDataLoader`. + """ + return self.base_dataloader.load_state_dict() + + def state_dict(self): + """ + Only supported for `StatefulDataLoader`. + """ + return self.base_dataloader.state_dict() + +class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): """ - Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup. + Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup. Args: dataset (`torch.utils.data.dataset.Dataset`): @@ -411,6 +448,8 @@ class DataLoaderShard(DataLoader, DataLoaderStateMixin): A random number generator to keep synchronized across processes. skip_batches (`int`, *optional*, defaults to 0): The number of batches to skip at the beginning. + use_stateful_dataloader (`bool`, *optional*, defaults to `False`): + Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`. **kwargs (additional keyword arguments, *optional*): All other keyword arguments to pass to the regular `DataLoader` initialization. @@ -430,11 +469,12 @@ def __init__( rng_types=None, synchronized_generator=None, skip_batches=0, + use_stateful_dataloader=False, _drop_last: bool = False, _non_blocking: bool = False, **kwargs, ): - super().__init__(dataset, **kwargs) + super().__init__(dataset, use_stateful_dataloader, **kwargs) self.device = device self.rng_types = rng_types self.synchronized_generator = synchronized_generator @@ -561,9 +601,9 @@ def batch_sampler(self): return self._loader.batch_sampler -class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin): +class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin): """ - Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each + Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process their part of the batch. Args: @@ -576,6 +616,8 @@ class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin): size of the `dataloader` is a round multiple of `batch_size`. skip_batches (`int`, *optional*, defaults to 0): The number of batches to skip at the beginning of an iteration. + use_stateful_dataloader (`bool`, *optional*, defaults to `False`): + Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`. **Available attributes:** @@ -591,6 +633,7 @@ def __init__( dataset, split_batches: bool = False, skip_batches=0, + use_stateful_dataloader=False, _drop_last: bool = False, _non_blocking: bool = False, slice_fn=None, @@ -603,7 +646,7 @@ def __init__( # We need to save the shuffling state of the DataPipe if isinstance(dataset, ShufflerIterDataPipe): shuffle = dataset._shuffle_enabled - super().__init__(dataset, **kwargs) + super().__init__(dataset, use_stateful_dataloader, **kwargs) self.split_batches = split_batches if shuffle: torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle) @@ -778,14 +821,6 @@ def set_sampler(self, sampler): self.batch_sampler.batch_sampler.sampler = sampler -if is_torchdata_stateful_dataloader_available(): - from torchdata.stateful_dataloader import StatefulDataLoader - class StatefulDataLoaderShard(DataLoaderShard, StatefulDataLoader, DataLoaderStateMixin): - """ - Subclass of DataLoaderShard which inherits from torchdata's `StatefulDataLoader` - """ - pass - def get_sampler(dataloader): """ Get the sampler associated to the dataloader @@ -817,6 +852,7 @@ def prepare_data_loader( slice_fn_for_dispatch: Optional[Callable] = None, use_seedable_sampler: bool = False, non_blocking: bool = False, + use_stateful_dataloader: bool = False, ) -> DataLoader: """ Wraps a PyTorch `DataLoader` to generate batches for one of the processes only. @@ -878,6 +914,10 @@ def prepare_data_loader( non_blocking (`bool`, *optional*, defaults to `False`): If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has `pin_memory` set to `True`, this will help to increase overlap between data transfer and computations. + use_stateful_dataloader (`bool`, *optional*, defaults to `False`): + "If set to true, the dataloader prepared by the Accelerator will be backed by " + "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires a version" + " of `torchdata` with StatefulDataLoader to be installed." Returns: @@ -1066,7 +1106,7 @@ def __len__(self): return len(self.batch_sampler) - self.skip_batches -class SkipDataLoader(DataLoader): +class SkipDataLoader(DataLoaderAdapter): """ Subclass of a PyTorch `DataLoader` that will skip the first batches. @@ -1075,12 +1115,14 @@ class SkipDataLoader(DataLoader): The dataset to use to build this datalaoder. skip_batches (`int`, *optional*, defaults to 0): The number of batches to skip at the beginning. + use_stateful_dataloader (`bool`, *optional*, defaults to `False`): + Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`. kwargs: All other keyword arguments to pass to the regular `DataLoader` initialization. """ - def __init__(self, dataset, skip_batches=0, **kwargs): - super().__init__(dataset, **kwargs) + def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs): + super().__init__(dataset, use_stateful_dataloader, **kwargs) self.skip_batches = skip_batches def __iter__(self): diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index 41e3ac35115..c99029dc20d 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -393,6 +393,20 @@ def require_trackers(test_case): )(test_case) +def require_torchdata_stateful_dataloader(test_case): + """ + Decorator marking a test that requires torchdata.stateful_dataloader. + + These tests are skipped when torchdata with stateful_dataloader module isn't installed. + + """ + try: + import torchdata.stateful_dataloader # noqa F401 + except (ImportError, AssertionError): + return unittest.skip("test requires torchdata.stateful_dataloader")(test_case) + else: + return test_case + class TempDirTestCase(unittest.TestCase): """ A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 2f360d71bcb..86532d38d17 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -28,6 +28,16 @@ skip_first_batches, ) +from accelerate.utils.imports import is_torchdata_stateful_dataloader_available +from accelerate.test_utils.testing import require_torchdata_stateful_dataloader + +if is_torchdata_stateful_dataloader_available(): + from torchdata.stateful_dataloader import ( + StatefulDataLoader, + ) + from accelerate.data_loader import ( + DataLoaderAdapter, + ) class RandomIterableDataset(IterableDataset): # For testing, an iterable dataset of random length @@ -396,3 +406,42 @@ def test_end_of_dataloader_dispatcher(self): # Test it also works on the second iteration for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) + + +class StatefulDataLoaderTester(unittest.TestCase): + + @require_torchdata_stateful_dataloader + def test_skip_data_loader(self): + dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2, use_stateful_dataloader=True) + + assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] + + @require_torchdata_stateful_dataloader + def test_skip_first_batches(self): + dataloader = StatefulDataLoader(list(range(16)), batch_size=4) + new_dataloader = skip_first_batches(dataloader, num_batches=2) + + assert [t.tolist() for t in new_dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] + + @require_torchdata_stateful_dataloader + def test_end_of_dataloader(self): + dataloader = DataLoaderShard(list(range(16)), batch_size=4, use_stateful_dataloader=True) + + for idx, _ in enumerate(dataloader): + assert dataloader.end_of_dataloader == (idx == 3) + + # Test it also works on the second iteration + for idx, _ in enumerate(dataloader): + assert dataloader.end_of_dataloader == (idx == 3) + + @require_torchdata_stateful_dataloader + def test_end_of_dataloader_dispatcher(self): + Accelerator() + dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) + + for idx, _ in enumerate(dataloader): + assert dataloader.end_of_dataloader == (idx == 3) + + # Test it also works on the second iteration + for idx, _ in enumerate(dataloader): + assert dataloader.end_of_dataloader == (idx == 3) From 88314882b90bb73e93d569e23f4675a1d47a759b Mon Sep 17 00:00:00 2001 From: byi8220 Date: Sat, 22 Jun 2024 20:34:41 -0400 Subject: [PATCH 08/61] make style --- src/accelerate/accelerator.py | 16 +++++++++++----- src/accelerate/data_loader.py | 20 ++++++++++++-------- src/accelerate/test_utils/testing.py | 1 + src/accelerate/utils/dataclasses.py | 2 +- src/accelerate/utils/imports.py | 5 ++++- tests/test_data_loader.py | 9 +++------ 6 files changed, 32 insertions(+), 21 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 842c3d61ebf..59415617bfd 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -569,7 +569,7 @@ def use_seedable_sampler(self): @property def non_blocking(self): return self.dataloader_config.non_blocking - + @property def use_stateful_dataloader(self): return self.dataloader_config.use_stateful_dataloader @@ -1584,9 +1584,13 @@ def _prepare_deepspeed(self, *args): deepspeed_plugin = self.state.deepspeed_plugin - is_dataloader_present = any((isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter)) for obj in args) + is_dataloader_present = any( + (isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter)) for obj in args + ) result = [ - self._prepare_one(obj, first_pass=True) if (isinstance(obj, torch.utils.data.DataLoader or isinstance(obj, DataLoaderAdapter))) else obj + self._prepare_one(obj, first_pass=True) + if (isinstance(obj, torch.utils.data.DataLoader or isinstance(obj, DataLoaderAdapter))) + else obj for obj in args ] @@ -1837,7 +1841,9 @@ def _prepare_megatron_lm(self, *args): scheduler = None batch_data = None for obj in args: - if (isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter)) and batch_data is None: + if ( + isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter) + ) and batch_data is None: batch_data = next(iter(obj)) elif isinstance(obj, torch.nn.Module): model = obj @@ -1866,7 +1872,7 @@ def _prepare_megatron_lm(self, *args): counter = 0 result = [] for obj in args: - if (isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter)): + if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter): result.append(megatron_lm_prepare_data_loader(self, obj)) counter += 1 elif isinstance(obj, MegatronLMDummyDataLoader): diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 0366b76aad7..ce7593effa3 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -20,6 +20,8 @@ from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler from accelerate.utils.imports import is_torchdata_stateful_dataloader_available + + if is_torchdata_stateful_dataloader_available(): from torchdata.stateful_dataloader import StatefulDataLoader @@ -40,7 +42,6 @@ ) - logger = get_logger(__name__) # kwargs of the DataLoader in min version 1.4.0. @@ -392,11 +393,13 @@ def end(self): "Cleans up the gradient state after exiting the dataloader" self.gradient_state._remove_dataloader(self) + # TODO: Maybe generalize this class? class DataLoaderAdapter: """ A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. """ + def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs): self.use_stateful_dataloader = use_stateful_dataloader if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available(): @@ -405,16 +408,16 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * self.base_dataloader = StatefulDataLoader(dataset, **kwargs) else: self.base_dataloader = DataLoader(dataset, **kwargs) - + for attr in self.base_dataloader.__dict__.keys(): setattr(self, attr, getattr(self.base_dataloader, attr)) def __iter__(self): return iter(self.base_dataloader) - + def __len__(self): return len(self.base_dataloader) - + def load_state_dict(self): """ Only supported for `StatefulDataLoader`. @@ -427,6 +430,7 @@ def state_dict(self): """ return self.base_dataloader.state_dict() + class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): """ Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup. @@ -603,8 +607,8 @@ def batch_sampler(self): class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin): """ - Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each - process their part of the batch. + Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process + their part of the batch. Args: split_batches (`bool`, *optional*, defaults to `False`): @@ -916,8 +920,8 @@ def prepare_data_loader( `pin_memory` set to `True`, this will help to increase overlap between data transfer and computations. use_stateful_dataloader (`bool`, *optional*, defaults to `False`): "If set to true, the dataloader prepared by the Accelerator will be backed by " - "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires a version" - " of `torchdata` with StatefulDataLoader to be installed." + "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). + This requires a version" " of `torchdata` with StatefulDataLoader to be installed." Returns: diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index c99029dc20d..288d3bf64e2 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -407,6 +407,7 @@ def require_torchdata_stateful_dataloader(test_case): else: return test_case + class TempDirTestCase(unittest.TestCase): """ A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 2b7e930107c..811ea7ae4d8 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -604,7 +604,7 @@ class DataLoaderConfiguration: default=False, metadata={ "help": "If set to true, the dataloader prepared by the Accelerator will be backed by " - "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires a version" + "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires a version" " of `torchdata` with StatefulDataLoader to be installed." }, ) diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index c7a486b337b..e1e207f4e6a 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -406,12 +406,15 @@ def is_xpu_available(check_device=False): def is_dvclive_available(): return _is_package_available("dvclive") + def is_torchdata_available(): return _is_package_available("torchdata") + # TODO: Remove this function once stateful_dataloader is a stable feature in torchdata. def is_torchdata_stateful_dataloader_available(): if not is_torchdata_available(): return False import torchdata - return hasattr(torchdata, "stateful_dataloader") and hasattr(torchdata.stateful_dataloader, "StatefulDataLoader") \ No newline at end of file + + return hasattr(torchdata, "stateful_dataloader") and hasattr(torchdata.stateful_dataloader, "StatefulDataLoader") diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 86532d38d17..8e96e4bec85 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -27,17 +27,15 @@ SkipDataLoader, skip_first_batches, ) - -from accelerate.utils.imports import is_torchdata_stateful_dataloader_available from accelerate.test_utils.testing import require_torchdata_stateful_dataloader +from accelerate.utils.imports import is_torchdata_stateful_dataloader_available + if is_torchdata_stateful_dataloader_available(): from torchdata.stateful_dataloader import ( StatefulDataLoader, ) - from accelerate.data_loader import ( - DataLoaderAdapter, - ) + class RandomIterableDataset(IterableDataset): # For testing, an iterable dataset of random length @@ -409,7 +407,6 @@ def test_end_of_dataloader_dispatcher(self): class StatefulDataLoaderTester(unittest.TestCase): - @require_torchdata_stateful_dataloader def test_skip_data_loader(self): dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2, use_stateful_dataloader=True) From 140f1e65f274ecfde6c70d9642ff203f14767a10 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Sat, 22 Jun 2024 21:48:04 -0400 Subject: [PATCH 09/61] Some dark magic dynamic reflection (for backwards compat) --- src/accelerate/data_loader.py | 43 +++++++++---------- .../scripts/external_deps/test_metrics.py | 1 - tests/test_data_loader.py | 9 ++-- 3 files changed, 26 insertions(+), 27 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index ce7593effa3..1ab061a920f 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -394,10 +394,10 @@ def end(self): self.gradient_state._remove_dataloader(self) -# TODO: Maybe generalize this class? class DataLoaderAdapter: """ - A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. + A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For + compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in. """ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs): @@ -409,27 +409,18 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * else: self.base_dataloader = DataLoader(dataset, **kwargs) + # Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641 + # This is pretty awkward, but it's the only way to make `isinstance(obj, StatefulDataLoader)` work transparently. + # It would be better if DataLoaderAdapter does not inherit from DataLoader, but that would not be backwards compatible. + base_cls = self.__class__ + base_cls_name = self.__class__.__name__ + parent_cls_name = self.base_dataloader.__class__ + self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {}) + + # Allow this class to transparently pass through attributes from the underlying class for attr in self.base_dataloader.__dict__.keys(): setattr(self, attr, getattr(self.base_dataloader, attr)) - def __iter__(self): - return iter(self.base_dataloader) - - def __len__(self): - return len(self.base_dataloader) - - def load_state_dict(self): - """ - Only supported for `StatefulDataLoader`. - """ - return self.base_dataloader.load_state_dict() - - def state_dict(self): - """ - Only supported for `StatefulDataLoader`. - """ - return self.base_dataloader.state_dict() - class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): """ @@ -1055,6 +1046,7 @@ def prepare_data_loader( _drop_last=dataloader.drop_last, _non_blocking=non_blocking, slice_fn=slice_fn_for_dispatch, + use_stateful_dataloader=use_stateful_dataloader, **kwargs, ) elif sampler_is_batch_sampler: @@ -1067,7 +1059,7 @@ def prepare_data_loader( _drop_last=dataloader.drop_last, _non_blocking=non_blocking, synchronized_generator=synchronized_generator, - **kwargs, + use_stateful_dataloader=use_stateful_dataloader**kwargs, ) else: dataloader = DataLoaderShard( @@ -1078,6 +1070,7 @@ def prepare_data_loader( synchronized_generator=synchronized_generator, _drop_last=dataloader.drop_last, _non_blocking=non_blocking, + use_stateful_dataloader=use_stateful_dataloader, **kwargs, ) @@ -1177,6 +1170,7 @@ def skip_first_batches(dataloader, num_batches=0): split_batches=dataloader.split_batches, batch_sampler=new_batch_sampler, _drop_last=dataloader._drop_last, + use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs, ) elif isinstance(dataloader, DataLoaderShard): @@ -1193,12 +1187,17 @@ def skip_first_batches(dataloader, num_batches=0): device=dataloader.device, rng_types=dataloader.rng_types, synchronized_generator=dataloader.synchronized_generator, + use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs, ) else: if new_batch_sampler is None: # Need to manually skip batches in the dataloader - dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs) + dataloader = SkipDataLoader( + dataset, skip_batches=num_batches, use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs + ) + elif is_torchdata_stateful_dataloader_available() and isinstance(dataloader, StatefulDataLoader): + dataloader = StatefulDataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) else: dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) diff --git a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py index 9ac13aba626..9925e60a647 100755 --- a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py +++ b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py @@ -227,7 +227,6 @@ def test_gather_for_metrics_drop_last(): num_items = (10 * accelerator.num_processes) + 1 dataloader = DataLoader(range(num_items), batch_size=per_device_batch_size, drop_last=True) dataloader = accelerator.prepare(dataloader) - iterator = iter(dataloader) next(iterator) # Skip first batch tensor([0, 1, 2, 3, 4], device='cuda:0') batch = next(iterator) diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 8e96e4bec85..8053562c4fc 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -410,20 +410,21 @@ class StatefulDataLoaderTester(unittest.TestCase): @require_torchdata_stateful_dataloader def test_skip_data_loader(self): dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2, use_stateful_dataloader=True) - + assert isinstance(dataloader, StatefulDataLoader) assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] @require_torchdata_stateful_dataloader def test_skip_first_batches(self): dataloader = StatefulDataLoader(list(range(16)), batch_size=4) new_dataloader = skip_first_batches(dataloader, num_batches=2) - + assert isinstance(new_dataloader, StatefulDataLoader) assert [t.tolist() for t in new_dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] @require_torchdata_stateful_dataloader def test_end_of_dataloader(self): dataloader = DataLoaderShard(list(range(16)), batch_size=4, use_stateful_dataloader=True) - + assert dataloader.use_stateful_dataloader + assert isinstance(dataloader, StatefulDataLoader) for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) @@ -435,7 +436,7 @@ def test_end_of_dataloader(self): def test_end_of_dataloader_dispatcher(self): Accelerator() dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) - + assert isinstance(dataloader, StatefulDataLoader) for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) From 727afeb9ef2b316418b3ecb23940e53b1d4740e7 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Sat, 22 Jun 2024 22:32:26 -0400 Subject: [PATCH 10/61] typo --- src/accelerate/data_loader.py | 4 ++-- .../test_utils/scripts/external_deps/test_metrics.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 1ab061a920f..442b7bc22f3 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -405,9 +405,9 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available(): raise ValueError("StatefulDataLoader is not available. Please install torchdata to use it.") if use_stateful_dataloader: - self.base_dataloader = StatefulDataLoader(dataset, **kwargs) + self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler **kwargs) else: - self.base_dataloader = DataLoader(dataset, **kwargs) + self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs) # Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641 # This is pretty awkward, but it's the only way to make `isinstance(obj, StatefulDataLoader)` work transparently. diff --git a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py index 9925e60a647..aca0f5ad07c 100755 --- a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py +++ b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py @@ -227,6 +227,7 @@ def test_gather_for_metrics_drop_last(): num_items = (10 * accelerator.num_processes) + 1 dataloader = DataLoader(range(num_items), batch_size=per_device_batch_size, drop_last=True) dataloader = accelerator.prepare(dataloader) + iterator = iter(dataloader) next(iterator) # Skip first batch tensor([0, 1, 2, 3, 4], device='cuda:0') batch = next(iterator) @@ -234,6 +235,11 @@ def test_gather_for_metrics_drop_last(): # Should return a full set of complete batches from each GPU num_expected_items = per_device_batch_size * accelerator.num_processes + print("dataloader.batch_size:", dataloader.batch_size) + print("accelerator.num_processes:", accelerator.num_processes) + print("gathered_items:", gathered_items) + print("batch:", batch) + print("len(dataloader):", len(dataloader)) assert gathered_items.size(0) == ( num_expected_items ), f"Expected number of items: {num_expected_items}, Actual: {gathered_items.size(0)}" From 73683b42a3f31fc08c3fdd987318102b369df4b8 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 25 Jun 2024 14:12:44 -0400 Subject: [PATCH 11/61] some tests --- src/accelerate/data_loader.py | 17 +++++- tests/test_accelerator.py | 98 +++++++++++++++++++++++++++++++++-- tests/test_data_loader.py | 52 +++++++++++++++++++ 3 files changed, 161 insertions(+), 6 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 442b7bc22f3..575d58e6ea5 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -387,10 +387,12 @@ def begin(self): if not self._drop_last: length = getattr(self.dataset, "total_dataset_length", len(self.dataset)) self.remainder = length % self.total_batch_size + print("adding dataloader", self) self.gradient_state._add_dataloader(self) def end(self): "Cleans up the gradient state after exiting the dataloader" + print("removing dataloader", self) self.gradient_state._remove_dataloader(self) @@ -405,7 +407,7 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available(): raise ValueError("StatefulDataLoader is not available. Please install torchdata to use it.") if use_stateful_dataloader: - self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler **kwargs) + self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs) else: self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs) @@ -421,6 +423,16 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * for attr in self.base_dataloader.__dict__.keys(): setattr(self, attr, getattr(self.base_dataloader, attr)) + if hasattr(self.base_dataloader, "state_dict"): + self.dl_state_dict = self.base_dataloader.state_dict() + + def state_dict(self): + return self.dl_state_dict + + def _save_state_dict(self): + if hasattr(self.base_dataloader, "state_dict"): + self.dl_state_dict = self.base_dataloader.state_dict() + class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): """ @@ -498,6 +510,7 @@ def __iter__(self): # But we still move it to the device so it is done before `StopIteration` is reached if self.device is not None: current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking) + self._save_state_dict() next_batch = next(dataloader_iter) if batch_index >= self.skip_batches: yield current_batch @@ -662,12 +675,14 @@ def _fetch_batches(self, iterator): try: if self.split_batches: # One batch of the main iterator is dispatched and split. + self._save_state_dict() batch = next(iterator) else: # num_processes batches of the main iterator are concatenated then dispatched and split. # We add the batches one by one so we have the remainder available when drop_last=False. batches = [] for _ in range(self.state.num_processes): + self._save_state_dict() batches.append(next(iterator)) try: batch = concatenate(batches, dim=0) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 6a868f8d356..ae3a7e4a21d 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -27,18 +27,30 @@ from accelerate.accelerator import Accelerator from accelerate.state import GradientState, PartialState from accelerate.test_utils import require_bnb, require_multi_gpu, require_non_cpu, slow, torch_device -from accelerate.test_utils.testing import AccelerateTestCase, require_cuda, require_non_torch_xla +from accelerate.test_utils.testing import ( + AccelerateTestCase, + require_cuda, + require_non_torch_xla, + require_torchdata_stateful_dataloader, +) from accelerate.utils import patch_environment +from accelerate.utils.dataclasses import DataLoaderConfiguration +from accelerate.utils.imports import is_torchdata_stateful_dataloader_available from accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model -def create_components(): +if is_torchdata_stateful_dataloader_available(): + from torchdata.stateful_dataloader import ( + StatefulDataLoader, + ) + + +def create_components(dataset_size=3): model = torch.nn.Linear(2, 4) optimizer = torch.optim.AdamW(model.parameters(), lr=1.0) scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=2, epochs=1) - train_dl = DataLoader(TensorDataset(torch.tensor([1, 2, 3]))) - valid_dl = DataLoader(TensorDataset(torch.tensor([4, 5, 6]))) - + train_dl = DataLoader(TensorDataset(torch.tensor([i for i in range(1, dataset_size+1)]))) + valid_dl = DataLoader(TensorDataset(torch.tensor([i+dataset_size for i in range(1, dataset_size+1)]))) return model, optimizer, scheduler, train_dl, valid_dl @@ -571,3 +583,79 @@ def test_can_unwrap_model(self): # check that pickle roundtrip works model_loaded = pickle.loads(pickle.dumps(model)) model_loaded(inputs) + + # Ideally would be a parameterized test which works with either stateful or non-stateful dataloaders, but dependencies are a bit awkward. + @require_torchdata_stateful_dataloader + def test_prepared_objects_are_referenced_with_stateful_dataloader(self): + """Test that setting `use_stateful_dataloader=True` in `DataLoaderConfiguration` prepares a `StatefulDataLoader` object instead of a `DataLoader` object.""" + dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) + accelerator = Accelerator(dataloader_config=dataloader_config) + model, optimizer, scheduler, train_dl, valid_dl = create_components() + + ( + prepared_model, + prepared_optimizer, + prepared_scheduler, + prepared_train_dl, + prepared_valid_dl, + ) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl) + + assert prepared_model in accelerator._models + assert prepared_optimizer in accelerator._optimizers + assert prepared_scheduler in accelerator._schedulers + assert prepared_train_dl in accelerator._dataloaders + assert prepared_valid_dl in accelerator._dataloaders + assert isinstance(prepared_train_dl, StatefulDataLoader) + assert isinstance(prepared_valid_dl, StatefulDataLoader) + + @parameterized.expand([True, False], name_func=parameterized_custom_name_func) + @require_torchdata_stateful_dataloader + def test_save_model_with_stateful_dataloader(self, use_safetensors): + """Test that saving and loading a model with a stateful dataloader returns the same model, and that the dataloader's iterator is restored properly.""" + dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) + accelerator = Accelerator(dataloader_config=dataloader_config) + + model, optimizer, scheduler, train_dl, valid_dl = create_components(dataset_size=6) + ( + prepared_model, + prepared_optimizer, + prepared_scheduler, + prepared_train_dl, + prepared_valid_dl, + ) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl) + + assert isinstance(prepared_train_dl, StatefulDataLoader) + assert isinstance(prepared_valid_dl, StatefulDataLoader) + + print("len before iterating", len(prepared_train_dl)) + # Perform 3 training iterations to ensure the dataloader's iterator is advanced + for i, input in enumerate(prepared_train_dl): + print(i, input) + if i == 2: + state_dict = prepared_train_dl.state_dict() + break + + for i, input in enumerate(prepared_train_dl): + print("Pass of initial dict yielding input {}".format(input), prepared_train_dl.state_dict()) + + print("State dict to be loaded", state_dict) + prepared_train_dl.load_state_dict(state_dict) + print("State dict immediately after loading", prepared_train_dl.state_dict()) + + for i, input in enumerate(prepared_train_dl): + print("Pass of loaded dict yielding input {}".format(input), prepared_train_dl.state_dict()) + + + model_signature = get_signature(prepared_model) + with tempfile.TemporaryDirectory() as tmpdirname: + + # Save the model's state. + accelerator.save_model(prepared_model, tmpdirname, safe_serialization=use_safetensors) + + # Load the saved model + loaded_model = prepared_model + load_checkpoint_in_model(loaded_model, tmpdirname) + # make sure loaded weights match + assert abs(model_signature - get_signature(prepared_model)) < 1e-3 + + # iterate through both dataloaders and assert their behaviors are identical \ No newline at end of file diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 8053562c4fc..81cae2b4d21 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -15,6 +15,7 @@ import random import unittest +import torch from torch.utils.data import BatchSampler, DataLoader, IterableDataset from accelerate import Accelerator @@ -438,8 +439,59 @@ def test_end_of_dataloader_dispatcher(self): dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) assert isinstance(dataloader, StatefulDataLoader) for idx, _ in enumerate(dataloader): + print(idx) assert dataloader.end_of_dataloader == (idx == 3) # Test it also works on the second iteration for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) + + @require_torchdata_stateful_dataloader + def test_dataloader_state_dict(self): + """ + Test that saving a stateful dataloader's state, then loading it back, gives the same results. + """ + dataset = list(range(16)) + dataloader = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True) + + assert dataloader.use_stateful_dataloader + assert isinstance(dataloader, StatefulDataLoader) + vals = [] + for idx, val in enumerate(dataloader): + vals.append(val) + if idx == 1: + sd = dataloader.state_dict() + assert len(vals) == 4 + + dataloader2 = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True) + dataloader2.load_state_dict(sd) + + data1 = vals[2:] + data2 = list(dataloader2) + for d1, d2 in zip(data1, data2): + assert torch.allclose(d1, d2) + + @require_torchdata_stateful_dataloader + def test_dataloader_dispatcher_state_dict(self): + """ + Test that saving a stateful dataloader's state, then loading it back, gives the same results. + """ + dataset = list(range(16)) + dataloader = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True) + + assert dataloader.use_stateful_dataloader + assert isinstance(dataloader, StatefulDataLoader) + vals = [] + for idx, val in enumerate(dataloader): + vals.append(val) + if idx == 1: + sd = dataloader.state_dict() + assert len(vals) == 4 + + dataloader2 = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True) + dataloader2.load_state_dict(sd) + + data1 = vals[2:] + data2 = list(dataloader2) + for d1, d2 in zip(data1, data2): + assert torch.allclose(d1, d2) \ No newline at end of file From 32c318eef347431a2e9aec0d11db0cf912159fed Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 25 Jun 2024 14:58:49 -0400 Subject: [PATCH 12/61] more mixin stuff --- src/accelerate/accelerator.py | 14 ++++---- src/accelerate/data_loader.py | 6 ++-- .../test_utils/scripts/test_sync.py | 2 ++ tests/test_data_loader.py | 32 ++++++++++++++++++- 4 files changed, 42 insertions(+), 12 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 59415617bfd..70335f2bcc9 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -34,7 +34,7 @@ from huggingface_hub import split_torch_state_dict_into_shards from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state -from .data_loader import DataLoaderAdapter, DataLoaderDispatcher, prepare_data_loader, skip_first_batches +from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches from .hooks import AlignDevicesHook from .logging import get_logger from .optimizer import AcceleratedOptimizer @@ -1178,7 +1178,7 @@ def print(self, *args, **kwargs): def _prepare_one(self, obj, first_pass=False, device_placement=None): # First pass of preparation: DataLoader, model, optimizer if first_pass: - if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter): + if isinstance(obj, torch.utils.data.DataLoader): return self.prepare_data_loader(obj, device_placement=device_placement) elif isinstance(obj, torch.nn.Module): return self.prepare_model(obj, device_placement=device_placement) @@ -1585,11 +1585,11 @@ def _prepare_deepspeed(self, *args): deepspeed_plugin = self.state.deepspeed_plugin is_dataloader_present = any( - (isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter)) for obj in args + (isinstance(obj, torch.utils.data.DataLoader)) for obj in args ) result = [ self._prepare_one(obj, first_pass=True) - if (isinstance(obj, torch.utils.data.DataLoader or isinstance(obj, DataLoaderAdapter))) + if (isinstance(obj, torch.utils.data.DataLoader)) else obj for obj in args ] @@ -1841,9 +1841,7 @@ def _prepare_megatron_lm(self, *args): scheduler = None batch_data = None for obj in args: - if ( - isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter) - ) and batch_data is None: + if isinstance(obj, torch.utils.data.DataLoader) and batch_data is None: batch_data = next(iter(obj)) elif isinstance(obj, torch.nn.Module): model = obj @@ -1872,7 +1870,7 @@ def _prepare_megatron_lm(self, *args): counter = 0 result = [] for obj in args: - if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter): + if isinstance(obj, torch.utils.data.DataLoader): result.append(megatron_lm_prepare_data_loader(self, obj)) counter += 1 elif isinstance(obj, MegatronLMDummyDataLoader): diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 575d58e6ea5..c7c188a200a 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -396,7 +396,7 @@ def end(self): self.gradient_state._remove_dataloader(self) -class DataLoaderAdapter: +class DataLoaderAdapter(DataLoaderStateMixin): """ A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in. @@ -434,7 +434,7 @@ def _save_state_dict(self): self.dl_state_dict = self.base_dataloader.state_dict() -class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): +class DataLoaderShard(DataLoaderAdapter): """ Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup. @@ -609,7 +609,7 @@ def batch_sampler(self): return self._loader.batch_sampler -class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin): +class DataLoaderDispatcher(DataLoaderAdapter): """ Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process their part of the batch. diff --git a/src/accelerate/test_utils/scripts/test_sync.py b/src/accelerate/test_utils/scripts/test_sync.py index fd829231770..1abfe02da57 100644 --- a/src/accelerate/test_utils/scripts/test_sync.py +++ b/src/accelerate/test_utils/scripts/test_sync.py @@ -310,7 +310,9 @@ def test_dataloader_break(): first_dataloader = DataLoader(first_dset, batch_size=16) second_dset = RegressionDataset(length=96) second_dataloader = DataLoader(second_dset, batch_size=16) + print("Dataloaders to be prepared", first_dataloader, second_dataloader) first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader) + print("Dataloaders prepared", first_dataloader, second_dataloader) assert accelerator.gradient_state.active_dataloader is None for iteration, _ in enumerate(first_dataloader): assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader) diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 81cae2b4d21..5e6f028176a 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -23,6 +23,7 @@ BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard, + DataLoaderStateMixin, IterableDatasetShard, SkipBatchSampler, SkipDataLoader, @@ -378,6 +379,20 @@ def test_skip_batch_sampler(self): new_batch_sampler = SkipBatchSampler(batch_sampler, 2) assert list(new_batch_sampler) == [[8, 9, 10, 11], [12, 13, 14, 15]] + def test_dataloader_inheritance(self): + """`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that subclasses of DataLoaderAdapter are instances of DataLoader and DataLoaderStateMixin.""" + Accelerator() + skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2) + dl_shard = DataLoaderShard(range(16), batch_size=4) + dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4) + assert isinstance(skip_dl, DataLoader) + assert isinstance(dl_shard, DataLoader) + assert isinstance(dl_dispatcher, DataLoader) + + assert isinstance(skip_dl, DataLoaderStateMixin) + assert isinstance(dl_shard, DataLoaderStateMixin) + assert isinstance(dl_dispatcher, DataLoaderStateMixin) + def test_skip_data_loader(self): dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2) assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] @@ -494,4 +509,19 @@ def test_dataloader_dispatcher_state_dict(self): data1 = vals[2:] data2 = list(dataloader2) for d1, d2 in zip(data1, data2): - assert torch.allclose(d1, d2) \ No newline at end of file + assert torch.allclose(d1, d2) + + @require_torchdata_stateful_dataloader + def test_dataloader_inheritance(self): + """`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that when use_stateful_dataloader=True, subclasses of DataLoaderAdapter are instances of StatefulDataLoader and DataLoaderStateMixin.""" + Accelerator() + skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True) + dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True) + dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) + assert isinstance(skip_dl, StatefulDataLoader) + assert isinstance(dl_shard, StatefulDataLoader) + assert isinstance(dl_dispatcher, StatefulDataLoader) + + assert isinstance(skip_dl, DataLoaderStateMixin) + assert isinstance(dl_shard, DataLoaderStateMixin) + assert isinstance(dl_dispatcher, DataLoaderStateMixin) From 57c6f575414e4cf3b149f95cdbca07d6b0e1726c Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 25 Jun 2024 16:33:54 -0400 Subject: [PATCH 13/61] maybe found broken test? --- src/accelerate/test_utils/scripts/test_sync.py | 4 +--- tests/test_accelerator.py | 1 + 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/accelerate/test_utils/scripts/test_sync.py b/src/accelerate/test_utils/scripts/test_sync.py index 1abfe02da57..672c4852046 100644 --- a/src/accelerate/test_utils/scripts/test_sync.py +++ b/src/accelerate/test_utils/scripts/test_sync.py @@ -305,14 +305,12 @@ def test_gradient_accumulation_with_opt_and_scheduler( def test_dataloader_break(): accelerator = Accelerator() - first_dset = RegressionDataset(length=80) first_dataloader = DataLoader(first_dset, batch_size=16) second_dset = RegressionDataset(length=96) second_dataloader = DataLoader(second_dset, batch_size=16) - print("Dataloaders to be prepared", first_dataloader, second_dataloader) first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader) - print("Dataloaders prepared", first_dataloader, second_dataloader) + assert accelerator.gradient_state.active_dataloader is None for iteration, _ in enumerate(first_dataloader): assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index ae3a7e4a21d..63a8639e2d9 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -626,6 +626,7 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors): assert isinstance(prepared_train_dl, StatefulDataLoader) assert isinstance(prepared_valid_dl, StatefulDataLoader) + assert accelerator.gradient_state.active_dataloader is None print("len before iterating", len(prepared_train_dl)) # Perform 3 training iterations to ensure the dataloader's iterator is advanced From ed612d142aba5bc7f5c58ea1ab860902da4532e6 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 25 Jun 2024 17:25:05 -0400 Subject: [PATCH 14/61] this is a very invasive feature --- tests/test_accelerator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 63a8639e2d9..f1ce224e597 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -634,6 +634,8 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors): print(i, input) if i == 2: state_dict = prepared_train_dl.state_dict() + # When breaking out without fully going through the iterator, must call end() to unregister this iterator from gradient state. + prepared_train_dl.end() break for i, input in enumerate(prepared_train_dl): @@ -659,4 +661,4 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors): # make sure loaded weights match assert abs(model_signature - get_signature(prepared_model)) < 1e-3 - # iterate through both dataloaders and assert their behaviors are identical \ No newline at end of file + # iterate through both dataloaders and assert their behaviors are identical From 511050e3a3f05d5c66fdab7d1acbf19e79034eca Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 26 Jun 2024 15:19:45 -0400 Subject: [PATCH 15/61] i think the feature is done? --- src/accelerate/accelerator.py | 8 +- src/accelerate/data_loader.py | 37 +++++++--- tests/test_accelerator.py | 133 +++++++++++++++++++++++++--------- tests/test_data_loader.py | 126 ++++++++++++++++++++++++++++---- 4 files changed, 241 insertions(+), 63 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 70335f2bcc9..ea0e3f64822 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1584,13 +1584,9 @@ def _prepare_deepspeed(self, *args): deepspeed_plugin = self.state.deepspeed_plugin - is_dataloader_present = any( - (isinstance(obj, torch.utils.data.DataLoader)) for obj in args - ) + is_dataloader_present = any((isinstance(obj, torch.utils.data.DataLoader)) for obj in args) result = [ - self._prepare_one(obj, first_pass=True) - if (isinstance(obj, torch.utils.data.DataLoader)) - else obj + self._prepare_one(obj, first_pass=True) if (isinstance(obj, torch.utils.data.DataLoader)) else obj for obj in args ] diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index c7c188a200a..b628c2f559d 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -387,12 +387,10 @@ def begin(self): if not self._drop_last: length = getattr(self.dataset, "total_dataset_length", len(self.dataset)) self.remainder = length % self.total_batch_size - print("adding dataloader", self) self.gradient_state._add_dataloader(self) def end(self): "Cleans up the gradient state after exiting the dataloader" - print("removing dataloader", self) self.gradient_state._remove_dataloader(self) @@ -404,6 +402,7 @@ class DataLoaderAdapter(DataLoaderStateMixin): def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs): self.use_stateful_dataloader = use_stateful_dataloader + if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available(): raise ValueError("StatefulDataLoader is not available. Please install torchdata to use it.") if use_stateful_dataloader: @@ -412,26 +411,41 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs) # Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641 - # This is pretty awkward, but it's the only way to make `isinstance(obj, StatefulDataLoader)` work transparently. - # It would be better if DataLoaderAdapter does not inherit from DataLoader, but that would not be backwards compatible. + # In C++ terms, this is analogous to creating `DataLoaderAdapter : T`, where T is a DataLoader or + # StatefulDataLoader + # + # The same functionality could be achieved by directly creating the required subclasses for both {DataLoader, + # StatefulDataLoader}, however that could lead to much messier code, with duplicated classes and conditional + # dispatching scattered throughout various functions and files. + # + # This code is incredibly awkward but it's the only way to make `isinstance(obj, StatefulDataLoader)` work + # transparently. + # + # A more robust solution is for DataLoaderAdapter to not inherit from DataLoader (compose rather than inherit), + # but this would not be backwards compatible with existing code which assumes + # DataLoaderShard/DataLoaderDispatcher are DataLoaders. base_cls = self.__class__ base_cls_name = self.__class__.__name__ parent_cls_name = self.base_dataloader.__class__ self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {}) # Allow this class to transparently pass through attributes from the underlying class + if hasattr(self.base_dataloader, "state_dict"): + self.dl_state_dict = self.base_dataloader.state_dict() + for attr in self.base_dataloader.__dict__.keys(): setattr(self, attr, getattr(self.base_dataloader, attr)) - if hasattr(self.base_dataloader, "state_dict"): - self.dl_state_dict = self.base_dataloader.state_dict() - def state_dict(self): return self.dl_state_dict - + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + self.dl_state_dict = self.state_dict + def _save_state_dict(self): if hasattr(self.base_dataloader, "state_dict"): - self.dl_state_dict = self.base_dataloader.state_dict() + self.dl_state_dict = super().state_dict() class DataLoaderShard(DataLoaderAdapter): @@ -1074,7 +1088,8 @@ def prepare_data_loader( _drop_last=dataloader.drop_last, _non_blocking=non_blocking, synchronized_generator=synchronized_generator, - use_stateful_dataloader=use_stateful_dataloader**kwargs, + use_stateful_dataloader=use_stateful_dataloader, + **kwargs, ) else: dataloader = DataLoaderShard( @@ -1140,6 +1155,7 @@ def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwa def __iter__(self): for index, batch in enumerate(super().__iter__()): if index >= self.skip_batches: + self._save_state_dict() yield batch @@ -1215,5 +1231,4 @@ def skip_first_batches(dataloader, num_batches=0): dataloader = StatefulDataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) else: dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) - return dataloader diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index f1ce224e597..b73fe8642c5 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools import json import os import pickle @@ -25,6 +26,7 @@ from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch from accelerate.accelerator import Accelerator +from accelerate.data_loader import skip_first_batches from accelerate.state import GradientState, PartialState from accelerate.test_utils import require_bnb, require_multi_gpu, require_non_cpu, slow, torch_device from accelerate.test_utils.testing import ( @@ -37,6 +39,7 @@ from accelerate.utils.dataclasses import DataLoaderConfiguration from accelerate.utils.imports import is_torchdata_stateful_dataloader_available from accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model +from accelerate.utils.random import set_seed if is_torchdata_stateful_dataloader_available(): @@ -45,12 +48,12 @@ ) -def create_components(dataset_size=3): +def create_components(): model = torch.nn.Linear(2, 4) optimizer = torch.optim.AdamW(model.parameters(), lr=1.0) scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=2, epochs=1) - train_dl = DataLoader(TensorDataset(torch.tensor([i for i in range(1, dataset_size+1)]))) - valid_dl = DataLoader(TensorDataset(torch.tensor([i+dataset_size for i in range(1, dataset_size+1)]))) + train_dl = DataLoader(TensorDataset(torch.tensor([1, 2, 3]))) + valid_dl = DataLoader(TensorDataset(torch.tensor([4, 5, 6]))) return model, optimizer, scheduler, train_dl, valid_dl @@ -65,6 +68,23 @@ def forward(self, x): return self.linear2(self.batchnorm(self.linear1(x))) +def create_dataloaders_for_test( + a=2, b=3, batch_size=3, n_train_batches: int = 12, n_valid_batches: int = 2, num_workers=0 +): + "Generates a tuple of dummy DataLoaders to test with" + + def get_dataset(n_batches): + x = torch.randn(batch_size * n_batches, 3) + y = torch.randn(batch_size * n_batches, 5) + return TensorDataset(x, y) + + train_dataset = get_dataset(n_train_batches) + valid_dataset = get_dataset(n_valid_batches) + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers) + valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=num_workers) + return (train_dataloader, valid_dataloader) + + def get_signature(model): return (model.weight.abs().sum() + model.bias.abs().sum()).item() @@ -78,6 +98,8 @@ def parameterized_custom_name_func(func, param_num, param): # customize the test name generator function as we want both params to appear in the sub-test # name, as by default it shows only the first param param_based_name = "use_safetensors" if param.args[0] is True else "use_pytorch" + if len(param.args) > 1: + param_based_name += f"_num_workers_{param.args[1]}" return f"{func.__name__}_{param_based_name}" @@ -608,14 +630,18 @@ def test_prepared_objects_are_referenced_with_stateful_dataloader(self): assert isinstance(prepared_train_dl, StatefulDataLoader) assert isinstance(prepared_valid_dl, StatefulDataLoader) - @parameterized.expand([True, False], name_func=parameterized_custom_name_func) + @parameterized.expand(itertools.product([True, False], [0, 2]), name_func=parameterized_custom_name_func) @require_torchdata_stateful_dataloader - def test_save_model_with_stateful_dataloader(self, use_safetensors): + def test_save_model_with_stateful_dataloader(self, use_safetensors, num_workers): """Test that saving and loading a model with a stateful dataloader returns the same model, and that the dataloader's iterator is restored properly.""" + set_seed(42) dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) accelerator = Accelerator(dataloader_config=dataloader_config) - model, optimizer, scheduler, train_dl, valid_dl = create_components(dataset_size=6) + model, optimizer, scheduler, train_dl, valid_dl = create_components() + train_dl, valid_dl = create_dataloaders_for_test(num_workers=num_workers) + model = ModelForTest() + ( prepared_model, prepared_optimizer, @@ -626,39 +652,80 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors): assert isinstance(prepared_train_dl, StatefulDataLoader) assert isinstance(prepared_valid_dl, StatefulDataLoader) - assert accelerator.gradient_state.active_dataloader is None - print("len before iterating", len(prepared_train_dl)) # Perform 3 training iterations to ensure the dataloader's iterator is advanced - for i, input in enumerate(prepared_train_dl): - print(i, input) - if i == 2: + num_batches_to_skip = 3 + model.train() + for step, batch in enumerate(prepared_train_dl): + x, y = batch + x.to(accelerator.device) + y.to(accelerator.device) + with accelerator.accumulate(prepared_model): + outputs = prepared_model(x) + loss = torch.nn.functional.mse_loss(outputs, y) + accelerator.backward(loss) + prepared_optimizer.step() + prepared_scheduler.step() + prepared_optimizer.zero_grad() + if step == num_batches_to_skip - 1: state_dict = prepared_train_dl.state_dict() # When breaking out without fully going through the iterator, must call end() to unregister this iterator from gradient state. + # TODO: Maybe this could be done automatically? prepared_train_dl.end() break + assert accelerator.gradient_state.active_dataloader is None - for i, input in enumerate(prepared_train_dl): - print("Pass of initial dict yielding input {}".format(input), prepared_train_dl.state_dict()) - - print("State dict to be loaded", state_dict) - prepared_train_dl.load_state_dict(state_dict) - print("State dict immediately after loading", prepared_train_dl.state_dict()) - - for i, input in enumerate(prepared_train_dl): - print("Pass of loaded dict yielding input {}".format(input), prepared_train_dl.state_dict()) - - - model_signature = get_signature(prepared_model) with tempfile.TemporaryDirectory() as tmpdirname: + # Save model for later use + accelerator.save_model(model, tmpdirname, safe_serialization=use_safetensors) - # Save the model's state. - accelerator.save_model(prepared_model, tmpdirname, safe_serialization=use_safetensors) - - # Load the saved model - loaded_model = prepared_model - load_checkpoint_in_model(loaded_model, tmpdirname) - # make sure loaded weights match - assert abs(model_signature - get_signature(prepared_model)) < 1e-3 - - # iterate through both dataloaders and assert their behaviors are identical + # Starting from where we left off, train this model to the end of the DataLoader + prepared_train_dl = skip_first_batches(prepared_train_dl, num_batches_to_skip) + batches_seen_with_original_dl = 0 + for batch in prepared_train_dl: + x, y = batch + x.to(accelerator.device) + y.to(accelerator.device) + with accelerator.accumulate(prepared_model): + outputs = prepared_model(x) + loss = torch.nn.functional.mse_loss(outputs, y) + accelerator.backward(loss) + prepared_optimizer.step() + prepared_scheduler.step() + prepared_optimizer.zero_grad() + batches_seen_with_original_dl += 1 + + original_linear1 = prepared_model.linear1.weight.clone() + original_batchnorm = prepared_model.batchnorm.weight.clone() + original_linear2 = prepared_model.linear2.weight.clone() + + # Load the model and state dict + load_checkpoint_in_model(model, tmpdirname) + stateful_train_dl, _ = create_dataloaders_for_test(num_workers=num_workers) + prepared_stateful_train_dl = accelerator.prepare_data_loader(stateful_train_dl) + prepared_stateful_train_dl.load_state_dict(state_dict) + + # Train this to the end of the DataLoader + batches_seen_with_loaded_dl = 0 + for batch in prepared_stateful_train_dl: + x, y = batch + x.to(accelerator.device) + y.to(accelerator.device) + with accelerator.accumulate(prepared_model): + outputs = prepared_model(x) + loss = torch.nn.functional.mse_loss(outputs, y) + accelerator.backward(loss) + prepared_optimizer.step() + prepared_scheduler.step() + prepared_optimizer.zero_grad() + batches_seen_with_loaded_dl += 1 + + new_linear1 = prepared_model.linear1.weight + new_batchnorm = prepared_model.batchnorm.weight + new_linear2 = prepared_model.linear2.weight + + # Assert equalities + assert batches_seen_with_original_dl == batches_seen_with_loaded_dl + assert torch.allclose(original_linear1, new_linear1) + assert torch.allclose(original_batchnorm, new_batchnorm) + assert torch.allclose(original_linear2, new_linear2) diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 5e6f028176a..e48c20246bf 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -16,6 +16,7 @@ import unittest import torch +from parameterized import parameterized from torch.utils.data import BatchSampler, DataLoader, IterableDataset from accelerate import Accelerator @@ -30,6 +31,7 @@ skip_first_batches, ) from accelerate.test_utils.testing import require_torchdata_stateful_dataloader +from accelerate.utils.dataclasses import DataLoaderConfiguration from accelerate.utils.imports import is_torchdata_stateful_dataloader_available @@ -39,6 +41,13 @@ ) +def parameterized_custom_name_func(func, param_num, param): + # customize the test name generator function as we want both params to appear in the sub-test + # name, as by default it shows only the first param + param_based_name = f"num_workers_{param.args[0]}" + return f"{func.__name__}_{param_based_name}" + + class RandomIterableDataset(IterableDataset): # For testing, an iterable dataset of random length def __init__(self, p_stop=0.01, max_length=1000): @@ -380,7 +389,10 @@ def test_skip_batch_sampler(self): assert list(new_batch_sampler) == [[8, 9, 10, 11], [12, 13, 14, 15]] def test_dataloader_inheritance(self): - """`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that subclasses of DataLoaderAdapter are instances of DataLoader and DataLoaderStateMixin.""" + """ + `DataLoaderAdapter`'s parent classes are dynamically constructed, assert that subclasses of DataLoaderAdapter + are instances of DataLoader and DataLoaderStateMixin. + """ Accelerator() skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2) dl_shard = DataLoaderShard(range(16), batch_size=4) @@ -454,7 +466,6 @@ def test_end_of_dataloader_dispatcher(self): dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) assert isinstance(dataloader, StatefulDataLoader) for idx, _ in enumerate(dataloader): - print(idx) assert dataloader.end_of_dataloader == (idx == 3) # Test it also works on the second iteration @@ -462,48 +473,53 @@ def test_end_of_dataloader_dispatcher(self): assert dataloader.end_of_dataloader == (idx == 3) @require_torchdata_stateful_dataloader - def test_dataloader_state_dict(self): + @parameterized.expand([0, 2], name_func=parameterized_custom_name_func) + def test_dataloader_state_dict(self, num_workers): """ Test that saving a stateful dataloader's state, then loading it back, gives the same results. """ dataset = list(range(16)) - dataloader = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True) + dataloader = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers) assert dataloader.use_stateful_dataloader assert isinstance(dataloader, StatefulDataLoader) - vals = [] + vals = [] for idx, val in enumerate(dataloader): vals.append(val) if idx == 1: sd = dataloader.state_dict() assert len(vals) == 4 - dataloader2 = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True) + dataloader2 = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers) dataloader2.load_state_dict(sd) data1 = vals[2:] data2 = list(dataloader2) for d1, d2 in zip(data1, data2): assert torch.allclose(d1, d2) - + @require_torchdata_stateful_dataloader - def test_dataloader_dispatcher_state_dict(self): + @parameterized.expand([0, 2], name_func=parameterized_custom_name_func) + def test_dataloader_dispatcher_state_dict(self, num_workers): """ Test that saving a stateful dataloader's state, then loading it back, gives the same results. """ + dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) + Accelerator(dataloader_config=dataloader_config) dataset = list(range(16)) - dataloader = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True) + dataloader = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers) assert dataloader.use_stateful_dataloader assert isinstance(dataloader, StatefulDataLoader) - vals = [] + vals = [] for idx, val in enumerate(dataloader): vals.append(val) if idx == 1: sd = dataloader.state_dict() assert len(vals) == 4 - - dataloader2 = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True) + dataloader2 = DataLoaderDispatcher( + dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers + ) dataloader2.load_state_dict(sd) data1 = vals[2:] @@ -513,7 +529,10 @@ def test_dataloader_dispatcher_state_dict(self): @require_torchdata_stateful_dataloader def test_dataloader_inheritance(self): - """`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that when use_stateful_dataloader=True, subclasses of DataLoaderAdapter are instances of StatefulDataLoader and DataLoaderStateMixin.""" + """ + `DataLoaderAdapter`'s parent classes are dynamically constructed, assert that if use_stateful_dataloader=True, + subclasses of DataLoaderAdapter are instances of StatefulDataLoader and DataLoaderStateMixin. + """ Accelerator() skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True) dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True) @@ -525,3 +544,84 @@ def test_dataloader_inheritance(self): assert isinstance(skip_dl, DataLoaderStateMixin) assert isinstance(dl_shard, DataLoaderStateMixin) assert isinstance(dl_dispatcher, DataLoaderStateMixin) + + @require_torchdata_stateful_dataloader + @parameterized.expand([0, 2], name_func=parameterized_custom_name_func) + def test_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader(self, num_workers): + """ + Assert that `state_dict()` and `load_state_dict()` for derived subclasses of `DataLoaderAdapter` produce + the same behavior as `state_dict()` and `load_state_dict()` for `StatefulDataLoader`. + """ + dataset = list(range(64)) + + # Set the seed for reproducibility + def g(): + return torch.Generator().manual_seed(42) + + accelerator = Accelerator() + stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g()) + skip_dl = SkipDataLoader( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + dl_shard = DataLoaderShard( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + dl_dispatcher = DataLoaderDispatcher( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + + stateful_dl_iter = iter(stateful_dl) + iterators_under_test = [iter(skip_dl), iter(dl_shard), iter(dl_dispatcher)] + + num_batches = 8 + # Iterate over all of the dataloaders identically, expect the same values + for _ in range(num_batches): + expected_val = next(stateful_dl_iter).to(accelerator.device) + for dl_iter in iterators_under_test: + val = next(dl_iter).to(accelerator.device) + assert torch.allclose(val, expected_val) + + # The adapters should all produce the same state_dict as the reference stateful dataloader + expected_state_dict = stateful_dl.state_dict() + skip_dl_state_dict = skip_dl.state_dict() + dl_shard_state_dict = dl_shard.state_dict() + dl_dispatcher_state_dict = dl_dispatcher.state_dict() + + assert expected_state_dict == skip_dl_state_dict + assert expected_state_dict == dl_shard_state_dict + assert expected_state_dict == dl_dispatcher_state_dict + + # Load the state dict into new dataloaders + manual_skip_dl = SkipDataLoader( + dataset, batch_size=4, num_workers=num_workers, generator=g(), skip_batches=8, use_stateful_dataloader=True + ) + loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g()) + loaded_stateful_dl.load_state_dict(expected_state_dict) + loaded_skip_dl = SkipDataLoader( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + loaded_skip_dl.load_state_dict(expected_state_dict) + loaded_dl_shard = DataLoaderShard( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + loaded_dl_shard.load_state_dict(expected_state_dict) + loaded_dl_dispatcher = DataLoaderDispatcher( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + loaded_dl_dispatcher.load_state_dict(expected_state_dict) + + # Continue the iteration, expecting identical behavior across the board + iterators_under_test.extend( + [ + iter(manual_skip_dl), + iter(loaded_stateful_dl), + iter(loaded_skip_dl), + iter(loaded_dl_shard), + iter(loaded_dl_dispatcher), + ] + ) + for _ in range(num_batches): + expected_val = next(stateful_dl_iter).to(accelerator.device) + for dl_iter in iterators_under_test: + val = next(dl_iter).to(accelerator.device) + assert torch.allclose(val, expected_val) From df439606c682e4c4ad78f63f50f820e35cea3bb8 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Wed, 26 Jun 2024 20:56:13 +0800 Subject: [PATCH 16/61] add xpu support (#2864) --- .../scripts/external_deps/test_zero3_integration.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/accelerate/test_utils/scripts/external_deps/test_zero3_integration.py b/src/accelerate/test_utils/scripts/external_deps/test_zero3_integration.py index 67e78a7d37c..2bbb324c8cd 100644 --- a/src/accelerate/test_utils/scripts/external_deps/test_zero3_integration.py +++ b/src/accelerate/test_utils/scripts/external_deps/test_zero3_integration.py @@ -14,7 +14,7 @@ import torch.distributed -from accelerate.test_utils import require_huggingface_suite +from accelerate.test_utils import require_huggingface_suite, torch_device from accelerate.utils import is_transformers_available @@ -27,7 +27,8 @@ @require_huggingface_suite def init_torch_dist_then_launch_deepspeed(): - torch.distributed.init_process_group(backend="nccl") + backend = "ccl" if torch_device == "xpu" else "nccl" + torch.distributed.init_process_group(backend=backend) deepspeed_config = { "zero_optimization": { "stage": 3, From 4e0005594b22f78ba99ecb559c8bd5e6e41c5036 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 26 Jun 2024 16:20:09 -0400 Subject: [PATCH 17/61] better tests --- src/accelerate/utils/dataclasses.py | 2 +- tests/test_accelerator.py | 41 ++++++++++++++++++++++++++--- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 811ea7ae4d8..7afff388d24 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -572,7 +572,7 @@ class DataLoaderConfiguration: metadata={ "help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process" " and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose" - " underlying dataset is an `IterableDataslet`, `False` otherwise." + " underlying dataset is an `IterableDataset`, `False` otherwise." }, ) even_batches: bool = field( diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index b73fe8642c5..6c0c3e56b8c 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -100,6 +100,8 @@ def parameterized_custom_name_func(func, param_num, param): param_based_name = "use_safetensors" if param.args[0] is True else "use_pytorch" if len(param.args) > 1: param_based_name += f"_num_workers_{param.args[1]}" + if len(param.args) > 2: + param_based_name += "_dispatch_batches" if param.args[2] is True else "_no_dispatch_batches" return f"{func.__name__}_{param_based_name}" @@ -630,12 +632,15 @@ def test_prepared_objects_are_referenced_with_stateful_dataloader(self): assert isinstance(prepared_train_dl, StatefulDataLoader) assert isinstance(prepared_valid_dl, StatefulDataLoader) - @parameterized.expand(itertools.product([True, False], [0, 2]), name_func=parameterized_custom_name_func) + @parameterized.expand(itertools.product([True, False], [0, 2], [True, False]), name_func=parameterized_custom_name_func) @require_torchdata_stateful_dataloader - def test_save_model_with_stateful_dataloader(self, use_safetensors, num_workers): - """Test that saving and loading a model with a stateful dataloader returns the same model, and that the dataloader's iterator is restored properly.""" + def test_save_model_with_stateful_dataloader(self, use_safetensors, num_workers, dispatch_batches): + """ + Test that saving and loading a model with a stateful dataloader returns the same model, + and that the dataloader's iterator is restored properly.""" + print() set_seed(42) - dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) + dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=True) accelerator = Accelerator(dataloader_config=dataloader_config) model, optimizer, scheduler, train_dl, valid_dl = create_components() @@ -673,6 +678,7 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors, num_workers) # TODO: Maybe this could be done automatically? prepared_train_dl.end() break + assert accelerator.gradient_state.active_dataloader is None with tempfile.TemporaryDirectory() as tmpdirname: @@ -729,3 +735,30 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors, num_workers) assert torch.allclose(original_linear1, new_linear1) assert torch.allclose(original_batchnorm, new_batchnorm) assert torch.allclose(original_linear2, new_linear2) + + @require_torchdata_stateful_dataloader + def test_stateful_dataloader_dispatcher_deactivate_dataloaders(self): + """ + Test that we can break iteration in a DataLoaderDispatcher backed by a StatefulDataLoader partway through + in a way that removes it from the gradient state active dataloader list. + """ + print() + set_seed(42) + dataloader_config = DataLoaderConfiguration(dispatch_batches=True, use_stateful_dataloader=True) + accelerator = Accelerator(dataloader_config=dataloader_config) + model, optimizer, scheduler, train_dl, valid_dl = create_components() + ( + prepared_model, + prepared_optimizer, + prepared_scheduler, + prepared_train_dl, + prepared_valid_dl, + ) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl) + + # Perform 3 training iterations to ensure the dataloader's iterator is advanced + num_batches_to_skip = 3 + for step, _ in enumerate(prepared_train_dl): + if step == num_batches_to_skip - 1: + prepared_train_dl.end() + break + assert accelerator.gradient_state.active_dataloader is None \ No newline at end of file From 0471fe317868af56ac457661028a43934389a5e9 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 26 Jun 2024 16:45:22 -0400 Subject: [PATCH 18/61] discovered a bug --- tests/test_accelerator.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 6c0c3e56b8c..031255274f2 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -735,30 +735,3 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors, num_workers, assert torch.allclose(original_linear1, new_linear1) assert torch.allclose(original_batchnorm, new_batchnorm) assert torch.allclose(original_linear2, new_linear2) - - @require_torchdata_stateful_dataloader - def test_stateful_dataloader_dispatcher_deactivate_dataloaders(self): - """ - Test that we can break iteration in a DataLoaderDispatcher backed by a StatefulDataLoader partway through - in a way that removes it from the gradient state active dataloader list. - """ - print() - set_seed(42) - dataloader_config = DataLoaderConfiguration(dispatch_batches=True, use_stateful_dataloader=True) - accelerator = Accelerator(dataloader_config=dataloader_config) - model, optimizer, scheduler, train_dl, valid_dl = create_components() - ( - prepared_model, - prepared_optimizer, - prepared_scheduler, - prepared_train_dl, - prepared_valid_dl, - ) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl) - - # Perform 3 training iterations to ensure the dataloader's iterator is advanced - num_batches_to_skip = 3 - for step, _ in enumerate(prepared_train_dl): - if step == num_batches_to_skip - 1: - prepared_train_dl.end() - break - assert accelerator.gradient_state.active_dataloader is None \ No newline at end of file From 3036b7fedb3c7f587d44afd6bcec9e30e5820ab0 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 26 Jun 2024 16:49:24 -0400 Subject: [PATCH 19/61] maybe fixed bug? --- src/accelerate/data_loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index b628c2f559d..72a65c88b88 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -1118,6 +1118,7 @@ class SkipBatchSampler(BatchSampler): def __init__(self, batch_sampler, skip_batches=0): self.batch_sampler = batch_sampler + self.sampler = batch_sampler.sampler self.skip_batches = skip_batches def __iter__(self): From 9ade2e9d4de51e7532dd81e609f5d7a795e51457 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 26 Jun 2024 16:49:35 -0400 Subject: [PATCH 20/61] make style --- tests/test_accelerator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 031255274f2..9b960bbb4f3 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -632,7 +632,9 @@ def test_prepared_objects_are_referenced_with_stateful_dataloader(self): assert isinstance(prepared_train_dl, StatefulDataLoader) assert isinstance(prepared_valid_dl, StatefulDataLoader) - @parameterized.expand(itertools.product([True, False], [0, 2], [True, False]), name_func=parameterized_custom_name_func) + @parameterized.expand( + itertools.product([True, False], [0, 2], [True, False]), name_func=parameterized_custom_name_func + ) @require_torchdata_stateful_dataloader def test_save_model_with_stateful_dataloader(self, use_safetensors, num_workers, dispatch_batches): """ From ba0f5c60cbf05afb274531dc7cd6cc320d666f1f Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 26 Jun 2024 17:43:10 -0400 Subject: [PATCH 21/61] hopefully this is PR ready --- src/accelerate/data_loader.py | 6 ++-- tests/test_data_loader.py | 67 ++++++++++++++++++++++------------- 2 files changed, 45 insertions(+), 28 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 72a65c88b88..a67d69b7e46 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -394,7 +394,7 @@ def end(self): self.gradient_state._remove_dataloader(self) -class DataLoaderAdapter(DataLoaderStateMixin): +class DataLoaderAdapter: """ A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in. @@ -448,7 +448,7 @@ def _save_state_dict(self): self.dl_state_dict = super().state_dict() -class DataLoaderShard(DataLoaderAdapter): +class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): """ Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup. @@ -623,7 +623,7 @@ def batch_sampler(self): return self._loader.batch_sampler -class DataLoaderDispatcher(DataLoaderAdapter): +class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin): """ Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process their part of the batch. diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index e48c20246bf..ac5e2d40686 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -401,7 +401,6 @@ def test_dataloader_inheritance(self): assert isinstance(dl_shard, DataLoader) assert isinstance(dl_dispatcher, DataLoader) - assert isinstance(skip_dl, DataLoaderStateMixin) assert isinstance(dl_shard, DataLoaderStateMixin) assert isinstance(dl_dispatcher, DataLoaderStateMixin) @@ -541,7 +540,6 @@ def test_dataloader_inheritance(self): assert isinstance(dl_shard, StatefulDataLoader) assert isinstance(dl_dispatcher, StatefulDataLoader) - assert isinstance(skip_dl, DataLoaderStateMixin) assert isinstance(dl_shard, DataLoaderStateMixin) assert isinstance(dl_dispatcher, DataLoaderStateMixin) @@ -570,16 +568,30 @@ def g(): dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True ) - stateful_dl_iter = iter(stateful_dl) - iterators_under_test = [iter(skip_dl), iter(dl_shard), iter(dl_dispatcher)] + dataloaders_under_test = [skip_dl, dl_shard, dl_dispatcher] + + num_batches_to_skip = 8 + + def get_first_n_batches(dl, n, device): + """ + Iterate over the first `n` batches of a dataloader then break, returning the batches in a list. + """ + batches = [] + for idx, batch in enumerate(dl): + if idx == n-1: + if hasattr(dl, "end"): + dl.end() + break + batches.append(batch.to(device)) + return batches - num_batches = 8 # Iterate over all of the dataloaders identically, expect the same values - for _ in range(num_batches): - expected_val = next(stateful_dl_iter).to(accelerator.device) - for dl_iter in iterators_under_test: - val = next(dl_iter).to(accelerator.device) - assert torch.allclose(val, expected_val) + expected_batches = get_first_n_batches(stateful_dl, num_batches_to_skip, accelerator.device) + batches_from_dataloaders = [get_first_n_batches(dl, num_batches_to_skip, accelerator.device) for dl in dataloaders_under_test] + + for dl_batches in batches_from_dataloaders: + for expected, actual in zip(expected_batches, dl_batches): + assert torch.allclose(expected, actual) # The adapters should all produce the same state_dict as the reference stateful dataloader expected_state_dict = stateful_dl.state_dict() @@ -593,7 +605,7 @@ def g(): # Load the state dict into new dataloaders manual_skip_dl = SkipDataLoader( - dataset, batch_size=4, num_workers=num_workers, generator=g(), skip_batches=8, use_stateful_dataloader=True + dataset, batch_size=4, num_workers=num_workers, generator=g(), skip_batches=num_batches_to_skip, use_stateful_dataloader=True ) loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g()) loaded_stateful_dl.load_state_dict(expected_state_dict) @@ -611,17 +623,22 @@ def g(): loaded_dl_dispatcher.load_state_dict(expected_state_dict) # Continue the iteration, expecting identical behavior across the board - iterators_under_test.extend( - [ - iter(manual_skip_dl), - iter(loaded_stateful_dl), - iter(loaded_skip_dl), - iter(loaded_dl_shard), - iter(loaded_dl_dispatcher), - ] - ) - for _ in range(num_batches): - expected_val = next(stateful_dl_iter).to(accelerator.device) - for dl_iter in iterators_under_test: - val = next(dl_iter).to(accelerator.device) - assert torch.allclose(val, expected_val) + def get_all_batches(dl, device): + """ + Iterate over all batches of a dataloader, returning (batches, num_batches_yielded) + """ + batches = [] + num_batches_yielded = 0 + for batch in dl: + batches.append(batch.to(device)) + num_batches_yielded += 1 + return (batches, num_batches_yielded) + + expected_batch_results = get_all_batches(loaded_stateful_dl, accelerator.device) + dataloader_batch_results = [get_all_batches(dl, accelerator.device) for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher]] + for dl_results in dataloader_batch_results: + for expected, actual in zip(expected_batches, dl_batches): + assert torch.allclose(expected[0], actual[0]) + assert expected_batch_results[1] == dl_results[1] + + assert accelerator.gradient_state.active_dataloader is None \ No newline at end of file From b774291d817d9a46f8c28030a5d81084b299cede Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 26 Jun 2024 18:00:17 -0400 Subject: [PATCH 22/61] properly skip tests --- src/accelerate/test_utils/testing.py | 10 ++++------ tests/test_data_loader.py | 22 ++++++++++++++++------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index 288d3bf64e2..7229a027cfd 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -29,6 +29,7 @@ import torch import accelerate +from accelerate.utils.imports import is_torchdata_stateful_dataloader_available from ..state import AcceleratorState, PartialState from ..utils import ( @@ -400,12 +401,9 @@ def require_torchdata_stateful_dataloader(test_case): These tests are skipped when torchdata with stateful_dataloader module isn't installed. """ - try: - import torchdata.stateful_dataloader # noqa F401 - except (ImportError, AssertionError): - return unittest.skip("test requires torchdata.stateful_dataloader")(test_case) - else: - return test_case + return unittest.skipUnless( + is_torchdata_stateful_dataloader_available(), "test requires torchdata.stateful_dataloader" + )(test_case) class TempDirTestCase(unittest.TestCase): diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index ac5e2d40686..27b0fb437b7 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -578,7 +578,7 @@ def get_first_n_batches(dl, n, device): """ batches = [] for idx, batch in enumerate(dl): - if idx == n-1: + if idx == n - 1: if hasattr(dl, "end"): dl.end() break @@ -587,8 +587,10 @@ def get_first_n_batches(dl, n, device): # Iterate over all of the dataloaders identically, expect the same values expected_batches = get_first_n_batches(stateful_dl, num_batches_to_skip, accelerator.device) - batches_from_dataloaders = [get_first_n_batches(dl, num_batches_to_skip, accelerator.device) for dl in dataloaders_under_test] - + batches_from_dataloaders = [ + get_first_n_batches(dl, num_batches_to_skip, accelerator.device) for dl in dataloaders_under_test + ] + for dl_batches in batches_from_dataloaders: for expected, actual in zip(expected_batches, dl_batches): assert torch.allclose(expected, actual) @@ -605,7 +607,12 @@ def get_first_n_batches(dl, n, device): # Load the state dict into new dataloaders manual_skip_dl = SkipDataLoader( - dataset, batch_size=4, num_workers=num_workers, generator=g(), skip_batches=num_batches_to_skip, use_stateful_dataloader=True + dataset, + batch_size=4, + num_workers=num_workers, + generator=g(), + skip_batches=num_batches_to_skip, + use_stateful_dataloader=True, ) loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g()) loaded_stateful_dl.load_state_dict(expected_state_dict) @@ -635,10 +642,13 @@ def get_all_batches(dl, device): return (batches, num_batches_yielded) expected_batch_results = get_all_batches(loaded_stateful_dl, accelerator.device) - dataloader_batch_results = [get_all_batches(dl, accelerator.device) for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher]] + dataloader_batch_results = [ + get_all_batches(dl, accelerator.device) + for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher] + ] for dl_results in dataloader_batch_results: for expected, actual in zip(expected_batches, dl_batches): assert torch.allclose(expected[0], actual[0]) assert expected_batch_results[1] == dl_results[1] - assert accelerator.gradient_state.active_dataloader is None \ No newline at end of file + assert accelerator.gradient_state.active_dataloader is None From fde597d29b81dc1c4bf464a54c1e2aa48f492d02 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 26 Jun 2024 18:08:26 -0400 Subject: [PATCH 23/61] parameterize --- tests/test_data_loader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 27b0fb437b7..b202caed788 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -471,8 +471,8 @@ def test_end_of_dataloader_dispatcher(self): for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) - @require_torchdata_stateful_dataloader @parameterized.expand([0, 2], name_func=parameterized_custom_name_func) + @require_torchdata_stateful_dataloader def test_dataloader_state_dict(self, num_workers): """ Test that saving a stateful dataloader's state, then loading it back, gives the same results. @@ -497,8 +497,8 @@ def test_dataloader_state_dict(self, num_workers): for d1, d2 in zip(data1, data2): assert torch.allclose(d1, d2) - @require_torchdata_stateful_dataloader @parameterized.expand([0, 2], name_func=parameterized_custom_name_func) + @require_torchdata_stateful_dataloader def test_dataloader_dispatcher_state_dict(self, num_workers): """ Test that saving a stateful dataloader's state, then loading it back, gives the same results. @@ -543,8 +543,8 @@ def test_dataloader_inheritance(self): assert isinstance(dl_shard, DataLoaderStateMixin) assert isinstance(dl_dispatcher, DataLoaderStateMixin) - @require_torchdata_stateful_dataloader @parameterized.expand([0, 2], name_func=parameterized_custom_name_func) + @require_torchdata_stateful_dataloader def test_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader(self, num_workers): """ Assert that `state_dict()` and `load_state_dict()` for derived subclasses of `DataLoaderAdapter` produce From 8bf2fe24fa7da4624a7c89b91cc231de80cb75a9 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Fri, 21 Jun 2024 20:21:34 -0400 Subject: [PATCH 24/61] temporary commit --- src/accelerate/accelerator.py | 8 ++++---- src/accelerate/data_loader.py | 28 ++++++++++++++++++++++++---- src/accelerate/utils/dataclasses.py | 8 ++++++++ src/accelerate/utils/imports.py | 10 ++++++++++ tests/test_accelerator.py | 10 +++++----- 5 files changed, 51 insertions(+), 13 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 46b859699be..c5dac726cf8 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -34,7 +34,7 @@ from huggingface_hub import split_torch_state_dict_into_shards from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state -from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches +from .data_loader import DataLoaderDispatcher, DataLoaderWrapper, prepare_data_loader, skip_first_batches from .hooks import AlignDevicesHook from .logging import get_logger from .optimizer import AcceleratedOptimizer @@ -1175,7 +1175,7 @@ def print(self, *args, **kwargs): def _prepare_one(self, obj, first_pass=False, device_placement=None): # First pass of preparation: DataLoader, model, optimizer if first_pass: - if isinstance(obj, torch.utils.data.DataLoader): + if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderWrapper): return self.prepare_data_loader(obj, device_placement=device_placement) elif isinstance(obj, torch.nn.Module): return self.prepare_model(obj, device_placement=device_placement) @@ -1587,7 +1587,7 @@ def _prepare_deepspeed(self, *args): is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args) result = [ - self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj + self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderWrapper) else obj for obj in args ] @@ -1838,7 +1838,7 @@ def _prepare_megatron_lm(self, *args): scheduler = None batch_data = None for obj in args: - if isinstance(obj, torch.utils.data.DataLoader) and batch_data is None: + if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderWrapper) and batch_data is None: batch_data = next(iter(obj)) elif isinstance(obj, torch.nn.Module): model = obj diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index fcf6631f162..8a294151ad0 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -19,6 +19,8 @@ import torch from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler +from accelerate.utils.imports import is_torchdata_stateful_dataloader_available + from .logging import get_logger from .state import AcceleratorState, DistributedType, GradientState, is_torch_xla_available from .utils import ( @@ -35,6 +37,8 @@ synchronize_rng_states, ) +if is_torchdata_stateful_dataloader_available(): + from torchdata.stateful_dataloader import StatefulDataLoader logger = get_logger(__name__) @@ -387,10 +391,26 @@ def end(self): "Cleans up the gradient state after exiting the dataloader" self.gradient_state._remove_dataloader(self) +class DataLoaderWrapper: + """ + Class that wraps around a PyTorch `DataLoader` (or subclasses, such as torchdata's `StatefulDataLoader`). + + """ + def __init__(self, dataset, **kwargs): + if False and is_torchdata_stateful_dataloader_available(): + self.dataloader = StatefulDataLoader(dataset, **kwargs) + else: + self.dataloader = DataLoader(dataset, **kwargs) + + for attr in self.dataloader.__dict__.keys(): + setattr(self, attr, getattr(self.dataloader, attr)) + + def __iter__(self): + return self.dataloader.__iter__() -class DataLoaderShard(DataLoader, DataLoaderStateMixin): +class DataLoaderShard(DataLoaderWrapper, DataLoaderStateMixin): """ - Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup. + Subclass of `DataLoaderWrapper` that will deal with device placement and current distributed setup. Args: dataset (`torch.utils.data.dataset.Dataset`): @@ -559,9 +579,9 @@ def batch_sampler(self): return self._loader.batch_sampler -class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin): +class DataLoaderDispatcher(DataLoaderWrapper, DataLoaderStateMixin): """ - Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each + Subclass of `DataLoaderWrapper` that will iterate and preprocess on process 0 only, then dispatch on each process their part of the batch. Args: diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index f8da14e2104..7f4a7e7fc06 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -717,6 +717,14 @@ class DataLoaderConfiguration: " prepared dataloader has `pin_memory` set to `True` to work properly." }, ) + use_stateful_dataloader: bool = field( + default=False, + metadata={ + "help": "If set to true, the dataloader prepared by the Accelerator will be backed by " + "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires a version" + " of `torchdata` with StatefulDataLoader to be installed." + }, + ) @dataclass diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 669044c3910..8d97dc44767 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -413,3 +413,13 @@ def is_xpu_available(check_device=False): def is_dvclive_available(): return _is_package_available("dvclive") + +def is_torchdata_available(): + return _is_package_available("torchdata") + +# TODO: Remove this function once stateful_dataloader is a stable feature in torchdata. +def is_torchdata_stateful_dataloader_available(): + if not _is_package_available("torchdata"): + return False + import torchdata + return hasattr(torchdata, "stateful_dataloader") and hasattr(torchdata.stateful_dataloader, "StatefulDataLoader") \ No newline at end of file diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 56febe0938e..b50837816de 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -230,7 +230,7 @@ def noop(*args, **kwargs): accelerator = Accelerator() assert str(accelerator.state.device) == "cuda:64" - @parameterized.expand((True, False), name_func=parameterized_custom_name_func) + @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) def test_save_load_model(self, use_safetensors): accelerator = Accelerator() model, optimizer, scheduler, train_dl, valid_dl = create_components() @@ -249,7 +249,7 @@ def test_save_load_model(self, use_safetensors): accelerator.load_state(tmpdirname) assert abs(model_signature - get_signature(model)) < 1e-3 - @parameterized.expand([True, False], name_func=parameterized_custom_name_func) + @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) def test_save_model(self, use_safetensors): accelerator = Accelerator() model = torch.nn.Linear(10, 10) @@ -261,7 +261,7 @@ def test_save_model(self, use_safetensors): load_checkpoint_in_model(model, tmpdirname) assert abs(model_signature - get_signature(model)) < 1e-3 - @parameterized.expand([True, False], name_func=parameterized_custom_name_func) + @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) def test_save_sharded_model(self, use_safetensors): accelerator = Accelerator() inputs = torch.randn(3, 3) @@ -277,7 +277,7 @@ def test_save_sharded_model(self, use_safetensors): assert torch.allclose(expected, output, atol=1e-5) - @parameterized.expand([True, False], name_func=parameterized_custom_name_func) + @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) def test_save_model_offload(self, use_safetensors): accelerator = Accelerator() @@ -325,7 +325,7 @@ def test_get_state_dict_from_offload(self, use_safetensors): assert cpu_onloaded_layer_weight.device.type == "cpu" assert device_onloaded_layer_weight.device.type == torch_device - @parameterized.expand([True, False], name_func=parameterized_custom_name_func) + @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) def test_save_load_model_with_hooks(self, use_safetensors): accelerator = Accelerator() model, optimizer, scheduler, train_dl, valid_dl = create_components() From ca4338d1cd648897bfd02d89335998425eacdff9 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Fri, 21 Jun 2024 20:53:32 -0400 Subject: [PATCH 25/61] checkout? --- tests/test_accelerator.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index b50837816de..063b40b9aa5 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -230,7 +230,7 @@ def noop(*args, **kwargs): accelerator = Accelerator() assert str(accelerator.state.device) == "cuda:64" - @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) + @parameterized.expand((True, False), name_func=parameterized_custom_name_func) def test_save_load_model(self, use_safetensors): accelerator = Accelerator() model, optimizer, scheduler, train_dl, valid_dl = create_components() @@ -249,7 +249,7 @@ def test_save_load_model(self, use_safetensors): accelerator.load_state(tmpdirname) assert abs(model_signature - get_signature(model)) < 1e-3 - @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) + @parameterized.expand([True, False], name_func=parameterized_custom_name_func) def test_save_model(self, use_safetensors): accelerator = Accelerator() model = torch.nn.Linear(10, 10) @@ -261,7 +261,7 @@ def test_save_model(self, use_safetensors): load_checkpoint_in_model(model, tmpdirname) assert abs(model_signature - get_signature(model)) < 1e-3 - @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) + @parameterized.expand([True, False], name_func=parameterized_custom_name_func) def test_save_sharded_model(self, use_safetensors): accelerator = Accelerator() inputs = torch.randn(3, 3) @@ -277,7 +277,7 @@ def test_save_sharded_model(self, use_safetensors): assert torch.allclose(expected, output, atol=1e-5) - @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) + @parameterized.expand([True, False], name_func=parameterized_custom_name_func) def test_save_model_offload(self, use_safetensors): accelerator = Accelerator() @@ -298,7 +298,11 @@ def test_save_model_offload(self, use_safetensors): assert torch.allclose(expected, output, atol=1e-5) @parameterized.expand([True, False], name_func=parameterized_custom_name_func) +<<<<<<< HEAD @require_non_cpu +======= + @require_cuda +>>>>>>> efa1e7d (checkout?) def test_get_state_dict_from_offload(self, use_safetensors): accelerator = Accelerator() @@ -325,7 +329,7 @@ def test_get_state_dict_from_offload(self, use_safetensors): assert cpu_onloaded_layer_weight.device.type == "cpu" assert device_onloaded_layer_weight.device.type == torch_device - @parameterized.expand([(True,), (False,)], name_func=parameterized_custom_name_func) + @parameterized.expand([True, False], name_func=parameterized_custom_name_func) def test_save_load_model_with_hooks(self, use_safetensors): accelerator = Accelerator() model, optimizer, scheduler, train_dl, valid_dl = create_components() From c38f317fde9f89165c317d5171584dd656a787ac Mon Sep 17 00:00:00 2001 From: byi8220 Date: Fri, 21 Jun 2024 21:37:53 -0400 Subject: [PATCH 26/61] dataloader wrapper --- src/accelerate/data_loader.py | 7 +++++-- src/accelerate/test_utils/scripts/test_script.py | 2 ++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 8a294151ad0..fba7fa5385e 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -393,7 +393,8 @@ def end(self): class DataLoaderWrapper: """ - Class that wraps around a PyTorch `DataLoader` (or subclasses, such as torchdata's `StatefulDataLoader`). + Class that wraps around a PyTorch `DataLoader` (or subclasses of `DataLoader`, such as torchdata's `StatefulDataLoader`). + """ def __init__(self, dataset, **kwargs): @@ -401,12 +402,14 @@ def __init__(self, dataset, **kwargs): self.dataloader = StatefulDataLoader(dataset, **kwargs) else: self.dataloader = DataLoader(dataset, **kwargs) - for attr in self.dataloader.__dict__.keys(): setattr(self, attr, getattr(self.dataloader, attr)) def __iter__(self): return self.dataloader.__iter__() + + def __len__(self): + return self.dataloader.__len__() class DataLoaderShard(DataLoaderWrapper, DataLoaderStateMixin): """ diff --git a/src/accelerate/test_utils/scripts/test_script.py b/src/accelerate/test_utils/scripts/test_script.py index ff09d9daaad..45292b54a0e 100644 --- a/src/accelerate/test_utils/scripts/test_script.py +++ b/src/accelerate/test_utils/scripts/test_script.py @@ -377,6 +377,8 @@ def check_seedable_sampler(): for batch in train_dl: new_items.append(batch["x"]) new_items = torch.cat(new_items) + print(original_items) + print(new_items) assert torch.allclose(original_items, new_items), "Did not obtain the same items with the same seed and epoch." From 17a2a195a49df65f1b9527c7b21080f6a4a4e3f0 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Fri, 21 Jun 2024 22:08:42 -0400 Subject: [PATCH 27/61] tmp --- src/accelerate/accelerator.py | 8 ++++---- src/accelerate/data_loader.py | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index c5dac726cf8..46b859699be 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -34,7 +34,7 @@ from huggingface_hub import split_torch_state_dict_into_shards from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state -from .data_loader import DataLoaderDispatcher, DataLoaderWrapper, prepare_data_loader, skip_first_batches +from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches from .hooks import AlignDevicesHook from .logging import get_logger from .optimizer import AcceleratedOptimizer @@ -1175,7 +1175,7 @@ def print(self, *args, **kwargs): def _prepare_one(self, obj, first_pass=False, device_placement=None): # First pass of preparation: DataLoader, model, optimizer if first_pass: - if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderWrapper): + if isinstance(obj, torch.utils.data.DataLoader): return self.prepare_data_loader(obj, device_placement=device_placement) elif isinstance(obj, torch.nn.Module): return self.prepare_model(obj, device_placement=device_placement) @@ -1587,7 +1587,7 @@ def _prepare_deepspeed(self, *args): is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args) result = [ - self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderWrapper) else obj + self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj for obj in args ] @@ -1838,7 +1838,7 @@ def _prepare_megatron_lm(self, *args): scheduler = None batch_data = None for obj in args: - if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderWrapper) and batch_data is None: + if isinstance(obj, torch.utils.data.DataLoader) and batch_data is None: batch_data = next(iter(obj)) elif isinstance(obj, torch.nn.Module): model = obj diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index fba7fa5385e..2a93fd96776 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -394,8 +394,6 @@ def end(self): class DataLoaderWrapper: """ Class that wraps around a PyTorch `DataLoader` (or subclasses of `DataLoader`, such as torchdata's `StatefulDataLoader`). - - """ def __init__(self, dataset, **kwargs): if False and is_torchdata_stateful_dataloader_available(): From b39a6060746d3c6d3224a36626c4d926fa8de000 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Fri, 21 Jun 2024 22:31:07 -0400 Subject: [PATCH 28/61] weird failing test --- src/accelerate/data_loader.py | 34 ++++++------------- .../test_utils/scripts/test_script.py | 2 -- 2 files changed, 11 insertions(+), 25 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 2a93fd96776..df36ac8fbae 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -37,8 +37,6 @@ synchronize_rng_states, ) -if is_torchdata_stateful_dataloader_available(): - from torchdata.stateful_dataloader import StatefulDataLoader logger = get_logger(__name__) @@ -391,27 +389,10 @@ def end(self): "Cleans up the gradient state after exiting the dataloader" self.gradient_state._remove_dataloader(self) -class DataLoaderWrapper: - """ - Class that wraps around a PyTorch `DataLoader` (or subclasses of `DataLoader`, such as torchdata's `StatefulDataLoader`). - """ - def __init__(self, dataset, **kwargs): - if False and is_torchdata_stateful_dataloader_available(): - self.dataloader = StatefulDataLoader(dataset, **kwargs) - else: - self.dataloader = DataLoader(dataset, **kwargs) - for attr in self.dataloader.__dict__.keys(): - setattr(self, attr, getattr(self.dataloader, attr)) - def __iter__(self): - return self.dataloader.__iter__() - - def __len__(self): - return self.dataloader.__len__() - -class DataLoaderShard(DataLoaderWrapper, DataLoaderStateMixin): +class DataLoaderShard(DataLoader, DataLoaderStateMixin): """ - Subclass of `DataLoaderWrapper` that will deal with device placement and current distributed setup. + Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup. Args: dataset (`torch.utils.data.dataset.Dataset`): @@ -580,9 +561,9 @@ def batch_sampler(self): return self._loader.batch_sampler -class DataLoaderDispatcher(DataLoaderWrapper, DataLoaderStateMixin): +class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin): """ - Subclass of `DataLoaderWrapper` that will iterate and preprocess on process 0 only, then dispatch on each + Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each process their part of the batch. Args: @@ -796,6 +777,13 @@ def set_sampler(self, sampler): if hasattr(self.batch_sampler, "batch_sampler"): self.batch_sampler.batch_sampler.sampler = sampler +if is_torchdata_stateful_dataloader_available(): + from torchdata.stateful_dataloader import StatefulDataLoader + + class StatefulDataLoaderShard(DataLoaderShard, StatefulDataLoader, DataLoaderStateMixin): + """ + Subclass of DataLoaderShard which inherits from torchdata's `StatefulDataLoader` + """ def get_sampler(dataloader): """ diff --git a/src/accelerate/test_utils/scripts/test_script.py b/src/accelerate/test_utils/scripts/test_script.py index 45292b54a0e..ff09d9daaad 100644 --- a/src/accelerate/test_utils/scripts/test_script.py +++ b/src/accelerate/test_utils/scripts/test_script.py @@ -377,8 +377,6 @@ def check_seedable_sampler(): for batch in train_dl: new_items.append(batch["x"]) new_items = torch.cat(new_items) - print(original_items) - print(new_items) assert torch.allclose(original_items, new_items), "Did not obtain the same items with the same seed and epoch." From d1e82e0ac4da2dc8abd69d8a9acbd11450ec3e5c Mon Sep 17 00:00:00 2001 From: byi8220 Date: Fri, 21 Jun 2024 22:54:14 -0400 Subject: [PATCH 29/61] trying multiple inheritance --- src/accelerate/data_loader.py | 3 ++- src/accelerate/utils/imports.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index df36ac8fbae..e1a403d28e5 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -777,13 +777,14 @@ def set_sampler(self, sampler): if hasattr(self.batch_sampler, "batch_sampler"): self.batch_sampler.batch_sampler.sampler = sampler + if is_torchdata_stateful_dataloader_available(): from torchdata.stateful_dataloader import StatefulDataLoader - class StatefulDataLoaderShard(DataLoaderShard, StatefulDataLoader, DataLoaderStateMixin): """ Subclass of DataLoaderShard which inherits from torchdata's `StatefulDataLoader` """ + pass def get_sampler(dataloader): """ diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 8d97dc44767..2b6626eec84 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -419,7 +419,7 @@ def is_torchdata_available(): # TODO: Remove this function once stateful_dataloader is a stable feature in torchdata. def is_torchdata_stateful_dataloader_available(): - if not _is_package_available("torchdata"): + if not is_torchdata_available(): return False import torchdata return hasattr(torchdata, "stateful_dataloader") and hasattr(torchdata.stateful_dataloader, "StatefulDataLoader") \ No newline at end of file From d99d734cf8acda0a8889cf7f210997aaf91063bf Mon Sep 17 00:00:00 2001 From: byi8220 Date: Sat, 22 Jun 2024 20:32:26 -0400 Subject: [PATCH 30/61] DataLoaderAdapter --- src/accelerate/accelerator.py | 17 ++++--- src/accelerate/data_loader.py | 76 +++++++++++++++++++++------- src/accelerate/test_utils/testing.py | 14 +++++ tests/test_data_loader.py | 49 ++++++++++++++++++ 4 files changed, 133 insertions(+), 23 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 46b859699be..a848ed42524 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -34,7 +34,7 @@ from huggingface_hub import split_torch_state_dict_into_shards from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state -from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches +from .data_loader import DataLoaderAdapter, DataLoaderDispatcher, prepare_data_loader, skip_first_batches from .hooks import AlignDevicesHook from .logging import get_logger from .optimizer import AcceleratedOptimizer @@ -570,6 +570,10 @@ def use_seedable_sampler(self): @property def non_blocking(self): return self.dataloader_config.non_blocking + + @property + def use_stateful_dataloader(self): + return self.dataloader_config.use_stateful_dataloader @property def project_dir(self): @@ -1175,7 +1179,7 @@ def print(self, *args, **kwargs): def _prepare_one(self, obj, first_pass=False, device_placement=None): # First pass of preparation: DataLoader, model, optimizer if first_pass: - if isinstance(obj, torch.utils.data.DataLoader): + if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter): return self.prepare_data_loader(obj, device_placement=device_placement) elif isinstance(obj, torch.nn.Module): return self.prepare_model(obj, device_placement=device_placement) @@ -1585,9 +1589,9 @@ def _prepare_deepspeed(self, *args): deepspeed_plugin = self.state.deepspeed_plugin - is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args) + is_dataloader_present = any((isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter)) for obj in args) result = [ - self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj + self._prepare_one(obj, first_pass=True) if (isinstance(obj, torch.utils.data.DataLoader or isinstance(obj, DataLoaderAdapter))) else obj for obj in args ] @@ -1838,7 +1842,7 @@ def _prepare_megatron_lm(self, *args): scheduler = None batch_data = None for obj in args: - if isinstance(obj, torch.utils.data.DataLoader) and batch_data is None: + if (isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter)) and batch_data is None: batch_data = next(iter(obj)) elif isinstance(obj, torch.nn.Module): model = obj @@ -1867,7 +1871,7 @@ def _prepare_megatron_lm(self, *args): counter = 0 result = [] for obj in args: - if isinstance(obj, torch.utils.data.DataLoader): + if (isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter)): result.append(megatron_lm_prepare_data_loader(self, obj)) counter += 1 elif isinstance(obj, MegatronLMDummyDataLoader): @@ -2030,6 +2034,7 @@ def prepare_data_loader( slice_fn_for_dispatch=slice_fn_for_dispatch, use_seedable_sampler=self.use_seedable_sampler, non_blocking=self.non_blocking, + use_stateful_dataloader=self.use_stateful_dataloader, ) self._dataloaders.append(prepared_data_loader) return prepared_data_loader diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index e1a403d28e5..0366b76aad7 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -20,6 +20,8 @@ from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler from accelerate.utils.imports import is_torchdata_stateful_dataloader_available +if is_torchdata_stateful_dataloader_available(): + from torchdata.stateful_dataloader import StatefulDataLoader from .logging import get_logger from .state import AcceleratorState, DistributedType, GradientState, is_torch_xla_available @@ -38,6 +40,7 @@ ) + logger = get_logger(__name__) # kwargs of the DataLoader in min version 1.4.0. @@ -389,10 +392,44 @@ def end(self): "Cleans up the gradient state after exiting the dataloader" self.gradient_state._remove_dataloader(self) +# TODO: Maybe generalize this class? +class DataLoaderAdapter: + """ + A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. + """ + def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs): + self.use_stateful_dataloader = use_stateful_dataloader + if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available(): + raise ValueError("StatefulDataLoader is not available. Please install torchdata to use it.") + if use_stateful_dataloader: + self.base_dataloader = StatefulDataLoader(dataset, **kwargs) + else: + self.base_dataloader = DataLoader(dataset, **kwargs) + + for attr in self.base_dataloader.__dict__.keys(): + setattr(self, attr, getattr(self.base_dataloader, attr)) -class DataLoaderShard(DataLoader, DataLoaderStateMixin): + def __iter__(self): + return iter(self.base_dataloader) + + def __len__(self): + return len(self.base_dataloader) + + def load_state_dict(self): + """ + Only supported for `StatefulDataLoader`. + """ + return self.base_dataloader.load_state_dict() + + def state_dict(self): + """ + Only supported for `StatefulDataLoader`. + """ + return self.base_dataloader.state_dict() + +class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): """ - Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup. + Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup. Args: dataset (`torch.utils.data.dataset.Dataset`): @@ -411,6 +448,8 @@ class DataLoaderShard(DataLoader, DataLoaderStateMixin): A random number generator to keep synchronized across processes. skip_batches (`int`, *optional*, defaults to 0): The number of batches to skip at the beginning. + use_stateful_dataloader (`bool`, *optional*, defaults to `False`): + Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`. **kwargs (additional keyword arguments, *optional*): All other keyword arguments to pass to the regular `DataLoader` initialization. @@ -430,11 +469,12 @@ def __init__( rng_types=None, synchronized_generator=None, skip_batches=0, + use_stateful_dataloader=False, _drop_last: bool = False, _non_blocking: bool = False, **kwargs, ): - super().__init__(dataset, **kwargs) + super().__init__(dataset, use_stateful_dataloader, **kwargs) self.device = device self.rng_types = rng_types self.synchronized_generator = synchronized_generator @@ -561,9 +601,9 @@ def batch_sampler(self): return self._loader.batch_sampler -class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin): +class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin): """ - Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each + Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process their part of the batch. Args: @@ -576,6 +616,8 @@ class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin): size of the `dataloader` is a round multiple of `batch_size`. skip_batches (`int`, *optional*, defaults to 0): The number of batches to skip at the beginning of an iteration. + use_stateful_dataloader (`bool`, *optional*, defaults to `False`): + Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`. **Available attributes:** @@ -591,6 +633,7 @@ def __init__( dataset, split_batches: bool = False, skip_batches=0, + use_stateful_dataloader=False, _drop_last: bool = False, _non_blocking: bool = False, slice_fn=None, @@ -603,7 +646,7 @@ def __init__( # We need to save the shuffling state of the DataPipe if isinstance(dataset, ShufflerIterDataPipe): shuffle = dataset._shuffle_enabled - super().__init__(dataset, **kwargs) + super().__init__(dataset, use_stateful_dataloader, **kwargs) self.split_batches = split_batches if shuffle: torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle) @@ -778,14 +821,6 @@ def set_sampler(self, sampler): self.batch_sampler.batch_sampler.sampler = sampler -if is_torchdata_stateful_dataloader_available(): - from torchdata.stateful_dataloader import StatefulDataLoader - class StatefulDataLoaderShard(DataLoaderShard, StatefulDataLoader, DataLoaderStateMixin): - """ - Subclass of DataLoaderShard which inherits from torchdata's `StatefulDataLoader` - """ - pass - def get_sampler(dataloader): """ Get the sampler associated to the dataloader @@ -817,6 +852,7 @@ def prepare_data_loader( slice_fn_for_dispatch: Optional[Callable] = None, use_seedable_sampler: bool = False, non_blocking: bool = False, + use_stateful_dataloader: bool = False, ) -> DataLoader: """ Wraps a PyTorch `DataLoader` to generate batches for one of the processes only. @@ -878,6 +914,10 @@ def prepare_data_loader( non_blocking (`bool`, *optional*, defaults to `False`): If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has `pin_memory` set to `True`, this will help to increase overlap between data transfer and computations. + use_stateful_dataloader (`bool`, *optional*, defaults to `False`): + "If set to true, the dataloader prepared by the Accelerator will be backed by " + "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires a version" + " of `torchdata` with StatefulDataLoader to be installed." Returns: @@ -1066,7 +1106,7 @@ def __len__(self): return len(self.batch_sampler) - self.skip_batches -class SkipDataLoader(DataLoader): +class SkipDataLoader(DataLoaderAdapter): """ Subclass of a PyTorch `DataLoader` that will skip the first batches. @@ -1075,12 +1115,14 @@ class SkipDataLoader(DataLoader): The dataset to use to build this datalaoder. skip_batches (`int`, *optional*, defaults to 0): The number of batches to skip at the beginning. + use_stateful_dataloader (`bool`, *optional*, defaults to `False`): + Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`. kwargs: All other keyword arguments to pass to the regular `DataLoader` initialization. """ - def __init__(self, dataset, skip_batches=0, **kwargs): - super().__init__(dataset, **kwargs) + def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs): + super().__init__(dataset, use_stateful_dataloader, **kwargs) self.skip_batches = skip_batches def __iter__(self): diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index 0f3877f2e81..76ad3253a93 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -410,6 +410,20 @@ def require_trackers(test_case): )(test_case) +def require_torchdata_stateful_dataloader(test_case): + """ + Decorator marking a test that requires torchdata.stateful_dataloader. + + These tests are skipped when torchdata with stateful_dataloader module isn't installed. + + """ + try: + import torchdata.stateful_dataloader # noqa F401 + except (ImportError, AssertionError): + return unittest.skip("test requires torchdata.stateful_dataloader")(test_case) + else: + return test_case + class TempDirTestCase(unittest.TestCase): """ A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 2f360d71bcb..86532d38d17 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -28,6 +28,16 @@ skip_first_batches, ) +from accelerate.utils.imports import is_torchdata_stateful_dataloader_available +from accelerate.test_utils.testing import require_torchdata_stateful_dataloader + +if is_torchdata_stateful_dataloader_available(): + from torchdata.stateful_dataloader import ( + StatefulDataLoader, + ) + from accelerate.data_loader import ( + DataLoaderAdapter, + ) class RandomIterableDataset(IterableDataset): # For testing, an iterable dataset of random length @@ -396,3 +406,42 @@ def test_end_of_dataloader_dispatcher(self): # Test it also works on the second iteration for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) + + +class StatefulDataLoaderTester(unittest.TestCase): + + @require_torchdata_stateful_dataloader + def test_skip_data_loader(self): + dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2, use_stateful_dataloader=True) + + assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] + + @require_torchdata_stateful_dataloader + def test_skip_first_batches(self): + dataloader = StatefulDataLoader(list(range(16)), batch_size=4) + new_dataloader = skip_first_batches(dataloader, num_batches=2) + + assert [t.tolist() for t in new_dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] + + @require_torchdata_stateful_dataloader + def test_end_of_dataloader(self): + dataloader = DataLoaderShard(list(range(16)), batch_size=4, use_stateful_dataloader=True) + + for idx, _ in enumerate(dataloader): + assert dataloader.end_of_dataloader == (idx == 3) + + # Test it also works on the second iteration + for idx, _ in enumerate(dataloader): + assert dataloader.end_of_dataloader == (idx == 3) + + @require_torchdata_stateful_dataloader + def test_end_of_dataloader_dispatcher(self): + Accelerator() + dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) + + for idx, _ in enumerate(dataloader): + assert dataloader.end_of_dataloader == (idx == 3) + + # Test it also works on the second iteration + for idx, _ in enumerate(dataloader): + assert dataloader.end_of_dataloader == (idx == 3) From 39b2866b4278dbed85423cb127ac51acf5cc122a Mon Sep 17 00:00:00 2001 From: byi8220 Date: Sat, 22 Jun 2024 20:34:41 -0400 Subject: [PATCH 31/61] make style --- src/accelerate/accelerator.py | 16 +++++++++++----- src/accelerate/data_loader.py | 20 ++++++++++++-------- src/accelerate/test_utils/testing.py | 1 + src/accelerate/utils/dataclasses.py | 2 +- src/accelerate/utils/imports.py | 5 ++++- tests/test_data_loader.py | 9 +++------ 6 files changed, 32 insertions(+), 21 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index a848ed42524..6d372a1af9b 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -570,7 +570,7 @@ def use_seedable_sampler(self): @property def non_blocking(self): return self.dataloader_config.non_blocking - + @property def use_stateful_dataloader(self): return self.dataloader_config.use_stateful_dataloader @@ -1589,9 +1589,13 @@ def _prepare_deepspeed(self, *args): deepspeed_plugin = self.state.deepspeed_plugin - is_dataloader_present = any((isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter)) for obj in args) + is_dataloader_present = any( + (isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter)) for obj in args + ) result = [ - self._prepare_one(obj, first_pass=True) if (isinstance(obj, torch.utils.data.DataLoader or isinstance(obj, DataLoaderAdapter))) else obj + self._prepare_one(obj, first_pass=True) + if (isinstance(obj, torch.utils.data.DataLoader or isinstance(obj, DataLoaderAdapter))) + else obj for obj in args ] @@ -1842,7 +1846,9 @@ def _prepare_megatron_lm(self, *args): scheduler = None batch_data = None for obj in args: - if (isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter)) and batch_data is None: + if ( + isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter) + ) and batch_data is None: batch_data = next(iter(obj)) elif isinstance(obj, torch.nn.Module): model = obj @@ -1871,7 +1877,7 @@ def _prepare_megatron_lm(self, *args): counter = 0 result = [] for obj in args: - if (isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter)): + if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter): result.append(megatron_lm_prepare_data_loader(self, obj)) counter += 1 elif isinstance(obj, MegatronLMDummyDataLoader): diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 0366b76aad7..ce7593effa3 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -20,6 +20,8 @@ from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler from accelerate.utils.imports import is_torchdata_stateful_dataloader_available + + if is_torchdata_stateful_dataloader_available(): from torchdata.stateful_dataloader import StatefulDataLoader @@ -40,7 +42,6 @@ ) - logger = get_logger(__name__) # kwargs of the DataLoader in min version 1.4.0. @@ -392,11 +393,13 @@ def end(self): "Cleans up the gradient state after exiting the dataloader" self.gradient_state._remove_dataloader(self) + # TODO: Maybe generalize this class? class DataLoaderAdapter: """ A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. """ + def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs): self.use_stateful_dataloader = use_stateful_dataloader if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available(): @@ -405,16 +408,16 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * self.base_dataloader = StatefulDataLoader(dataset, **kwargs) else: self.base_dataloader = DataLoader(dataset, **kwargs) - + for attr in self.base_dataloader.__dict__.keys(): setattr(self, attr, getattr(self.base_dataloader, attr)) def __iter__(self): return iter(self.base_dataloader) - + def __len__(self): return len(self.base_dataloader) - + def load_state_dict(self): """ Only supported for `StatefulDataLoader`. @@ -427,6 +430,7 @@ def state_dict(self): """ return self.base_dataloader.state_dict() + class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): """ Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup. @@ -603,8 +607,8 @@ def batch_sampler(self): class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin): """ - Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each - process their part of the batch. + Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process + their part of the batch. Args: split_batches (`bool`, *optional*, defaults to `False`): @@ -916,8 +920,8 @@ def prepare_data_loader( `pin_memory` set to `True`, this will help to increase overlap between data transfer and computations. use_stateful_dataloader (`bool`, *optional*, defaults to `False`): "If set to true, the dataloader prepared by the Accelerator will be backed by " - "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires a version" - " of `torchdata` with StatefulDataLoader to be installed." + "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). + This requires a version" " of `torchdata` with StatefulDataLoader to be installed." Returns: diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index 76ad3253a93..120fc4a0263 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -424,6 +424,7 @@ def require_torchdata_stateful_dataloader(test_case): else: return test_case + class TempDirTestCase(unittest.TestCase): """ A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 7f4a7e7fc06..10cd50f5ccf 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -721,7 +721,7 @@ class DataLoaderConfiguration: default=False, metadata={ "help": "If set to true, the dataloader prepared by the Accelerator will be backed by " - "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires a version" + "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires a version" " of `torchdata` with StatefulDataLoader to be installed." }, ) diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 2b6626eec84..b254e11750d 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -414,12 +414,15 @@ def is_xpu_available(check_device=False): def is_dvclive_available(): return _is_package_available("dvclive") + def is_torchdata_available(): return _is_package_available("torchdata") + # TODO: Remove this function once stateful_dataloader is a stable feature in torchdata. def is_torchdata_stateful_dataloader_available(): if not is_torchdata_available(): return False import torchdata - return hasattr(torchdata, "stateful_dataloader") and hasattr(torchdata.stateful_dataloader, "StatefulDataLoader") \ No newline at end of file + + return hasattr(torchdata, "stateful_dataloader") and hasattr(torchdata.stateful_dataloader, "StatefulDataLoader") diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 86532d38d17..8e96e4bec85 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -27,17 +27,15 @@ SkipDataLoader, skip_first_batches, ) - -from accelerate.utils.imports import is_torchdata_stateful_dataloader_available from accelerate.test_utils.testing import require_torchdata_stateful_dataloader +from accelerate.utils.imports import is_torchdata_stateful_dataloader_available + if is_torchdata_stateful_dataloader_available(): from torchdata.stateful_dataloader import ( StatefulDataLoader, ) - from accelerate.data_loader import ( - DataLoaderAdapter, - ) + class RandomIterableDataset(IterableDataset): # For testing, an iterable dataset of random length @@ -409,7 +407,6 @@ def test_end_of_dataloader_dispatcher(self): class StatefulDataLoaderTester(unittest.TestCase): - @require_torchdata_stateful_dataloader def test_skip_data_loader(self): dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2, use_stateful_dataloader=True) From f58f609cfd5bedb806c80255ca131ac4a92d7edc Mon Sep 17 00:00:00 2001 From: byi8220 Date: Sat, 22 Jun 2024 21:48:04 -0400 Subject: [PATCH 32/61] Some dark magic dynamic reflection (for backwards compat) --- src/accelerate/data_loader.py | 43 +++++++++---------- .../scripts/external_deps/test_metrics.py | 1 - tests/test_data_loader.py | 9 ++-- 3 files changed, 26 insertions(+), 27 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index ce7593effa3..1ab061a920f 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -394,10 +394,10 @@ def end(self): self.gradient_state._remove_dataloader(self) -# TODO: Maybe generalize this class? class DataLoaderAdapter: """ - A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. + A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For + compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in. """ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs): @@ -409,27 +409,18 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * else: self.base_dataloader = DataLoader(dataset, **kwargs) + # Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641 + # This is pretty awkward, but it's the only way to make `isinstance(obj, StatefulDataLoader)` work transparently. + # It would be better if DataLoaderAdapter does not inherit from DataLoader, but that would not be backwards compatible. + base_cls = self.__class__ + base_cls_name = self.__class__.__name__ + parent_cls_name = self.base_dataloader.__class__ + self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {}) + + # Allow this class to transparently pass through attributes from the underlying class for attr in self.base_dataloader.__dict__.keys(): setattr(self, attr, getattr(self.base_dataloader, attr)) - def __iter__(self): - return iter(self.base_dataloader) - - def __len__(self): - return len(self.base_dataloader) - - def load_state_dict(self): - """ - Only supported for `StatefulDataLoader`. - """ - return self.base_dataloader.load_state_dict() - - def state_dict(self): - """ - Only supported for `StatefulDataLoader`. - """ - return self.base_dataloader.state_dict() - class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): """ @@ -1055,6 +1046,7 @@ def prepare_data_loader( _drop_last=dataloader.drop_last, _non_blocking=non_blocking, slice_fn=slice_fn_for_dispatch, + use_stateful_dataloader=use_stateful_dataloader, **kwargs, ) elif sampler_is_batch_sampler: @@ -1067,7 +1059,7 @@ def prepare_data_loader( _drop_last=dataloader.drop_last, _non_blocking=non_blocking, synchronized_generator=synchronized_generator, - **kwargs, + use_stateful_dataloader=use_stateful_dataloader**kwargs, ) else: dataloader = DataLoaderShard( @@ -1078,6 +1070,7 @@ def prepare_data_loader( synchronized_generator=synchronized_generator, _drop_last=dataloader.drop_last, _non_blocking=non_blocking, + use_stateful_dataloader=use_stateful_dataloader, **kwargs, ) @@ -1177,6 +1170,7 @@ def skip_first_batches(dataloader, num_batches=0): split_batches=dataloader.split_batches, batch_sampler=new_batch_sampler, _drop_last=dataloader._drop_last, + use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs, ) elif isinstance(dataloader, DataLoaderShard): @@ -1193,12 +1187,17 @@ def skip_first_batches(dataloader, num_batches=0): device=dataloader.device, rng_types=dataloader.rng_types, synchronized_generator=dataloader.synchronized_generator, + use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs, ) else: if new_batch_sampler is None: # Need to manually skip batches in the dataloader - dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs) + dataloader = SkipDataLoader( + dataset, skip_batches=num_batches, use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs + ) + elif is_torchdata_stateful_dataloader_available() and isinstance(dataloader, StatefulDataLoader): + dataloader = StatefulDataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) else: dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) diff --git a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py index 9ac13aba626..9925e60a647 100755 --- a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py +++ b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py @@ -227,7 +227,6 @@ def test_gather_for_metrics_drop_last(): num_items = (10 * accelerator.num_processes) + 1 dataloader = DataLoader(range(num_items), batch_size=per_device_batch_size, drop_last=True) dataloader = accelerator.prepare(dataloader) - iterator = iter(dataloader) next(iterator) # Skip first batch tensor([0, 1, 2, 3, 4], device='cuda:0') batch = next(iterator) diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 8e96e4bec85..8053562c4fc 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -410,20 +410,21 @@ class StatefulDataLoaderTester(unittest.TestCase): @require_torchdata_stateful_dataloader def test_skip_data_loader(self): dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2, use_stateful_dataloader=True) - + assert isinstance(dataloader, StatefulDataLoader) assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] @require_torchdata_stateful_dataloader def test_skip_first_batches(self): dataloader = StatefulDataLoader(list(range(16)), batch_size=4) new_dataloader = skip_first_batches(dataloader, num_batches=2) - + assert isinstance(new_dataloader, StatefulDataLoader) assert [t.tolist() for t in new_dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] @require_torchdata_stateful_dataloader def test_end_of_dataloader(self): dataloader = DataLoaderShard(list(range(16)), batch_size=4, use_stateful_dataloader=True) - + assert dataloader.use_stateful_dataloader + assert isinstance(dataloader, StatefulDataLoader) for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) @@ -435,7 +436,7 @@ def test_end_of_dataloader(self): def test_end_of_dataloader_dispatcher(self): Accelerator() dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) - + assert isinstance(dataloader, StatefulDataLoader) for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) From f2119cfdc3f90de242c81899b87ab3f82ac78989 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Sat, 22 Jun 2024 22:32:26 -0400 Subject: [PATCH 33/61] typo --- src/accelerate/data_loader.py | 4 ++-- .../test_utils/scripts/external_deps/test_metrics.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 1ab061a920f..442b7bc22f3 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -405,9 +405,9 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available(): raise ValueError("StatefulDataLoader is not available. Please install torchdata to use it.") if use_stateful_dataloader: - self.base_dataloader = StatefulDataLoader(dataset, **kwargs) + self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler **kwargs) else: - self.base_dataloader = DataLoader(dataset, **kwargs) + self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs) # Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641 # This is pretty awkward, but it's the only way to make `isinstance(obj, StatefulDataLoader)` work transparently. diff --git a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py index 9925e60a647..aca0f5ad07c 100755 --- a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py +++ b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py @@ -227,6 +227,7 @@ def test_gather_for_metrics_drop_last(): num_items = (10 * accelerator.num_processes) + 1 dataloader = DataLoader(range(num_items), batch_size=per_device_batch_size, drop_last=True) dataloader = accelerator.prepare(dataloader) + iterator = iter(dataloader) next(iterator) # Skip first batch tensor([0, 1, 2, 3, 4], device='cuda:0') batch = next(iterator) @@ -234,6 +235,11 @@ def test_gather_for_metrics_drop_last(): # Should return a full set of complete batches from each GPU num_expected_items = per_device_batch_size * accelerator.num_processes + print("dataloader.batch_size:", dataloader.batch_size) + print("accelerator.num_processes:", accelerator.num_processes) + print("gathered_items:", gathered_items) + print("batch:", batch) + print("len(dataloader):", len(dataloader)) assert gathered_items.size(0) == ( num_expected_items ), f"Expected number of items: {num_expected_items}, Actual: {gathered_items.size(0)}" From 7adec94478b19e897c2d5b73f3146a6eb653f5a0 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 25 Jun 2024 14:12:44 -0400 Subject: [PATCH 34/61] some tests --- src/accelerate/data_loader.py | 17 +++++- tests/test_accelerator.py | 98 +++++++++++++++++++++++++++++++++-- tests/test_data_loader.py | 52 +++++++++++++++++++ 3 files changed, 161 insertions(+), 6 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 442b7bc22f3..575d58e6ea5 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -387,10 +387,12 @@ def begin(self): if not self._drop_last: length = getattr(self.dataset, "total_dataset_length", len(self.dataset)) self.remainder = length % self.total_batch_size + print("adding dataloader", self) self.gradient_state._add_dataloader(self) def end(self): "Cleans up the gradient state after exiting the dataloader" + print("removing dataloader", self) self.gradient_state._remove_dataloader(self) @@ -405,7 +407,7 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available(): raise ValueError("StatefulDataLoader is not available. Please install torchdata to use it.") if use_stateful_dataloader: - self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler **kwargs) + self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs) else: self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs) @@ -421,6 +423,16 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * for attr in self.base_dataloader.__dict__.keys(): setattr(self, attr, getattr(self.base_dataloader, attr)) + if hasattr(self.base_dataloader, "state_dict"): + self.dl_state_dict = self.base_dataloader.state_dict() + + def state_dict(self): + return self.dl_state_dict + + def _save_state_dict(self): + if hasattr(self.base_dataloader, "state_dict"): + self.dl_state_dict = self.base_dataloader.state_dict() + class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): """ @@ -498,6 +510,7 @@ def __iter__(self): # But we still move it to the device so it is done before `StopIteration` is reached if self.device is not None: current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking) + self._save_state_dict() next_batch = next(dataloader_iter) if batch_index >= self.skip_batches: yield current_batch @@ -662,12 +675,14 @@ def _fetch_batches(self, iterator): try: if self.split_batches: # One batch of the main iterator is dispatched and split. + self._save_state_dict() batch = next(iterator) else: # num_processes batches of the main iterator are concatenated then dispatched and split. # We add the batches one by one so we have the remainder available when drop_last=False. batches = [] for _ in range(self.state.num_processes): + self._save_state_dict() batches.append(next(iterator)) try: batch = concatenate(batches, dim=0) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 063b40b9aa5..41068072b68 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -27,18 +27,30 @@ from accelerate.accelerator import Accelerator from accelerate.state import GradientState, PartialState from accelerate.test_utils import require_bnb, require_multi_gpu, require_non_cpu, slow, torch_device -from accelerate.test_utils.testing import AccelerateTestCase, require_cuda, require_non_torch_xla +from accelerate.test_utils.testing import ( + AccelerateTestCase, + require_cuda, + require_non_torch_xla, + require_torchdata_stateful_dataloader, +) from accelerate.utils import patch_environment +from accelerate.utils.dataclasses import DataLoaderConfiguration +from accelerate.utils.imports import is_torchdata_stateful_dataloader_available from accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model -def create_components(): +if is_torchdata_stateful_dataloader_available(): + from torchdata.stateful_dataloader import ( + StatefulDataLoader, + ) + + +def create_components(dataset_size=3): model = torch.nn.Linear(2, 4) optimizer = torch.optim.AdamW(model.parameters(), lr=1.0) scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=2, epochs=1) - train_dl = DataLoader(TensorDataset(torch.tensor([1, 2, 3]))) - valid_dl = DataLoader(TensorDataset(torch.tensor([4, 5, 6]))) - + train_dl = DataLoader(TensorDataset(torch.tensor([i for i in range(1, dataset_size+1)]))) + valid_dl = DataLoader(TensorDataset(torch.tensor([i+dataset_size for i in range(1, dataset_size+1)]))) return model, optimizer, scheduler, train_dl, valid_dl @@ -575,3 +587,79 @@ def test_can_unwrap_model(self): # check that pickle roundtrip works model_loaded = pickle.loads(pickle.dumps(model)) model_loaded(inputs) + + # Ideally would be a parameterized test which works with either stateful or non-stateful dataloaders, but dependencies are a bit awkward. + @require_torchdata_stateful_dataloader + def test_prepared_objects_are_referenced_with_stateful_dataloader(self): + """Test that setting `use_stateful_dataloader=True` in `DataLoaderConfiguration` prepares a `StatefulDataLoader` object instead of a `DataLoader` object.""" + dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) + accelerator = Accelerator(dataloader_config=dataloader_config) + model, optimizer, scheduler, train_dl, valid_dl = create_components() + + ( + prepared_model, + prepared_optimizer, + prepared_scheduler, + prepared_train_dl, + prepared_valid_dl, + ) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl) + + assert prepared_model in accelerator._models + assert prepared_optimizer in accelerator._optimizers + assert prepared_scheduler in accelerator._schedulers + assert prepared_train_dl in accelerator._dataloaders + assert prepared_valid_dl in accelerator._dataloaders + assert isinstance(prepared_train_dl, StatefulDataLoader) + assert isinstance(prepared_valid_dl, StatefulDataLoader) + + @parameterized.expand([True, False], name_func=parameterized_custom_name_func) + @require_torchdata_stateful_dataloader + def test_save_model_with_stateful_dataloader(self, use_safetensors): + """Test that saving and loading a model with a stateful dataloader returns the same model, and that the dataloader's iterator is restored properly.""" + dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) + accelerator = Accelerator(dataloader_config=dataloader_config) + + model, optimizer, scheduler, train_dl, valid_dl = create_components(dataset_size=6) + ( + prepared_model, + prepared_optimizer, + prepared_scheduler, + prepared_train_dl, + prepared_valid_dl, + ) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl) + + assert isinstance(prepared_train_dl, StatefulDataLoader) + assert isinstance(prepared_valid_dl, StatefulDataLoader) + + print("len before iterating", len(prepared_train_dl)) + # Perform 3 training iterations to ensure the dataloader's iterator is advanced + for i, input in enumerate(prepared_train_dl): + print(i, input) + if i == 2: + state_dict = prepared_train_dl.state_dict() + break + + for i, input in enumerate(prepared_train_dl): + print("Pass of initial dict yielding input {}".format(input), prepared_train_dl.state_dict()) + + print("State dict to be loaded", state_dict) + prepared_train_dl.load_state_dict(state_dict) + print("State dict immediately after loading", prepared_train_dl.state_dict()) + + for i, input in enumerate(prepared_train_dl): + print("Pass of loaded dict yielding input {}".format(input), prepared_train_dl.state_dict()) + + + model_signature = get_signature(prepared_model) + with tempfile.TemporaryDirectory() as tmpdirname: + + # Save the model's state. + accelerator.save_model(prepared_model, tmpdirname, safe_serialization=use_safetensors) + + # Load the saved model + loaded_model = prepared_model + load_checkpoint_in_model(loaded_model, tmpdirname) + # make sure loaded weights match + assert abs(model_signature - get_signature(prepared_model)) < 1e-3 + + # iterate through both dataloaders and assert their behaviors are identical \ No newline at end of file diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 8053562c4fc..81cae2b4d21 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -15,6 +15,7 @@ import random import unittest +import torch from torch.utils.data import BatchSampler, DataLoader, IterableDataset from accelerate import Accelerator @@ -438,8 +439,59 @@ def test_end_of_dataloader_dispatcher(self): dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) assert isinstance(dataloader, StatefulDataLoader) for idx, _ in enumerate(dataloader): + print(idx) assert dataloader.end_of_dataloader == (idx == 3) # Test it also works on the second iteration for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) + + @require_torchdata_stateful_dataloader + def test_dataloader_state_dict(self): + """ + Test that saving a stateful dataloader's state, then loading it back, gives the same results. + """ + dataset = list(range(16)) + dataloader = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True) + + assert dataloader.use_stateful_dataloader + assert isinstance(dataloader, StatefulDataLoader) + vals = [] + for idx, val in enumerate(dataloader): + vals.append(val) + if idx == 1: + sd = dataloader.state_dict() + assert len(vals) == 4 + + dataloader2 = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True) + dataloader2.load_state_dict(sd) + + data1 = vals[2:] + data2 = list(dataloader2) + for d1, d2 in zip(data1, data2): + assert torch.allclose(d1, d2) + + @require_torchdata_stateful_dataloader + def test_dataloader_dispatcher_state_dict(self): + """ + Test that saving a stateful dataloader's state, then loading it back, gives the same results. + """ + dataset = list(range(16)) + dataloader = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True) + + assert dataloader.use_stateful_dataloader + assert isinstance(dataloader, StatefulDataLoader) + vals = [] + for idx, val in enumerate(dataloader): + vals.append(val) + if idx == 1: + sd = dataloader.state_dict() + assert len(vals) == 4 + + dataloader2 = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True) + dataloader2.load_state_dict(sd) + + data1 = vals[2:] + data2 = list(dataloader2) + for d1, d2 in zip(data1, data2): + assert torch.allclose(d1, d2) \ No newline at end of file From 8850af3202827a5c8780102fc336b7618ac56abb Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 25 Jun 2024 14:58:49 -0400 Subject: [PATCH 35/61] more mixin stuff --- src/accelerate/accelerator.py | 14 ++++---- src/accelerate/data_loader.py | 6 ++-- .../test_utils/scripts/test_sync.py | 2 ++ tests/test_data_loader.py | 32 ++++++++++++++++++- 4 files changed, 42 insertions(+), 12 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 6d372a1af9b..52c59e07227 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -34,7 +34,7 @@ from huggingface_hub import split_torch_state_dict_into_shards from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state -from .data_loader import DataLoaderAdapter, DataLoaderDispatcher, prepare_data_loader, skip_first_batches +from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches from .hooks import AlignDevicesHook from .logging import get_logger from .optimizer import AcceleratedOptimizer @@ -1179,7 +1179,7 @@ def print(self, *args, **kwargs): def _prepare_one(self, obj, first_pass=False, device_placement=None): # First pass of preparation: DataLoader, model, optimizer if first_pass: - if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter): + if isinstance(obj, torch.utils.data.DataLoader): return self.prepare_data_loader(obj, device_placement=device_placement) elif isinstance(obj, torch.nn.Module): return self.prepare_model(obj, device_placement=device_placement) @@ -1590,11 +1590,11 @@ def _prepare_deepspeed(self, *args): deepspeed_plugin = self.state.deepspeed_plugin is_dataloader_present = any( - (isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter)) for obj in args + (isinstance(obj, torch.utils.data.DataLoader)) for obj in args ) result = [ self._prepare_one(obj, first_pass=True) - if (isinstance(obj, torch.utils.data.DataLoader or isinstance(obj, DataLoaderAdapter))) + if (isinstance(obj, torch.utils.data.DataLoader)) else obj for obj in args ] @@ -1846,9 +1846,7 @@ def _prepare_megatron_lm(self, *args): scheduler = None batch_data = None for obj in args: - if ( - isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter) - ) and batch_data is None: + if isinstance(obj, torch.utils.data.DataLoader) and batch_data is None: batch_data = next(iter(obj)) elif isinstance(obj, torch.nn.Module): model = obj @@ -1877,7 +1875,7 @@ def _prepare_megatron_lm(self, *args): counter = 0 result = [] for obj in args: - if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, DataLoaderAdapter): + if isinstance(obj, torch.utils.data.DataLoader): result.append(megatron_lm_prepare_data_loader(self, obj)) counter += 1 elif isinstance(obj, MegatronLMDummyDataLoader): diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 575d58e6ea5..c7c188a200a 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -396,7 +396,7 @@ def end(self): self.gradient_state._remove_dataloader(self) -class DataLoaderAdapter: +class DataLoaderAdapter(DataLoaderStateMixin): """ A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in. @@ -434,7 +434,7 @@ def _save_state_dict(self): self.dl_state_dict = self.base_dataloader.state_dict() -class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): +class DataLoaderShard(DataLoaderAdapter): """ Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup. @@ -609,7 +609,7 @@ def batch_sampler(self): return self._loader.batch_sampler -class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin): +class DataLoaderDispatcher(DataLoaderAdapter): """ Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process their part of the batch. diff --git a/src/accelerate/test_utils/scripts/test_sync.py b/src/accelerate/test_utils/scripts/test_sync.py index fd829231770..1abfe02da57 100644 --- a/src/accelerate/test_utils/scripts/test_sync.py +++ b/src/accelerate/test_utils/scripts/test_sync.py @@ -310,7 +310,9 @@ def test_dataloader_break(): first_dataloader = DataLoader(first_dset, batch_size=16) second_dset = RegressionDataset(length=96) second_dataloader = DataLoader(second_dset, batch_size=16) + print("Dataloaders to be prepared", first_dataloader, second_dataloader) first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader) + print("Dataloaders prepared", first_dataloader, second_dataloader) assert accelerator.gradient_state.active_dataloader is None for iteration, _ in enumerate(first_dataloader): assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader) diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 81cae2b4d21..5e6f028176a 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -23,6 +23,7 @@ BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard, + DataLoaderStateMixin, IterableDatasetShard, SkipBatchSampler, SkipDataLoader, @@ -378,6 +379,20 @@ def test_skip_batch_sampler(self): new_batch_sampler = SkipBatchSampler(batch_sampler, 2) assert list(new_batch_sampler) == [[8, 9, 10, 11], [12, 13, 14, 15]] + def test_dataloader_inheritance(self): + """`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that subclasses of DataLoaderAdapter are instances of DataLoader and DataLoaderStateMixin.""" + Accelerator() + skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2) + dl_shard = DataLoaderShard(range(16), batch_size=4) + dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4) + assert isinstance(skip_dl, DataLoader) + assert isinstance(dl_shard, DataLoader) + assert isinstance(dl_dispatcher, DataLoader) + + assert isinstance(skip_dl, DataLoaderStateMixin) + assert isinstance(dl_shard, DataLoaderStateMixin) + assert isinstance(dl_dispatcher, DataLoaderStateMixin) + def test_skip_data_loader(self): dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2) assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] @@ -494,4 +509,19 @@ def test_dataloader_dispatcher_state_dict(self): data1 = vals[2:] data2 = list(dataloader2) for d1, d2 in zip(data1, data2): - assert torch.allclose(d1, d2) \ No newline at end of file + assert torch.allclose(d1, d2) + + @require_torchdata_stateful_dataloader + def test_dataloader_inheritance(self): + """`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that when use_stateful_dataloader=True, subclasses of DataLoaderAdapter are instances of StatefulDataLoader and DataLoaderStateMixin.""" + Accelerator() + skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True) + dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True) + dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) + assert isinstance(skip_dl, StatefulDataLoader) + assert isinstance(dl_shard, StatefulDataLoader) + assert isinstance(dl_dispatcher, StatefulDataLoader) + + assert isinstance(skip_dl, DataLoaderStateMixin) + assert isinstance(dl_shard, DataLoaderStateMixin) + assert isinstance(dl_dispatcher, DataLoaderStateMixin) From 6ff0f6805d72b72b5a1f3ab3c8f5b828e25b8093 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 25 Jun 2024 16:33:54 -0400 Subject: [PATCH 36/61] maybe found broken test? --- src/accelerate/test_utils/scripts/test_sync.py | 4 +--- tests/test_accelerator.py | 1 + 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/accelerate/test_utils/scripts/test_sync.py b/src/accelerate/test_utils/scripts/test_sync.py index 1abfe02da57..672c4852046 100644 --- a/src/accelerate/test_utils/scripts/test_sync.py +++ b/src/accelerate/test_utils/scripts/test_sync.py @@ -305,14 +305,12 @@ def test_gradient_accumulation_with_opt_and_scheduler( def test_dataloader_break(): accelerator = Accelerator() - first_dset = RegressionDataset(length=80) first_dataloader = DataLoader(first_dset, batch_size=16) second_dset = RegressionDataset(length=96) second_dataloader = DataLoader(second_dset, batch_size=16) - print("Dataloaders to be prepared", first_dataloader, second_dataloader) first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader) - print("Dataloaders prepared", first_dataloader, second_dataloader) + assert accelerator.gradient_state.active_dataloader is None for iteration, _ in enumerate(first_dataloader): assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 41068072b68..77bbb792811 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -630,6 +630,7 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors): assert isinstance(prepared_train_dl, StatefulDataLoader) assert isinstance(prepared_valid_dl, StatefulDataLoader) + assert accelerator.gradient_state.active_dataloader is None print("len before iterating", len(prepared_train_dl)) # Perform 3 training iterations to ensure the dataloader's iterator is advanced From 4f28d2e24acd0c33044354d40dd5c981147cfee0 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 25 Jun 2024 17:25:05 -0400 Subject: [PATCH 37/61] this is a very invasive feature --- tests/test_accelerator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 77bbb792811..2989e01e527 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -638,6 +638,8 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors): print(i, input) if i == 2: state_dict = prepared_train_dl.state_dict() + # When breaking out without fully going through the iterator, must call end() to unregister this iterator from gradient state. + prepared_train_dl.end() break for i, input in enumerate(prepared_train_dl): @@ -663,4 +665,4 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors): # make sure loaded weights match assert abs(model_signature - get_signature(prepared_model)) < 1e-3 - # iterate through both dataloaders and assert their behaviors are identical \ No newline at end of file + # iterate through both dataloaders and assert their behaviors are identical From a9b637dc8982185fa76179edf49b4aa8ab5e650d Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 26 Jun 2024 15:19:45 -0400 Subject: [PATCH 38/61] i think the feature is done? --- src/accelerate/accelerator.py | 8 +- src/accelerate/data_loader.py | 37 +++++++--- tests/test_accelerator.py | 133 +++++++++++++++++++++++++--------- tests/test_data_loader.py | 126 ++++++++++++++++++++++++++++---- 4 files changed, 241 insertions(+), 63 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 52c59e07227..7693505dea8 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1589,13 +1589,9 @@ def _prepare_deepspeed(self, *args): deepspeed_plugin = self.state.deepspeed_plugin - is_dataloader_present = any( - (isinstance(obj, torch.utils.data.DataLoader)) for obj in args - ) + is_dataloader_present = any((isinstance(obj, torch.utils.data.DataLoader)) for obj in args) result = [ - self._prepare_one(obj, first_pass=True) - if (isinstance(obj, torch.utils.data.DataLoader)) - else obj + self._prepare_one(obj, first_pass=True) if (isinstance(obj, torch.utils.data.DataLoader)) else obj for obj in args ] diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index c7c188a200a..b628c2f559d 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -387,12 +387,10 @@ def begin(self): if not self._drop_last: length = getattr(self.dataset, "total_dataset_length", len(self.dataset)) self.remainder = length % self.total_batch_size - print("adding dataloader", self) self.gradient_state._add_dataloader(self) def end(self): "Cleans up the gradient state after exiting the dataloader" - print("removing dataloader", self) self.gradient_state._remove_dataloader(self) @@ -404,6 +402,7 @@ class DataLoaderAdapter(DataLoaderStateMixin): def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs): self.use_stateful_dataloader = use_stateful_dataloader + if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available(): raise ValueError("StatefulDataLoader is not available. Please install torchdata to use it.") if use_stateful_dataloader: @@ -412,26 +411,41 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs) # Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641 - # This is pretty awkward, but it's the only way to make `isinstance(obj, StatefulDataLoader)` work transparently. - # It would be better if DataLoaderAdapter does not inherit from DataLoader, but that would not be backwards compatible. + # In C++ terms, this is analogous to creating `DataLoaderAdapter : T`, where T is a DataLoader or + # StatefulDataLoader + # + # The same functionality could be achieved by directly creating the required subclasses for both {DataLoader, + # StatefulDataLoader}, however that could lead to much messier code, with duplicated classes and conditional + # dispatching scattered throughout various functions and files. + # + # This code is incredibly awkward but it's the only way to make `isinstance(obj, StatefulDataLoader)` work + # transparently. + # + # A more robust solution is for DataLoaderAdapter to not inherit from DataLoader (compose rather than inherit), + # but this would not be backwards compatible with existing code which assumes + # DataLoaderShard/DataLoaderDispatcher are DataLoaders. base_cls = self.__class__ base_cls_name = self.__class__.__name__ parent_cls_name = self.base_dataloader.__class__ self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {}) # Allow this class to transparently pass through attributes from the underlying class + if hasattr(self.base_dataloader, "state_dict"): + self.dl_state_dict = self.base_dataloader.state_dict() + for attr in self.base_dataloader.__dict__.keys(): setattr(self, attr, getattr(self.base_dataloader, attr)) - if hasattr(self.base_dataloader, "state_dict"): - self.dl_state_dict = self.base_dataloader.state_dict() - def state_dict(self): return self.dl_state_dict - + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + self.dl_state_dict = self.state_dict + def _save_state_dict(self): if hasattr(self.base_dataloader, "state_dict"): - self.dl_state_dict = self.base_dataloader.state_dict() + self.dl_state_dict = super().state_dict() class DataLoaderShard(DataLoaderAdapter): @@ -1074,7 +1088,8 @@ def prepare_data_loader( _drop_last=dataloader.drop_last, _non_blocking=non_blocking, synchronized_generator=synchronized_generator, - use_stateful_dataloader=use_stateful_dataloader**kwargs, + use_stateful_dataloader=use_stateful_dataloader, + **kwargs, ) else: dataloader = DataLoaderShard( @@ -1140,6 +1155,7 @@ def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwa def __iter__(self): for index, batch in enumerate(super().__iter__()): if index >= self.skip_batches: + self._save_state_dict() yield batch @@ -1215,5 +1231,4 @@ def skip_first_batches(dataloader, num_batches=0): dataloader = StatefulDataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) else: dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) - return dataloader diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 2989e01e527..1e12e3e5f4b 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools import json import os import pickle @@ -25,6 +26,7 @@ from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch from accelerate.accelerator import Accelerator +from accelerate.data_loader import skip_first_batches from accelerate.state import GradientState, PartialState from accelerate.test_utils import require_bnb, require_multi_gpu, require_non_cpu, slow, torch_device from accelerate.test_utils.testing import ( @@ -37,6 +39,7 @@ from accelerate.utils.dataclasses import DataLoaderConfiguration from accelerate.utils.imports import is_torchdata_stateful_dataloader_available from accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model +from accelerate.utils.random import set_seed if is_torchdata_stateful_dataloader_available(): @@ -45,12 +48,12 @@ ) -def create_components(dataset_size=3): +def create_components(): model = torch.nn.Linear(2, 4) optimizer = torch.optim.AdamW(model.parameters(), lr=1.0) scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=2, epochs=1) - train_dl = DataLoader(TensorDataset(torch.tensor([i for i in range(1, dataset_size+1)]))) - valid_dl = DataLoader(TensorDataset(torch.tensor([i+dataset_size for i in range(1, dataset_size+1)]))) + train_dl = DataLoader(TensorDataset(torch.tensor([1, 2, 3]))) + valid_dl = DataLoader(TensorDataset(torch.tensor([4, 5, 6]))) return model, optimizer, scheduler, train_dl, valid_dl @@ -65,6 +68,23 @@ def forward(self, x): return self.linear2(self.batchnorm(self.linear1(x))) +def create_dataloaders_for_test( + a=2, b=3, batch_size=3, n_train_batches: int = 12, n_valid_batches: int = 2, num_workers=0 +): + "Generates a tuple of dummy DataLoaders to test with" + + def get_dataset(n_batches): + x = torch.randn(batch_size * n_batches, 3) + y = torch.randn(batch_size * n_batches, 5) + return TensorDataset(x, y) + + train_dataset = get_dataset(n_train_batches) + valid_dataset = get_dataset(n_valid_batches) + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers) + valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=num_workers) + return (train_dataloader, valid_dataloader) + + def get_signature(model): return (model.weight.abs().sum() + model.bias.abs().sum()).item() @@ -78,6 +98,8 @@ def parameterized_custom_name_func(func, param_num, param): # customize the test name generator function as we want both params to appear in the sub-test # name, as by default it shows only the first param param_based_name = "use_safetensors" if param.args[0] is True else "use_pytorch" + if len(param.args) > 1: + param_based_name += f"_num_workers_{param.args[1]}" return f"{func.__name__}_{param_based_name}" @@ -612,14 +634,18 @@ def test_prepared_objects_are_referenced_with_stateful_dataloader(self): assert isinstance(prepared_train_dl, StatefulDataLoader) assert isinstance(prepared_valid_dl, StatefulDataLoader) - @parameterized.expand([True, False], name_func=parameterized_custom_name_func) + @parameterized.expand(itertools.product([True, False], [0, 2]), name_func=parameterized_custom_name_func) @require_torchdata_stateful_dataloader - def test_save_model_with_stateful_dataloader(self, use_safetensors): + def test_save_model_with_stateful_dataloader(self, use_safetensors, num_workers): """Test that saving and loading a model with a stateful dataloader returns the same model, and that the dataloader's iterator is restored properly.""" + set_seed(42) dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) accelerator = Accelerator(dataloader_config=dataloader_config) - model, optimizer, scheduler, train_dl, valid_dl = create_components(dataset_size=6) + model, optimizer, scheduler, train_dl, valid_dl = create_components() + train_dl, valid_dl = create_dataloaders_for_test(num_workers=num_workers) + model = ModelForTest() + ( prepared_model, prepared_optimizer, @@ -630,39 +656,80 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors): assert isinstance(prepared_train_dl, StatefulDataLoader) assert isinstance(prepared_valid_dl, StatefulDataLoader) - assert accelerator.gradient_state.active_dataloader is None - print("len before iterating", len(prepared_train_dl)) # Perform 3 training iterations to ensure the dataloader's iterator is advanced - for i, input in enumerate(prepared_train_dl): - print(i, input) - if i == 2: + num_batches_to_skip = 3 + model.train() + for step, batch in enumerate(prepared_train_dl): + x, y = batch + x.to(accelerator.device) + y.to(accelerator.device) + with accelerator.accumulate(prepared_model): + outputs = prepared_model(x) + loss = torch.nn.functional.mse_loss(outputs, y) + accelerator.backward(loss) + prepared_optimizer.step() + prepared_scheduler.step() + prepared_optimizer.zero_grad() + if step == num_batches_to_skip - 1: state_dict = prepared_train_dl.state_dict() # When breaking out without fully going through the iterator, must call end() to unregister this iterator from gradient state. + # TODO: Maybe this could be done automatically? prepared_train_dl.end() break + assert accelerator.gradient_state.active_dataloader is None - for i, input in enumerate(prepared_train_dl): - print("Pass of initial dict yielding input {}".format(input), prepared_train_dl.state_dict()) - - print("State dict to be loaded", state_dict) - prepared_train_dl.load_state_dict(state_dict) - print("State dict immediately after loading", prepared_train_dl.state_dict()) - - for i, input in enumerate(prepared_train_dl): - print("Pass of loaded dict yielding input {}".format(input), prepared_train_dl.state_dict()) - - - model_signature = get_signature(prepared_model) with tempfile.TemporaryDirectory() as tmpdirname: + # Save model for later use + accelerator.save_model(model, tmpdirname, safe_serialization=use_safetensors) - # Save the model's state. - accelerator.save_model(prepared_model, tmpdirname, safe_serialization=use_safetensors) - - # Load the saved model - loaded_model = prepared_model - load_checkpoint_in_model(loaded_model, tmpdirname) - # make sure loaded weights match - assert abs(model_signature - get_signature(prepared_model)) < 1e-3 - - # iterate through both dataloaders and assert their behaviors are identical + # Starting from where we left off, train this model to the end of the DataLoader + prepared_train_dl = skip_first_batches(prepared_train_dl, num_batches_to_skip) + batches_seen_with_original_dl = 0 + for batch in prepared_train_dl: + x, y = batch + x.to(accelerator.device) + y.to(accelerator.device) + with accelerator.accumulate(prepared_model): + outputs = prepared_model(x) + loss = torch.nn.functional.mse_loss(outputs, y) + accelerator.backward(loss) + prepared_optimizer.step() + prepared_scheduler.step() + prepared_optimizer.zero_grad() + batches_seen_with_original_dl += 1 + + original_linear1 = prepared_model.linear1.weight.clone() + original_batchnorm = prepared_model.batchnorm.weight.clone() + original_linear2 = prepared_model.linear2.weight.clone() + + # Load the model and state dict + load_checkpoint_in_model(model, tmpdirname) + stateful_train_dl, _ = create_dataloaders_for_test(num_workers=num_workers) + prepared_stateful_train_dl = accelerator.prepare_data_loader(stateful_train_dl) + prepared_stateful_train_dl.load_state_dict(state_dict) + + # Train this to the end of the DataLoader + batches_seen_with_loaded_dl = 0 + for batch in prepared_stateful_train_dl: + x, y = batch + x.to(accelerator.device) + y.to(accelerator.device) + with accelerator.accumulate(prepared_model): + outputs = prepared_model(x) + loss = torch.nn.functional.mse_loss(outputs, y) + accelerator.backward(loss) + prepared_optimizer.step() + prepared_scheduler.step() + prepared_optimizer.zero_grad() + batches_seen_with_loaded_dl += 1 + + new_linear1 = prepared_model.linear1.weight + new_batchnorm = prepared_model.batchnorm.weight + new_linear2 = prepared_model.linear2.weight + + # Assert equalities + assert batches_seen_with_original_dl == batches_seen_with_loaded_dl + assert torch.allclose(original_linear1, new_linear1) + assert torch.allclose(original_batchnorm, new_batchnorm) + assert torch.allclose(original_linear2, new_linear2) diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 5e6f028176a..e48c20246bf 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -16,6 +16,7 @@ import unittest import torch +from parameterized import parameterized from torch.utils.data import BatchSampler, DataLoader, IterableDataset from accelerate import Accelerator @@ -30,6 +31,7 @@ skip_first_batches, ) from accelerate.test_utils.testing import require_torchdata_stateful_dataloader +from accelerate.utils.dataclasses import DataLoaderConfiguration from accelerate.utils.imports import is_torchdata_stateful_dataloader_available @@ -39,6 +41,13 @@ ) +def parameterized_custom_name_func(func, param_num, param): + # customize the test name generator function as we want both params to appear in the sub-test + # name, as by default it shows only the first param + param_based_name = f"num_workers_{param.args[0]}" + return f"{func.__name__}_{param_based_name}" + + class RandomIterableDataset(IterableDataset): # For testing, an iterable dataset of random length def __init__(self, p_stop=0.01, max_length=1000): @@ -380,7 +389,10 @@ def test_skip_batch_sampler(self): assert list(new_batch_sampler) == [[8, 9, 10, 11], [12, 13, 14, 15]] def test_dataloader_inheritance(self): - """`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that subclasses of DataLoaderAdapter are instances of DataLoader and DataLoaderStateMixin.""" + """ + `DataLoaderAdapter`'s parent classes are dynamically constructed, assert that subclasses of DataLoaderAdapter + are instances of DataLoader and DataLoaderStateMixin. + """ Accelerator() skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2) dl_shard = DataLoaderShard(range(16), batch_size=4) @@ -454,7 +466,6 @@ def test_end_of_dataloader_dispatcher(self): dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) assert isinstance(dataloader, StatefulDataLoader) for idx, _ in enumerate(dataloader): - print(idx) assert dataloader.end_of_dataloader == (idx == 3) # Test it also works on the second iteration @@ -462,48 +473,53 @@ def test_end_of_dataloader_dispatcher(self): assert dataloader.end_of_dataloader == (idx == 3) @require_torchdata_stateful_dataloader - def test_dataloader_state_dict(self): + @parameterized.expand([0, 2], name_func=parameterized_custom_name_func) + def test_dataloader_state_dict(self, num_workers): """ Test that saving a stateful dataloader's state, then loading it back, gives the same results. """ dataset = list(range(16)) - dataloader = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True) + dataloader = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers) assert dataloader.use_stateful_dataloader assert isinstance(dataloader, StatefulDataLoader) - vals = [] + vals = [] for idx, val in enumerate(dataloader): vals.append(val) if idx == 1: sd = dataloader.state_dict() assert len(vals) == 4 - dataloader2 = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True) + dataloader2 = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers) dataloader2.load_state_dict(sd) data1 = vals[2:] data2 = list(dataloader2) for d1, d2 in zip(data1, data2): assert torch.allclose(d1, d2) - + @require_torchdata_stateful_dataloader - def test_dataloader_dispatcher_state_dict(self): + @parameterized.expand([0, 2], name_func=parameterized_custom_name_func) + def test_dataloader_dispatcher_state_dict(self, num_workers): """ Test that saving a stateful dataloader's state, then loading it back, gives the same results. """ + dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) + Accelerator(dataloader_config=dataloader_config) dataset = list(range(16)) - dataloader = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True) + dataloader = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers) assert dataloader.use_stateful_dataloader assert isinstance(dataloader, StatefulDataLoader) - vals = [] + vals = [] for idx, val in enumerate(dataloader): vals.append(val) if idx == 1: sd = dataloader.state_dict() assert len(vals) == 4 - - dataloader2 = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True) + dataloader2 = DataLoaderDispatcher( + dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers + ) dataloader2.load_state_dict(sd) data1 = vals[2:] @@ -513,7 +529,10 @@ def test_dataloader_dispatcher_state_dict(self): @require_torchdata_stateful_dataloader def test_dataloader_inheritance(self): - """`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that when use_stateful_dataloader=True, subclasses of DataLoaderAdapter are instances of StatefulDataLoader and DataLoaderStateMixin.""" + """ + `DataLoaderAdapter`'s parent classes are dynamically constructed, assert that if use_stateful_dataloader=True, + subclasses of DataLoaderAdapter are instances of StatefulDataLoader and DataLoaderStateMixin. + """ Accelerator() skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True) dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True) @@ -525,3 +544,84 @@ def test_dataloader_inheritance(self): assert isinstance(skip_dl, DataLoaderStateMixin) assert isinstance(dl_shard, DataLoaderStateMixin) assert isinstance(dl_dispatcher, DataLoaderStateMixin) + + @require_torchdata_stateful_dataloader + @parameterized.expand([0, 2], name_func=parameterized_custom_name_func) + def test_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader(self, num_workers): + """ + Assert that `state_dict()` and `load_state_dict()` for derived subclasses of `DataLoaderAdapter` produce + the same behavior as `state_dict()` and `load_state_dict()` for `StatefulDataLoader`. + """ + dataset = list(range(64)) + + # Set the seed for reproducibility + def g(): + return torch.Generator().manual_seed(42) + + accelerator = Accelerator() + stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g()) + skip_dl = SkipDataLoader( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + dl_shard = DataLoaderShard( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + dl_dispatcher = DataLoaderDispatcher( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + + stateful_dl_iter = iter(stateful_dl) + iterators_under_test = [iter(skip_dl), iter(dl_shard), iter(dl_dispatcher)] + + num_batches = 8 + # Iterate over all of the dataloaders identically, expect the same values + for _ in range(num_batches): + expected_val = next(stateful_dl_iter).to(accelerator.device) + for dl_iter in iterators_under_test: + val = next(dl_iter).to(accelerator.device) + assert torch.allclose(val, expected_val) + + # The adapters should all produce the same state_dict as the reference stateful dataloader + expected_state_dict = stateful_dl.state_dict() + skip_dl_state_dict = skip_dl.state_dict() + dl_shard_state_dict = dl_shard.state_dict() + dl_dispatcher_state_dict = dl_dispatcher.state_dict() + + assert expected_state_dict == skip_dl_state_dict + assert expected_state_dict == dl_shard_state_dict + assert expected_state_dict == dl_dispatcher_state_dict + + # Load the state dict into new dataloaders + manual_skip_dl = SkipDataLoader( + dataset, batch_size=4, num_workers=num_workers, generator=g(), skip_batches=8, use_stateful_dataloader=True + ) + loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g()) + loaded_stateful_dl.load_state_dict(expected_state_dict) + loaded_skip_dl = SkipDataLoader( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + loaded_skip_dl.load_state_dict(expected_state_dict) + loaded_dl_shard = DataLoaderShard( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + loaded_dl_shard.load_state_dict(expected_state_dict) + loaded_dl_dispatcher = DataLoaderDispatcher( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + loaded_dl_dispatcher.load_state_dict(expected_state_dict) + + # Continue the iteration, expecting identical behavior across the board + iterators_under_test.extend( + [ + iter(manual_skip_dl), + iter(loaded_stateful_dl), + iter(loaded_skip_dl), + iter(loaded_dl_shard), + iter(loaded_dl_dispatcher), + ] + ) + for _ in range(num_batches): + expected_val = next(stateful_dl_iter).to(accelerator.device) + for dl_iter in iterators_under_test: + val = next(dl_iter).to(accelerator.device) + assert torch.allclose(val, expected_val) From 03845432de0403e848e66c0d0c599849f4236027 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 26 Jun 2024 16:20:09 -0400 Subject: [PATCH 39/61] better tests --- src/accelerate/utils/dataclasses.py | 2 +- tests/test_accelerator.py | 41 ++++++++++++++++++++++++++--- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 10cd50f5ccf..b2f26565e9c 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -689,7 +689,7 @@ class DataLoaderConfiguration: metadata={ "help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process" " and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose" - " underlying dataset is an `IterableDataslet`, `False` otherwise." + " underlying dataset is an `IterableDataset`, `False` otherwise." }, ) even_batches: bool = field( diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 1e12e3e5f4b..7d3a34d81bc 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -100,6 +100,8 @@ def parameterized_custom_name_func(func, param_num, param): param_based_name = "use_safetensors" if param.args[0] is True else "use_pytorch" if len(param.args) > 1: param_based_name += f"_num_workers_{param.args[1]}" + if len(param.args) > 2: + param_based_name += "_dispatch_batches" if param.args[2] is True else "_no_dispatch_batches" return f"{func.__name__}_{param_based_name}" @@ -634,12 +636,15 @@ def test_prepared_objects_are_referenced_with_stateful_dataloader(self): assert isinstance(prepared_train_dl, StatefulDataLoader) assert isinstance(prepared_valid_dl, StatefulDataLoader) - @parameterized.expand(itertools.product([True, False], [0, 2]), name_func=parameterized_custom_name_func) + @parameterized.expand(itertools.product([True, False], [0, 2], [True, False]), name_func=parameterized_custom_name_func) @require_torchdata_stateful_dataloader - def test_save_model_with_stateful_dataloader(self, use_safetensors, num_workers): - """Test that saving and loading a model with a stateful dataloader returns the same model, and that the dataloader's iterator is restored properly.""" + def test_save_model_with_stateful_dataloader(self, use_safetensors, num_workers, dispatch_batches): + """ + Test that saving and loading a model with a stateful dataloader returns the same model, + and that the dataloader's iterator is restored properly.""" + print() set_seed(42) - dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) + dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=True) accelerator = Accelerator(dataloader_config=dataloader_config) model, optimizer, scheduler, train_dl, valid_dl = create_components() @@ -677,6 +682,7 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors, num_workers) # TODO: Maybe this could be done automatically? prepared_train_dl.end() break + assert accelerator.gradient_state.active_dataloader is None with tempfile.TemporaryDirectory() as tmpdirname: @@ -733,3 +739,30 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors, num_workers) assert torch.allclose(original_linear1, new_linear1) assert torch.allclose(original_batchnorm, new_batchnorm) assert torch.allclose(original_linear2, new_linear2) + + @require_torchdata_stateful_dataloader + def test_stateful_dataloader_dispatcher_deactivate_dataloaders(self): + """ + Test that we can break iteration in a DataLoaderDispatcher backed by a StatefulDataLoader partway through + in a way that removes it from the gradient state active dataloader list. + """ + print() + set_seed(42) + dataloader_config = DataLoaderConfiguration(dispatch_batches=True, use_stateful_dataloader=True) + accelerator = Accelerator(dataloader_config=dataloader_config) + model, optimizer, scheduler, train_dl, valid_dl = create_components() + ( + prepared_model, + prepared_optimizer, + prepared_scheduler, + prepared_train_dl, + prepared_valid_dl, + ) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl) + + # Perform 3 training iterations to ensure the dataloader's iterator is advanced + num_batches_to_skip = 3 + for step, _ in enumerate(prepared_train_dl): + if step == num_batches_to_skip - 1: + prepared_train_dl.end() + break + assert accelerator.gradient_state.active_dataloader is None \ No newline at end of file From 0e0515d2438e624013733f4e95dccd8d6ccb3ce6 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 26 Jun 2024 16:45:22 -0400 Subject: [PATCH 40/61] discovered a bug --- tests/test_accelerator.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 7d3a34d81bc..9c8a481eba2 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -739,30 +739,3 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors, num_workers, assert torch.allclose(original_linear1, new_linear1) assert torch.allclose(original_batchnorm, new_batchnorm) assert torch.allclose(original_linear2, new_linear2) - - @require_torchdata_stateful_dataloader - def test_stateful_dataloader_dispatcher_deactivate_dataloaders(self): - """ - Test that we can break iteration in a DataLoaderDispatcher backed by a StatefulDataLoader partway through - in a way that removes it from the gradient state active dataloader list. - """ - print() - set_seed(42) - dataloader_config = DataLoaderConfiguration(dispatch_batches=True, use_stateful_dataloader=True) - accelerator = Accelerator(dataloader_config=dataloader_config) - model, optimizer, scheduler, train_dl, valid_dl = create_components() - ( - prepared_model, - prepared_optimizer, - prepared_scheduler, - prepared_train_dl, - prepared_valid_dl, - ) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl) - - # Perform 3 training iterations to ensure the dataloader's iterator is advanced - num_batches_to_skip = 3 - for step, _ in enumerate(prepared_train_dl): - if step == num_batches_to_skip - 1: - prepared_train_dl.end() - break - assert accelerator.gradient_state.active_dataloader is None \ No newline at end of file From 809aca041d9ddc1b91bc10b7e34872e9c6f91cfb Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 26 Jun 2024 16:49:24 -0400 Subject: [PATCH 41/61] maybe fixed bug? --- src/accelerate/data_loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index b628c2f559d..72a65c88b88 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -1118,6 +1118,7 @@ class SkipBatchSampler(BatchSampler): def __init__(self, batch_sampler, skip_batches=0): self.batch_sampler = batch_sampler + self.sampler = batch_sampler.sampler self.skip_batches = skip_batches def __iter__(self): From 5145c2d4b8cf273582f5dd51a0b717d173a71681 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 26 Jun 2024 16:49:35 -0400 Subject: [PATCH 42/61] make style --- tests/test_accelerator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 9c8a481eba2..ff69fa2e774 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -636,7 +636,9 @@ def test_prepared_objects_are_referenced_with_stateful_dataloader(self): assert isinstance(prepared_train_dl, StatefulDataLoader) assert isinstance(prepared_valid_dl, StatefulDataLoader) - @parameterized.expand(itertools.product([True, False], [0, 2], [True, False]), name_func=parameterized_custom_name_func) + @parameterized.expand( + itertools.product([True, False], [0, 2], [True, False]), name_func=parameterized_custom_name_func + ) @require_torchdata_stateful_dataloader def test_save_model_with_stateful_dataloader(self, use_safetensors, num_workers, dispatch_batches): """ From ca74ff2261cb70fc811cb6422ee4f35d72164a7d Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 26 Jun 2024 17:43:10 -0400 Subject: [PATCH 43/61] hopefully this is PR ready --- src/accelerate/data_loader.py | 6 ++-- tests/test_data_loader.py | 67 ++++++++++++++++++++++------------- 2 files changed, 45 insertions(+), 28 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 72a65c88b88..a67d69b7e46 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -394,7 +394,7 @@ def end(self): self.gradient_state._remove_dataloader(self) -class DataLoaderAdapter(DataLoaderStateMixin): +class DataLoaderAdapter: """ A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in. @@ -448,7 +448,7 @@ def _save_state_dict(self): self.dl_state_dict = super().state_dict() -class DataLoaderShard(DataLoaderAdapter): +class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): """ Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup. @@ -623,7 +623,7 @@ def batch_sampler(self): return self._loader.batch_sampler -class DataLoaderDispatcher(DataLoaderAdapter): +class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin): """ Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process their part of the batch. diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index e48c20246bf..ac5e2d40686 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -401,7 +401,6 @@ def test_dataloader_inheritance(self): assert isinstance(dl_shard, DataLoader) assert isinstance(dl_dispatcher, DataLoader) - assert isinstance(skip_dl, DataLoaderStateMixin) assert isinstance(dl_shard, DataLoaderStateMixin) assert isinstance(dl_dispatcher, DataLoaderStateMixin) @@ -541,7 +540,6 @@ def test_dataloader_inheritance(self): assert isinstance(dl_shard, StatefulDataLoader) assert isinstance(dl_dispatcher, StatefulDataLoader) - assert isinstance(skip_dl, DataLoaderStateMixin) assert isinstance(dl_shard, DataLoaderStateMixin) assert isinstance(dl_dispatcher, DataLoaderStateMixin) @@ -570,16 +568,30 @@ def g(): dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True ) - stateful_dl_iter = iter(stateful_dl) - iterators_under_test = [iter(skip_dl), iter(dl_shard), iter(dl_dispatcher)] + dataloaders_under_test = [skip_dl, dl_shard, dl_dispatcher] + + num_batches_to_skip = 8 + + def get_first_n_batches(dl, n, device): + """ + Iterate over the first `n` batches of a dataloader then break, returning the batches in a list. + """ + batches = [] + for idx, batch in enumerate(dl): + if idx == n-1: + if hasattr(dl, "end"): + dl.end() + break + batches.append(batch.to(device)) + return batches - num_batches = 8 # Iterate over all of the dataloaders identically, expect the same values - for _ in range(num_batches): - expected_val = next(stateful_dl_iter).to(accelerator.device) - for dl_iter in iterators_under_test: - val = next(dl_iter).to(accelerator.device) - assert torch.allclose(val, expected_val) + expected_batches = get_first_n_batches(stateful_dl, num_batches_to_skip, accelerator.device) + batches_from_dataloaders = [get_first_n_batches(dl, num_batches_to_skip, accelerator.device) for dl in dataloaders_under_test] + + for dl_batches in batches_from_dataloaders: + for expected, actual in zip(expected_batches, dl_batches): + assert torch.allclose(expected, actual) # The adapters should all produce the same state_dict as the reference stateful dataloader expected_state_dict = stateful_dl.state_dict() @@ -593,7 +605,7 @@ def g(): # Load the state dict into new dataloaders manual_skip_dl = SkipDataLoader( - dataset, batch_size=4, num_workers=num_workers, generator=g(), skip_batches=8, use_stateful_dataloader=True + dataset, batch_size=4, num_workers=num_workers, generator=g(), skip_batches=num_batches_to_skip, use_stateful_dataloader=True ) loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g()) loaded_stateful_dl.load_state_dict(expected_state_dict) @@ -611,17 +623,22 @@ def g(): loaded_dl_dispatcher.load_state_dict(expected_state_dict) # Continue the iteration, expecting identical behavior across the board - iterators_under_test.extend( - [ - iter(manual_skip_dl), - iter(loaded_stateful_dl), - iter(loaded_skip_dl), - iter(loaded_dl_shard), - iter(loaded_dl_dispatcher), - ] - ) - for _ in range(num_batches): - expected_val = next(stateful_dl_iter).to(accelerator.device) - for dl_iter in iterators_under_test: - val = next(dl_iter).to(accelerator.device) - assert torch.allclose(val, expected_val) + def get_all_batches(dl, device): + """ + Iterate over all batches of a dataloader, returning (batches, num_batches_yielded) + """ + batches = [] + num_batches_yielded = 0 + for batch in dl: + batches.append(batch.to(device)) + num_batches_yielded += 1 + return (batches, num_batches_yielded) + + expected_batch_results = get_all_batches(loaded_stateful_dl, accelerator.device) + dataloader_batch_results = [get_all_batches(dl, accelerator.device) for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher]] + for dl_results in dataloader_batch_results: + for expected, actual in zip(expected_batches, dl_batches): + assert torch.allclose(expected[0], actual[0]) + assert expected_batch_results[1] == dl_results[1] + + assert accelerator.gradient_state.active_dataloader is None \ No newline at end of file From a8f8bf310f48826e70ca9266797c647c483ed0c5 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 26 Jun 2024 18:00:17 -0400 Subject: [PATCH 44/61] properly skip tests --- src/accelerate/test_utils/testing.py | 10 ++++------ tests/test_data_loader.py | 22 ++++++++++++++++------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index 120fc4a0263..8303d31661c 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -29,6 +29,7 @@ import torch import accelerate +from accelerate.utils.imports import is_torchdata_stateful_dataloader_available from ..state import AcceleratorState, PartialState from ..utils import ( @@ -417,12 +418,9 @@ def require_torchdata_stateful_dataloader(test_case): These tests are skipped when torchdata with stateful_dataloader module isn't installed. """ - try: - import torchdata.stateful_dataloader # noqa F401 - except (ImportError, AssertionError): - return unittest.skip("test requires torchdata.stateful_dataloader")(test_case) - else: - return test_case + return unittest.skipUnless( + is_torchdata_stateful_dataloader_available(), "test requires torchdata.stateful_dataloader" + )(test_case) class TempDirTestCase(unittest.TestCase): diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index ac5e2d40686..27b0fb437b7 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -578,7 +578,7 @@ def get_first_n_batches(dl, n, device): """ batches = [] for idx, batch in enumerate(dl): - if idx == n-1: + if idx == n - 1: if hasattr(dl, "end"): dl.end() break @@ -587,8 +587,10 @@ def get_first_n_batches(dl, n, device): # Iterate over all of the dataloaders identically, expect the same values expected_batches = get_first_n_batches(stateful_dl, num_batches_to_skip, accelerator.device) - batches_from_dataloaders = [get_first_n_batches(dl, num_batches_to_skip, accelerator.device) for dl in dataloaders_under_test] - + batches_from_dataloaders = [ + get_first_n_batches(dl, num_batches_to_skip, accelerator.device) for dl in dataloaders_under_test + ] + for dl_batches in batches_from_dataloaders: for expected, actual in zip(expected_batches, dl_batches): assert torch.allclose(expected, actual) @@ -605,7 +607,12 @@ def get_first_n_batches(dl, n, device): # Load the state dict into new dataloaders manual_skip_dl = SkipDataLoader( - dataset, batch_size=4, num_workers=num_workers, generator=g(), skip_batches=num_batches_to_skip, use_stateful_dataloader=True + dataset, + batch_size=4, + num_workers=num_workers, + generator=g(), + skip_batches=num_batches_to_skip, + use_stateful_dataloader=True, ) loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g()) loaded_stateful_dl.load_state_dict(expected_state_dict) @@ -635,10 +642,13 @@ def get_all_batches(dl, device): return (batches, num_batches_yielded) expected_batch_results = get_all_batches(loaded_stateful_dl, accelerator.device) - dataloader_batch_results = [get_all_batches(dl, accelerator.device) for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher]] + dataloader_batch_results = [ + get_all_batches(dl, accelerator.device) + for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher] + ] for dl_results in dataloader_batch_results: for expected, actual in zip(expected_batches, dl_batches): assert torch.allclose(expected[0], actual[0]) assert expected_batch_results[1] == dl_results[1] - assert accelerator.gradient_state.active_dataloader is None \ No newline at end of file + assert accelerator.gradient_state.active_dataloader is None From 59738f4cef704064bfdd7de2dd68c19772bad37b Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 26 Jun 2024 18:08:26 -0400 Subject: [PATCH 45/61] parameterize --- tests/test_data_loader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 27b0fb437b7..b202caed788 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -471,8 +471,8 @@ def test_end_of_dataloader_dispatcher(self): for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) - @require_torchdata_stateful_dataloader @parameterized.expand([0, 2], name_func=parameterized_custom_name_func) + @require_torchdata_stateful_dataloader def test_dataloader_state_dict(self, num_workers): """ Test that saving a stateful dataloader's state, then loading it back, gives the same results. @@ -497,8 +497,8 @@ def test_dataloader_state_dict(self, num_workers): for d1, d2 in zip(data1, data2): assert torch.allclose(d1, d2) - @require_torchdata_stateful_dataloader @parameterized.expand([0, 2], name_func=parameterized_custom_name_func) + @require_torchdata_stateful_dataloader def test_dataloader_dispatcher_state_dict(self, num_workers): """ Test that saving a stateful dataloader's state, then loading it back, gives the same results. @@ -543,8 +543,8 @@ def test_dataloader_inheritance(self): assert isinstance(dl_shard, DataLoaderStateMixin) assert isinstance(dl_dispatcher, DataLoaderStateMixin) - @require_torchdata_stateful_dataloader @parameterized.expand([0, 2], name_func=parameterized_custom_name_func) + @require_torchdata_stateful_dataloader def test_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader(self, num_workers): """ Assert that `state_dict()` and `load_state_dict()` for derived subclasses of `DataLoaderAdapter` produce From 8f04c1e88784fbbaf4e911606db0e06444fffc45 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Mon, 15 Jul 2024 14:05:41 -0400 Subject: [PATCH 46/61] Update src/accelerate/utils/dataclasses.py Co-authored-by: Zach Mueller --- src/accelerate/utils/dataclasses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 7aeeb8e09c5..713b8308fd0 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -723,7 +723,7 @@ class DataLoaderConfiguration: use_stateful_dataloader: bool = field( default=False, metadata={ - "help": "If set to true, the dataloader prepared by the Accelerator will be backed by " + "help": "If set to `True`, the dataloader prepared by the Accelerator will be backed by " "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires a version" " of `torchdata` with StatefulDataLoader to be installed." }, From 45db4b98fa3bc0f3c3d75980008480121724d298 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Mon, 15 Jul 2024 14:05:51 -0400 Subject: [PATCH 47/61] Update src/accelerate/data_loader.py Co-authored-by: Zach Mueller --- src/accelerate/data_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index a67d69b7e46..94532b3e678 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -404,7 +404,7 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * self.use_stateful_dataloader = use_stateful_dataloader if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available(): - raise ValueError("StatefulDataLoader is not available. Please install torchdata to use it.") + raise ImportError("StatefulDataLoader is not available. Please install torchdata to use it.") if use_stateful_dataloader: self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs) else: From 0ffc64b4efcd4e88bde4ed8a957a2013d886d9d7 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Mon, 15 Jul 2024 14:08:14 -0400 Subject: [PATCH 48/61] merge conflicts --- tests/test_accelerator.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index ff69fa2e774..b8e3d7335e4 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -334,11 +334,7 @@ def test_save_model_offload(self, use_safetensors): assert torch.allclose(expected, output, atol=1e-5) @parameterized.expand([True, False], name_func=parameterized_custom_name_func) -<<<<<<< HEAD @require_non_cpu -======= - @require_cuda ->>>>>>> efa1e7d (checkout?) def test_get_state_dict_from_offload(self, use_safetensors): accelerator = Accelerator() From 8d2c6c3ad6045ffcd6e04bf1506404f08e5cdfe8 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Mon, 15 Jul 2024 14:23:30 -0400 Subject: [PATCH 49/61] move imports --- src/accelerate/data_loader.py | 10 ++++++---- src/accelerate/test_utils/testing.py | 2 +- src/accelerate/utils/__init__.py | 2 ++ tests/test_accelerator.py | 3 +-- tests/test_data_loader.py | 2 +- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 94532b3e678..ef3ae6bea47 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -19,12 +19,9 @@ import torch from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler -from accelerate.utils.imports import is_torchdata_stateful_dataloader_available +from accelerate.utils import is_torchdata_stateful_dataloader_available -if is_torchdata_stateful_dataloader_available(): - from torchdata.stateful_dataloader import StatefulDataLoader - from .logging import get_logger from .state import AcceleratorState, DistributedType, GradientState, is_torch_xla_available from .utils import ( @@ -402,6 +399,8 @@ class DataLoaderAdapter: def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs): self.use_stateful_dataloader = use_stateful_dataloader + if is_torchdata_stateful_dataloader_available(): + from torchdata.stateful_dataloader import StatefulDataLoader if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available(): raise ImportError("StatefulDataLoader is not available. Please install torchdata to use it.") @@ -1164,6 +1163,9 @@ def skip_first_batches(dataloader, num_batches=0): """ Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. """ + if is_torchdata_stateful_dataloader_available(): + from torchdata.stateful_dataloader import StatefulDataLoader + dataset = dataloader.dataset sampler_is_batch_sampler = False if isinstance(dataset, IterableDataset): diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index 620c81eb5aa..3685fe97803 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -29,7 +29,6 @@ import torch import accelerate -from accelerate.utils.imports import is_torchdata_stateful_dataloader_available from ..state import AcceleratorState, PartialState from ..utils import ( @@ -53,6 +52,7 @@ is_timm_available, is_torch_version, is_torch_xla_available, + is_torchdata_stateful_dataloader_available, is_torchvision_available, is_transformers_available, is_triton_available, diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 8a6d4e20886..ce6675ac2e7 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -107,6 +107,8 @@ is_tensorboard_available, is_timm_available, is_torch_xla_available, + is_torchdata_available, + is_torchdata_stateful_dataloader_available, is_torchvision_available, is_transformer_engine_available, is_transformers_available, diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index b8e3d7335e4..97698723969 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -35,9 +35,8 @@ require_non_torch_xla, require_torchdata_stateful_dataloader, ) -from accelerate.utils import patch_environment +from accelerate.utils import patch_environment, is_torchdata_stateful_dataloader_available from accelerate.utils.dataclasses import DataLoaderConfiguration -from accelerate.utils.imports import is_torchdata_stateful_dataloader_available from accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model from accelerate.utils.random import set_seed diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index b202caed788..2397ae04a7f 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -32,7 +32,7 @@ ) from accelerate.test_utils.testing import require_torchdata_stateful_dataloader from accelerate.utils.dataclasses import DataLoaderConfiguration -from accelerate.utils.imports import is_torchdata_stateful_dataloader_available +from accelerate.utils import is_torchdata_stateful_dataloader_available if is_torchdata_stateful_dataloader_available(): From 6bfe871b961f3e940d7b4a3c067333f0d4b6a187 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Mon, 15 Jul 2024 15:18:49 -0400 Subject: [PATCH 50/61] make style --- src/accelerate/data_loader.py | 1 - tests/test_accelerator.py | 2 +- tests/test_data_loader.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index ef3ae6bea47..a03beb85fa1 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -21,7 +21,6 @@ from accelerate.utils import is_torchdata_stateful_dataloader_available - from .logging import get_logger from .state import AcceleratorState, DistributedType, GradientState, is_torch_xla_available from .utils import ( diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 97698723969..0c6633e2c77 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -35,7 +35,7 @@ require_non_torch_xla, require_torchdata_stateful_dataloader, ) -from accelerate.utils import patch_environment, is_torchdata_stateful_dataloader_available +from accelerate.utils import is_torchdata_stateful_dataloader_available, patch_environment from accelerate.utils.dataclasses import DataLoaderConfiguration from accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model from accelerate.utils.random import set_seed diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 2397ae04a7f..b1108bf9f22 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -31,8 +31,8 @@ skip_first_batches, ) from accelerate.test_utils.testing import require_torchdata_stateful_dataloader -from accelerate.utils.dataclasses import DataLoaderConfiguration from accelerate.utils import is_torchdata_stateful_dataloader_available +from accelerate.utils.dataclasses import DataLoaderConfiguration if is_torchdata_stateful_dataloader_available(): From 6ff997e220c02ed3c82266ba477975006c7c51d0 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 17 Jul 2024 16:19:16 -0400 Subject: [PATCH 51/61] merges are breaking tests --- tests/test_accelerator.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index bff963d1e22..f24d6b4e843 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -114,8 +114,10 @@ def parameterized_custom_name_func(func, param_num, param): # name, as by default it shows only the first param param_based_name = "use_safetensors" if param.args[0] is True else "use_pytorch" if len(param.args) > 1: - param_based_name += f"_num_workers_{param.args[1]}" + param_based_name += "_tied_weights" if param.args[1] is True else "" if len(param.args) > 2: + param_based_name += f"_num_workers_{param.args[1]}" + if len(param.args) > 3: param_based_name += "_dispatch_batches" if param.args[2] is True else "_no_dispatch_batches" return f"{func.__name__}_{param_based_name}" @@ -652,10 +654,10 @@ def test_prepared_objects_are_referenced_with_stateful_dataloader(self): assert isinstance(prepared_valid_dl, StatefulDataLoader) @parameterized.expand( - itertools.product([True, False], [0, 2], [True, False]), name_func=parameterized_custom_name_func + itertools.product([True, False], [True, False], [0, 2], [True, False]), name_func=parameterized_custom_name_func ) @require_torchdata_stateful_dataloader - def test_save_model_with_stateful_dataloader(self, use_safetensors, num_workers, dispatch_batches): + def test_save_model_with_stateful_dataloader(self, use_safetensors, tied_weights, num_workers, dispatch_batches): """ Test that saving and loading a model with a stateful dataloader returns the same model, and that the dataloader's iterator is restored properly.""" @@ -664,7 +666,7 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors, num_workers, dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=True) accelerator = Accelerator(dataloader_config=dataloader_config) - model, optimizer, scheduler, train_dl, valid_dl = create_components() + model, optimizer, scheduler, train_dl, valid_dl = create_components(tied_weights) train_dl, valid_dl = create_dataloaders_for_test(num_workers=num_workers) model = ModelForTest() From 4739524d4e4303771773c87c6cf9700e588888e2 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 17 Jul 2024 16:32:04 -0400 Subject: [PATCH 52/61] fix test name --- tests/test_accelerator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index f24d6b4e843..ec57c7d1596 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -116,9 +116,9 @@ def parameterized_custom_name_func(func, param_num, param): if len(param.args) > 1: param_based_name += "_tied_weights" if param.args[1] is True else "" if len(param.args) > 2: - param_based_name += f"_num_workers_{param.args[1]}" + param_based_name += f"_num_workers_{param.args[2]}" if len(param.args) > 3: - param_based_name += "_dispatch_batches" if param.args[2] is True else "_no_dispatch_batches" + param_based_name += "_dispatch_batches" if param.args[3] is True else "_no_dispatch_batches" return f"{func.__name__}_{param_based_name}" From 06597d4868ddd0c1b55d8184358851143524f6af Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 24 Jul 2024 15:26:35 -0400 Subject: [PATCH 53/61] Require safetensors>=0.4.3 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index fe42e8833f0..fb7d133f111 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,7 @@ "pyyaml", "torch>=1.10.0", "huggingface_hub>=0.21.0", - "safetensors>=0.3.1", + "safetensors>=0.4.3", ], extras_require=extras, classifiers=[ From 4142c7ff46746a868b5bfff304a6c4bd33a0d83b Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 24 Jul 2024 15:28:23 -0400 Subject: [PATCH 54/61] undo last commit --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index fb7d133f111..fe42e8833f0 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,7 @@ "pyyaml", "torch>=1.10.0", "huggingface_hub>=0.21.0", - "safetensors>=0.4.3", + "safetensors>=0.3.1", ], extras_require=extras, classifiers=[ From 35977ca6f1a1daaccb8effb0268110544e3208aa Mon Sep 17 00:00:00 2001 From: byi8220 Date: Mon, 29 Jul 2024 14:44:34 -0400 Subject: [PATCH 55/61] minor style --- src/accelerate/data_loader.py | 3 +-- .../test_utils/scripts/external_deps/test_metrics.py | 5 ----- tests/test_accelerator.py | 3 ++- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index a03beb85fa1..a9bb70218f5 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -19,8 +19,6 @@ import torch from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler -from accelerate.utils import is_torchdata_stateful_dataloader_available - from .logging import get_logger from .state import AcceleratorState, DistributedType, GradientState, is_torch_xla_available from .utils import ( @@ -32,6 +30,7 @@ get_data_structure, initialize_tensors, is_torch_version, + is_torchdata_stateful_dataloader_available, send_to_device, slice_tensors, synchronize_rng_states, diff --git a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py index aca0f5ad07c..9ac13aba626 100755 --- a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py +++ b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py @@ -235,11 +235,6 @@ def test_gather_for_metrics_drop_last(): # Should return a full set of complete batches from each GPU num_expected_items = per_device_batch_size * accelerator.num_processes - print("dataloader.batch_size:", dataloader.batch_size) - print("accelerator.num_processes:", accelerator.num_processes) - print("gathered_items:", gathered_items) - print("batch:", batch) - print("len(dataloader):", len(dataloader)) assert gathered_items.size(0) == ( num_expected_items ), f"Expected number of items: {num_expected_items}, Actual: {gathered_items.size(0)}" diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index ec57c7d1596..53fb461bdd8 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -654,7 +654,8 @@ def test_prepared_objects_are_referenced_with_stateful_dataloader(self): assert isinstance(prepared_valid_dl, StatefulDataLoader) @parameterized.expand( - itertools.product([True, False], [True, False], [0, 2], [True, False]), name_func=parameterized_custom_name_func + itertools.product([True, False], [True, False], [0, 2], [True, False]), + name_func=parameterized_custom_name_func, ) @require_torchdata_stateful_dataloader def test_save_model_with_stateful_dataloader(self, use_safetensors, tied_weights, num_workers, dispatch_batches): From 4188d4c5a1e6a529ed8257ccfefde86f54c6d12a Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 20 Aug 2024 12:27:10 -0400 Subject: [PATCH 56/61] address pr comments --- src/accelerate/accelerator.py | 8 +++++--- src/accelerate/data_loader.py | 27 +++++++++++++++------------ tests/test_accelerator.py | 2 +- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index beb220e8d32..4ed80537144 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -585,7 +585,9 @@ def non_blocking(self): @property def use_stateful_dataloader(self): - return self.dataloader_config.use_stateful_dataloader + if hasattr(self.dataloader_config, "use_stateful_dataloader"): + return self.dataloader_config.use_stateful_dataloader + return False @property def project_dir(self): @@ -1624,9 +1626,9 @@ def _prepare_deepspeed(self, *args): deepspeed_plugin = self.state.deepspeed_plugin - is_dataloader_present = any((isinstance(obj, torch.utils.data.DataLoader)) for obj in args) + is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args) result = [ - self._prepare_one(obj, first_pass=True) if (isinstance(obj, torch.utils.data.DataLoader)) else obj + self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj for obj in args ] diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 590d477a1dc..aa6954d1901 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -401,7 +401,9 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * from torchdata.stateful_dataloader import StatefulDataLoader if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available(): - raise ImportError("StatefulDataLoader is not available. Please install torchdata to use it.") + raise ImportError( + "StatefulDataLoader is not available. Please install the nightly version of torchdata to use it." + ) if use_stateful_dataloader: self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs) else: @@ -430,8 +432,9 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * if hasattr(self.base_dataloader, "state_dict"): self.dl_state_dict = self.base_dataloader.state_dict() - for attr in self.base_dataloader.__dict__.keys(): - setattr(self, attr, getattr(self.base_dataloader, attr)) + def __getattr__(self, name): + # Delegate attribute access to the internal dataloader + return getattr(self.base_dataloader, name) def state_dict(self): return self.dl_state_dict @@ -440,7 +443,7 @@ def load_state_dict(self, state_dict): super().load_state_dict(state_dict) self.dl_state_dict = self.state_dict - def _save_state_dict(self): + def _update_state_dict(self): if hasattr(self.base_dataloader, "state_dict"): self.dl_state_dict = super().state_dict() @@ -492,7 +495,7 @@ def __init__( _non_blocking: bool = False, **kwargs, ): - super().__init__(dataset, use_stateful_dataloader, **kwargs) + super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs) self.device = device self.rng_types = rng_types self.synchronized_generator = synchronized_generator @@ -521,7 +524,7 @@ def __iter__(self): # But we still move it to the device so it is done before `StopIteration` is reached if self.device is not None: current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking) - self._save_state_dict() + self._update_state_dict() next_batch = next(dataloader_iter) if batch_index >= self.skip_batches: yield current_batch @@ -670,7 +673,7 @@ def __init__( # We need to save the shuffling state of the DataPipe if isinstance(dataset, ShufflerIterDataPipe): shuffle = dataset._shuffle_enabled - super().__init__(dataset, use_stateful_dataloader, **kwargs) + super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs) self.split_batches = split_batches if shuffle: torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle) @@ -691,14 +694,14 @@ def _fetch_batches(self, iterator): try: if self.split_batches: # One batch of the main iterator is dispatched and split. - self._save_state_dict() + self._update_state_dict() batch = next(iterator) else: # num_processes batches of the main iterator are concatenated then dispatched and split. # We add the batches one by one so we have the remainder available when drop_last=False. batches = [] for _ in range(self.state.num_processes): - self._save_state_dict() + self._update_state_dict() batches.append(next(iterator)) try: batch = concatenate(batches, dim=0) @@ -943,7 +946,7 @@ def prepare_data_loader( use_stateful_dataloader (`bool`, *optional*, defaults to `False`): "If set to true, the dataloader prepared by the Accelerator will be backed by " "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). - This requires a version" " of `torchdata` with StatefulDataLoader to be installed." + This requires the nightly version of `torchdata` that supports StatefulDataLoader to be installed." Returns: @@ -1152,13 +1155,13 @@ class SkipDataLoader(DataLoaderAdapter): """ def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs): - super().__init__(dataset, use_stateful_dataloader, **kwargs) + super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs) self.skip_batches = skip_batches def __iter__(self): for index, batch in enumerate(super().__iter__()): if index >= self.skip_batches: - self._save_state_dict() + self._update_state_dict() yield batch diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index d8e5a2c7ab6..fe75e31158d 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -43,7 +43,7 @@ require_non_torch_xla, require_torchdata_stateful_dataloader, ) -from accelerate.utils import is_torchdata_stateful_dataloader_available, FP8RecipeKwargs, patch_environment +from accelerate.utils import FP8RecipeKwargs, is_torchdata_stateful_dataloader_available, patch_environment from accelerate.utils.dataclasses import DataLoaderConfiguration from accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model from accelerate.utils.random import set_seed From 51377a439313c2116e76564f0f9cabd99c28ad0a Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 20 Aug 2024 13:46:14 -0400 Subject: [PATCH 57/61] Torchdata version 0.8.0 is stable now --- src/accelerate/data_loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index aa6954d1901..694ff710e50 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -402,7 +402,7 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available(): raise ImportError( - "StatefulDataLoader is not available. Please install the nightly version of torchdata to use it." + "StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it." ) if use_stateful_dataloader: self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs) @@ -946,7 +946,7 @@ def prepare_data_loader( use_stateful_dataloader (`bool`, *optional*, defaults to `False`): "If set to true, the dataloader prepared by the Accelerator will be backed by " "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). - This requires the nightly version of `torchdata` that supports StatefulDataLoader to be installed." + This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." Returns: From f4b6bb5d8ebb2e350415fa4c6395f2b4078f7c01 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 20 Aug 2024 14:36:41 -0400 Subject: [PATCH 58/61] added docs and require torchdata>=0.8.0 for testing --- docs/source/basic_tutorials/migration.md | 6 ++++++ docs/source/concept_guides/internal_mechanism.md | 6 ++++++ setup.py | 1 + 3 files changed, 13 insertions(+) diff --git a/docs/source/basic_tutorials/migration.md b/docs/source/basic_tutorials/migration.md index 6220702e977..8fb2c32f981 100644 --- a/docs/source/basic_tutorials/migration.md +++ b/docs/source/basic_tutorials/migration.md @@ -219,3 +219,9 @@ During training, you may want to save the current state of the model, optimizer, To further customize where and how states are saved through [`~Accelerator.save_state`], use the [`~utils.ProjectConfiguration`] class. For example, if `automatic_checkpoint_naming` is enabled, each saved checkpoint is stored at `Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}`. Any other stateful items to be stored should be registered with the [`~Accelerator.register_for_checkpointing`] method so they can be saved and loaded. Every object passed to this method to be stored must have a `load_state_dict` and `state_dict` function. + + + +If you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, you can additionally pass `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`]. This extends Accelerate's DataLoader classes with a `load_state_dict` and `state_dict` function, and makes it so `Accelerator.save_state` and `Accelerator.load_state` also track how far into the training dataset it has read when persisting the model. + + diff --git a/docs/source/concept_guides/internal_mechanism.md b/docs/source/concept_guides/internal_mechanism.md index e0b715dfa63..2410d882bb5 100644 --- a/docs/source/concept_guides/internal_mechanism.md +++ b/docs/source/concept_guides/internal_mechanism.md @@ -69,4 +69,10 @@ setting the same seed in the main random number generator in all processes. + + +If you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, and you have passed `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`], these classes will directly inherit from `StatefulDataLoader` instead, and maintain a `state_dict`. + + + For more details about the internals, see the [Internals page](package_reference/torch_wrappers). diff --git a/setup.py b/setup.py index 27d609cfa11..85a90bf82d8 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ "datasets", "diffusers", "evaluate", + "torchdata>=0.8.0", "torchpippy>=0.2.0", "transformers", "scipy", From d02dfcc023546b8119a740fb7ff7ce11fa9eabe7 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 21 Aug 2024 08:33:50 -0400 Subject: [PATCH 59/61] test base_dataloader attr doesn't cause infinite recursion --- tests/test_data_loader.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index b1108bf9f22..bc035e4a24d 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -404,6 +404,10 @@ def test_dataloader_inheritance(self): assert isinstance(dl_shard, DataLoaderStateMixin) assert isinstance(dl_dispatcher, DataLoaderStateMixin) + assert isinstance(skip_dl.base_dataloader, DataLoader) + assert isinstance(dl_shard.base_dataloader, DataLoader) + assert isinstance(dl_dispatcher.base_dataloader, DataLoader) + def test_skip_data_loader(self): dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2) assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] @@ -543,6 +547,10 @@ def test_dataloader_inheritance(self): assert isinstance(dl_shard, DataLoaderStateMixin) assert isinstance(dl_dispatcher, DataLoaderStateMixin) + assert isinstance(skip_dl.base_dataloader, StatefulDataLoader) + assert isinstance(dl_shard.base_dataloader, StatefulDataLoader) + assert isinstance(dl_dispatcher.base_dataloader, StatefulDataLoader) + @parameterized.expand([0, 2], name_func=parameterized_custom_name_func) @require_torchdata_stateful_dataloader def test_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader(self, num_workers): From 21bc4204a9d6d392afbc11aa4496b389b6305845 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 21 Aug 2024 11:28:04 -0400 Subject: [PATCH 60/61] address pr --- src/accelerate/data_loader.py | 9 ++++++++- src/accelerate/utils/dataclasses.py | 3 +-- src/accelerate/utils/imports.py | 10 +++++----- tests/test_accelerator.py | 9 ++------- tests/test_data_loader.py | 6 ++++++ 5 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 694ff710e50..3397bf2743f 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -428,11 +428,13 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * parent_cls_name = self.base_dataloader.__class__ self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {}) - # Allow this class to transparently pass through attributes from the underlying class if hasattr(self.base_dataloader, "state_dict"): self.dl_state_dict = self.base_dataloader.state_dict() def __getattr__(self, name): + # Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute. + if name == "base_dataloader": + raise AttributeError() # Delegate attribute access to the internal dataloader return getattr(self.base_dataloader, name) @@ -444,6 +446,11 @@ def load_state_dict(self, state_dict): self.dl_state_dict = self.state_dict def _update_state_dict(self): + # The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded. + # E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of + # what it wants to yield. + # + # _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter. if hasattr(self.base_dataloader, "state_dict"): self.dl_state_dict = super().state_dict() diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 2fbec48911a..7fa7810878e 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -781,8 +781,7 @@ class DataLoaderConfiguration: default=False, metadata={ "help": "If set to `True`, the dataloader prepared by the Accelerator will be backed by " - "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires a version" - " of `torchdata` with StatefulDataLoader to be installed." + "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." }, ) diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index b0255f3422e..3badfefd684 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -439,8 +439,8 @@ def is_torchdata_available(): # TODO: Remove this function once stateful_dataloader is a stable feature in torchdata. def is_torchdata_stateful_dataloader_available(): - if not is_torchdata_available(): - return False - import torchdata - - return hasattr(torchdata, "stateful_dataloader") and hasattr(torchdata.stateful_dataloader, "StatefulDataLoader") + package_exists = _is_package_available("torchdata") + if package_exists: + torchdata_version = version.parse(importlib.metadata.version("torchdata")) + return compare_versions(torchdata_version, ">=", "0.8.0") + return False diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index fe75e31158d..42df216123a 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -50,9 +50,7 @@ if is_torchdata_stateful_dataloader_available(): - from torchdata.stateful_dataloader import ( - StatefulDataLoader, - ) + from torchdata.stateful_dataloader import StatefulDataLoader class ModelWithTiedWeights(torch.nn.Module): @@ -87,9 +85,7 @@ def forward(self, x): return self.linear2(self.batchnorm(self.linear1(x))) -def create_dataloaders_for_test( - a=2, b=3, batch_size=3, n_train_batches: int = 12, n_valid_batches: int = 2, num_workers=0 -): +def create_dataloaders_for_test(batch_size=3, n_train_batches: int = 12, n_valid_batches: int = 2, num_workers=0): "Generates a tuple of dummy DataLoaders to test with" def get_dataset(n_batches): @@ -685,7 +681,6 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors, tied_weights """ Test that saving and loading a model with a stateful dataloader returns the same model, and that the dataloader's iterator is restored properly.""" - print() set_seed(42) dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=True) accelerator = Accelerator(dataloader_config=dataloader_config) diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index bc035e4a24d..d60e599722d 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -15,6 +15,7 @@ import random import unittest +import pytest import torch from parameterized import parameterized from torch.utils.data import BatchSampler, DataLoader, IterableDataset @@ -408,6 +409,9 @@ def test_dataloader_inheritance(self): assert isinstance(dl_shard.base_dataloader, DataLoader) assert isinstance(dl_dispatcher.base_dataloader, DataLoader) + with pytest.raises(AttributeError): + _ = DataLoaderShard.base_dataloader + def test_skip_data_loader(self): dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2) assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] @@ -498,6 +502,7 @@ def test_dataloader_state_dict(self, num_workers): data1 = vals[2:] data2 = list(dataloader2) + assert len(data1) == len(data2) for d1, d2 in zip(data1, data2): assert torch.allclose(d1, d2) @@ -527,6 +532,7 @@ def test_dataloader_dispatcher_state_dict(self, num_workers): data1 = vals[2:] data2 = list(dataloader2) + assert len(data1) == len(data2) for d1, d2 in zip(data1, data2): assert torch.allclose(d1, d2) From 74e2f53d841f5701d476f3fe5f6df8f97ad82e5c Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 21 Aug 2024 11:32:33 -0400 Subject: [PATCH 61/61] replace super().__iter__ with self.base_dataloader.__iter__ --- src/accelerate/data_loader.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 3397bf2743f..e5b6364b4ab 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -442,7 +442,7 @@ def state_dict(self): return self.dl_state_dict def load_state_dict(self, state_dict): - super().load_state_dict(state_dict) + self.base_dataloader.load_state_dict(state_dict) self.dl_state_dict = self.state_dict def _update_state_dict(self): @@ -452,7 +452,7 @@ def _update_state_dict(self): # # _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter. if hasattr(self.base_dataloader, "state_dict"): - self.dl_state_dict = super().state_dict() + self.dl_state_dict = self.base_dataloader.state_dict() class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): @@ -518,7 +518,7 @@ def __iter__(self): self.begin() self.set_epoch(self.iteration) - dataloader_iter = super().__iter__() + dataloader_iter = self.base_dataloader.__iter__() # We iterate one batch ahead to check when we are at the end try: current_batch = next(dataloader_iter) @@ -749,9 +749,9 @@ def __iter__(self): # NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts # shared seed to all dist processes. Thus, we need to create iterator for all dist processes. # But, we only iterate through the DataLoader on process 0. - main_iterator = super().__iter__() + main_iterator = self.base_dataloader.__iter__() elif self.state.process_index == 0: - main_iterator = super().__iter__() + main_iterator = self.base_dataloader.__iter__() stop_iteration = False self._stop_iteration = False first_batch = None @@ -1166,7 +1166,7 @@ def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwa self.skip_batches = skip_batches def __iter__(self): - for index, batch in enumerate(super().__iter__()): + for index, batch in enumerate(self.base_dataloader.__iter__()): if index >= self.skip_batches: self._update_state_dict() yield batch