diff --git a/docs/source/basic_tutorials/migration.md b/docs/source/basic_tutorials/migration.md index 6220702e977..8fb2c32f981 100644 --- a/docs/source/basic_tutorials/migration.md +++ b/docs/source/basic_tutorials/migration.md @@ -219,3 +219,9 @@ During training, you may want to save the current state of the model, optimizer, To further customize where and how states are saved through [`~Accelerator.save_state`], use the [`~utils.ProjectConfiguration`] class. For example, if `automatic_checkpoint_naming` is enabled, each saved checkpoint is stored at `Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}`. Any other stateful items to be stored should be registered with the [`~Accelerator.register_for_checkpointing`] method so they can be saved and loaded. Every object passed to this method to be stored must have a `load_state_dict` and `state_dict` function. + + + +If you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, you can additionally pass `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`]. This extends Accelerate's DataLoader classes with a `load_state_dict` and `state_dict` function, and makes it so `Accelerator.save_state` and `Accelerator.load_state` also track how far into the training dataset it has read when persisting the model. + + diff --git a/docs/source/concept_guides/internal_mechanism.md b/docs/source/concept_guides/internal_mechanism.md index e0b715dfa63..2410d882bb5 100644 --- a/docs/source/concept_guides/internal_mechanism.md +++ b/docs/source/concept_guides/internal_mechanism.md @@ -69,4 +69,10 @@ setting the same seed in the main random number generator in all processes. + + +If you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, and you have passed `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`], these classes will directly inherit from `StatefulDataLoader` instead, and maintain a `state_dict`. + + + For more details about the internals, see the [Internals page](package_reference/torch_wrappers). diff --git a/setup.py b/setup.py index 27d609cfa11..85a90bf82d8 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ "datasets", "diffusers", "evaluate", + "torchdata>=0.8.0", "torchpippy>=0.2.0", "transformers", "scipy", diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 3f5f1279132..4ed80537144 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -583,6 +583,12 @@ def use_seedable_sampler(self): def non_blocking(self): return self.dataloader_config.non_blocking + @property + def use_stateful_dataloader(self): + if hasattr(self.dataloader_config, "use_stateful_dataloader"): + return self.dataloader_config.use_stateful_dataloader + return False + @property def project_dir(self): return self.project_configuration.project_dir @@ -2068,6 +2074,7 @@ def prepare_data_loader( slice_fn_for_dispatch=slice_fn_for_dispatch, use_seedable_sampler=self.use_seedable_sampler, non_blocking=self.non_blocking, + use_stateful_dataloader=self.use_stateful_dataloader, ) self._dataloaders.append(prepared_data_loader) return prepared_data_loader diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index f0e88c645ec..e5b6364b4ab 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -30,6 +30,7 @@ get_data_structure, initialize_tensors, is_torch_version, + is_torchdata_stateful_dataloader_available, send_to_device, slice_tensors, synchronize_rng_states, @@ -388,9 +389,75 @@ def end(self): self.gradient_state._remove_dataloader(self) -class DataLoaderShard(DataLoader, DataLoaderStateMixin): +class DataLoaderAdapter: """ - Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup. + A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For + compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in. + """ + + def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs): + self.use_stateful_dataloader = use_stateful_dataloader + if is_torchdata_stateful_dataloader_available(): + from torchdata.stateful_dataloader import StatefulDataLoader + + if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available(): + raise ImportError( + "StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it." + ) + if use_stateful_dataloader: + self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs) + else: + self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs) + + # Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641 + # In C++ terms, this is analogous to creating `DataLoaderAdapter : T`, where T is a DataLoader or + # StatefulDataLoader + # + # The same functionality could be achieved by directly creating the required subclasses for both {DataLoader, + # StatefulDataLoader}, however that could lead to much messier code, with duplicated classes and conditional + # dispatching scattered throughout various functions and files. + # + # This code is incredibly awkward but it's the only way to make `isinstance(obj, StatefulDataLoader)` work + # transparently. + # + # A more robust solution is for DataLoaderAdapter to not inherit from DataLoader (compose rather than inherit), + # but this would not be backwards compatible with existing code which assumes + # DataLoaderShard/DataLoaderDispatcher are DataLoaders. + base_cls = self.__class__ + base_cls_name = self.__class__.__name__ + parent_cls_name = self.base_dataloader.__class__ + self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {}) + + if hasattr(self.base_dataloader, "state_dict"): + self.dl_state_dict = self.base_dataloader.state_dict() + + def __getattr__(self, name): + # Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute. + if name == "base_dataloader": + raise AttributeError() + # Delegate attribute access to the internal dataloader + return getattr(self.base_dataloader, name) + + def state_dict(self): + return self.dl_state_dict + + def load_state_dict(self, state_dict): + self.base_dataloader.load_state_dict(state_dict) + self.dl_state_dict = self.state_dict + + def _update_state_dict(self): + # The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded. + # E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of + # what it wants to yield. + # + # _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter. + if hasattr(self.base_dataloader, "state_dict"): + self.dl_state_dict = self.base_dataloader.state_dict() + + +class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): + """ + Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup. Args: dataset (`torch.utils.data.dataset.Dataset`): @@ -409,6 +476,8 @@ class DataLoaderShard(DataLoader, DataLoaderStateMixin): A random number generator to keep synchronized across processes. skip_batches (`int`, *optional*, defaults to 0): The number of batches to skip at the beginning. + use_stateful_dataloader (`bool`, *optional*, defaults to `False`): + Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`. **kwargs (additional keyword arguments, *optional*): All other keyword arguments to pass to the regular `DataLoader` initialization. @@ -428,11 +497,12 @@ def __init__( rng_types=None, synchronized_generator=None, skip_batches=0, + use_stateful_dataloader=False, _drop_last: bool = False, _non_blocking: bool = False, **kwargs, ): - super().__init__(dataset, **kwargs) + super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs) self.device = device self.rng_types = rng_types self.synchronized_generator = synchronized_generator @@ -448,7 +518,7 @@ def __iter__(self): self.begin() self.set_epoch(self.iteration) - dataloader_iter = super().__iter__() + dataloader_iter = self.base_dataloader.__iter__() # We iterate one batch ahead to check when we are at the end try: current_batch = next(dataloader_iter) @@ -461,6 +531,7 @@ def __iter__(self): # But we still move it to the device so it is done before `StopIteration` is reached if self.device is not None: current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking) + self._update_state_dict() next_batch = next(dataloader_iter) if batch_index >= self.skip_batches: yield current_batch @@ -564,10 +635,10 @@ def dataloader(self): return self._loader -class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin): +class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin): """ - Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each - process their part of the batch. + Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process + their part of the batch. Args: split_batches (`bool`, *optional*, defaults to `False`): @@ -579,6 +650,8 @@ class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin): size of the `dataloader` is a round multiple of `batch_size`. skip_batches (`int`, *optional*, defaults to 0): The number of batches to skip at the beginning of an iteration. + use_stateful_dataloader (`bool`, *optional*, defaults to `False`): + Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`. **Available attributes:** @@ -594,6 +667,7 @@ def __init__( dataset, split_batches: bool = False, skip_batches=0, + use_stateful_dataloader=False, _drop_last: bool = False, _non_blocking: bool = False, slice_fn=None, @@ -606,7 +680,7 @@ def __init__( # We need to save the shuffling state of the DataPipe if isinstance(dataset, ShufflerIterDataPipe): shuffle = dataset._shuffle_enabled - super().__init__(dataset, **kwargs) + super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs) self.split_batches = split_batches if shuffle: torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle) @@ -627,12 +701,14 @@ def _fetch_batches(self, iterator): try: if self.split_batches: # One batch of the main iterator is dispatched and split. + self._update_state_dict() batch = next(iterator) else: # num_processes batches of the main iterator are concatenated then dispatched and split. # We add the batches one by one so we have the remainder available when drop_last=False. batches = [] for _ in range(self.state.num_processes): + self._update_state_dict() batches.append(next(iterator)) try: batch = concatenate(batches, dim=0) @@ -673,9 +749,9 @@ def __iter__(self): # NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts # shared seed to all dist processes. Thus, we need to create iterator for all dist processes. # But, we only iterate through the DataLoader on process 0. - main_iterator = super().__iter__() + main_iterator = self.base_dataloader.__iter__() elif self.state.process_index == 0: - main_iterator = super().__iter__() + main_iterator = self.base_dataloader.__iter__() stop_iteration = False self._stop_iteration = False first_batch = None @@ -812,6 +888,7 @@ def prepare_data_loader( slice_fn_for_dispatch: Optional[Callable] = None, use_seedable_sampler: bool = False, non_blocking: bool = False, + use_stateful_dataloader: bool = False, ) -> DataLoader: """ Wraps a PyTorch `DataLoader` to generate batches for one of the processes only. @@ -873,6 +950,10 @@ def prepare_data_loader( non_blocking (`bool`, *optional*, defaults to `False`): If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has `pin_memory` set to `True`, this will help to increase overlap between data transfer and computations. + use_stateful_dataloader (`bool`, *optional*, defaults to `False`): + "If set to true, the dataloader prepared by the Accelerator will be backed by " + "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). + This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." Returns: @@ -1006,6 +1087,7 @@ def prepare_data_loader( _drop_last=dataloader.drop_last, _non_blocking=non_blocking, slice_fn=slice_fn_for_dispatch, + use_stateful_dataloader=use_stateful_dataloader, **kwargs, ) elif sampler_is_batch_sampler: @@ -1018,6 +1100,7 @@ def prepare_data_loader( _drop_last=dataloader.drop_last, _non_blocking=non_blocking, synchronized_generator=synchronized_generator, + use_stateful_dataloader=use_stateful_dataloader, **kwargs, ) else: @@ -1029,6 +1112,7 @@ def prepare_data_loader( synchronized_generator=synchronized_generator, _drop_last=dataloader.drop_last, _non_blocking=non_blocking, + use_stateful_dataloader=use_stateful_dataloader, **kwargs, ) @@ -1046,6 +1130,7 @@ class SkipBatchSampler(BatchSampler): def __init__(self, batch_sampler, skip_batches=0): self.batch_sampler = batch_sampler + self.sampler = batch_sampler.sampler self.skip_batches = skip_batches def __iter__(self): @@ -1061,7 +1146,7 @@ def __len__(self): return len(self.batch_sampler) - self.skip_batches -class SkipDataLoader(DataLoader): +class SkipDataLoader(DataLoaderAdapter): """ Subclass of a PyTorch `DataLoader` that will skip the first batches. @@ -1070,17 +1155,20 @@ class SkipDataLoader(DataLoader): The dataset to use to build this datalaoder. skip_batches (`int`, *optional*, defaults to 0): The number of batches to skip at the beginning. + use_stateful_dataloader (`bool`, *optional*, defaults to `False`): + Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`. kwargs: All other keyword arguments to pass to the regular `DataLoader` initialization. """ - def __init__(self, dataset, skip_batches=0, **kwargs): - super().__init__(dataset, **kwargs) + def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs): + super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs) self.skip_batches = skip_batches def __iter__(self): - for index, batch in enumerate(super().__iter__()): + for index, batch in enumerate(self.base_dataloader.__iter__()): if index >= self.skip_batches: + self._update_state_dict() yield batch @@ -1088,6 +1176,9 @@ def skip_first_batches(dataloader, num_batches=0): """ Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. """ + if is_torchdata_stateful_dataloader_available(): + from torchdata.stateful_dataloader import StatefulDataLoader + state = PartialState() if state.distributed_type == DistributedType.XLA: device = dataloader.device @@ -1131,6 +1222,7 @@ def skip_first_batches(dataloader, num_batches=0): split_batches=dataloader.split_batches, batch_sampler=new_batch_sampler, _drop_last=dataloader._drop_last, + use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs, ) elif isinstance(dataloader, DataLoaderShard): @@ -1147,12 +1239,17 @@ def skip_first_batches(dataloader, num_batches=0): device=dataloader.device, rng_types=dataloader.rng_types, synchronized_generator=dataloader.synchronized_generator, + use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs, ) else: if new_batch_sampler is None: # Need to manually skip batches in the dataloader - dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs) + dataloader = SkipDataLoader( + dataset, skip_batches=num_batches, use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs + ) + elif is_torchdata_stateful_dataloader_available() and isinstance(dataloader, StatefulDataLoader): + dataloader = StatefulDataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) else: dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) diff --git a/src/accelerate/test_utils/scripts/test_sync.py b/src/accelerate/test_utils/scripts/test_sync.py index 1029f475f6c..20b5b2c752e 100644 --- a/src/accelerate/test_utils/scripts/test_sync.py +++ b/src/accelerate/test_utils/scripts/test_sync.py @@ -305,12 +305,12 @@ def test_gradient_accumulation_with_opt_and_scheduler( def test_dataloader_break(): accelerator = Accelerator() - first_dset = RegressionDataset(length=80) first_dataloader = DataLoader(first_dset, batch_size=16) second_dset = RegressionDataset(length=96) second_dataloader = DataLoader(second_dset, batch_size=16) first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader) + assert accelerator.gradient_state.active_dataloader is None for iteration, _ in enumerate(first_dataloader): assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader) diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index 43a55e2339f..f9ade458eac 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -52,6 +52,7 @@ is_timm_available, is_torch_version, is_torch_xla_available, + is_torchdata_stateful_dataloader_available, is_torchvision_available, is_transformer_engine_available, is_transformers_available, @@ -429,6 +430,18 @@ def require_trackers(test_case): )(test_case) +def require_torchdata_stateful_dataloader(test_case): + """ + Decorator marking a test that requires torchdata.stateful_dataloader. + + These tests are skipped when torchdata with stateful_dataloader module isn't installed. + + """ + return unittest.skipUnless( + is_torchdata_stateful_dataloader_available(), "test requires torchdata.stateful_dataloader" + )(test_case) + + class TempDirTestCase(unittest.TestCase): """ A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 07ba2fdcf8f..ed6c77d8de6 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -107,6 +107,8 @@ is_tensorboard_available, is_timm_available, is_torch_xla_available, + is_torchdata_available, + is_torchdata_stateful_dataloader_available, is_torchvision_available, is_transformer_engine_available, is_transformers_available, diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 0f35a294736..7fa7810878e 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -749,7 +749,7 @@ class DataLoaderConfiguration: metadata={ "help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process" " and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose" - " underlying dataset is an `IterableDataslet`, `False` otherwise." + " underlying dataset is an `IterableDataset`, `False` otherwise." }, ) even_batches: bool = field( @@ -777,6 +777,13 @@ class DataLoaderConfiguration: " prepared dataloader has `pin_memory` set to `True` to work properly." }, ) + use_stateful_dataloader: bool = field( + default=False, + metadata={ + "help": "If set to `True`, the dataloader prepared by the Accelerator will be backed by " + "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." + }, + ) @dataclass diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 592e62e6172..3badfefd684 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -431,3 +431,16 @@ def is_xpu_available(check_device=False): def is_dvclive_available(): return _is_package_available("dvclive") + + +def is_torchdata_available(): + return _is_package_available("torchdata") + + +# TODO: Remove this function once stateful_dataloader is a stable feature in torchdata. +def is_torchdata_stateful_dataloader_available(): + package_exists = _is_package_available("torchdata") + if package_exists: + torchdata_version = version.parse(importlib.metadata.version("torchdata")) + return compare_versions(torchdata_version, ">=", "0.8.0") + return False diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 4ef7a94b281..42df216123a 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools import json import os import pickle @@ -26,6 +27,7 @@ from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch from accelerate.accelerator import Accelerator +from accelerate.data_loader import skip_first_batches from accelerate.state import GradientState, PartialState from accelerate.test_utils import ( require_bnb, @@ -35,9 +37,20 @@ slow, torch_device, ) -from accelerate.test_utils.testing import AccelerateTestCase, require_cuda, require_non_torch_xla -from accelerate.utils import FP8RecipeKwargs, patch_environment +from accelerate.test_utils.testing import ( + AccelerateTestCase, + require_cuda, + require_non_torch_xla, + require_torchdata_stateful_dataloader, +) +from accelerate.utils import FP8RecipeKwargs, is_torchdata_stateful_dataloader_available, patch_environment +from accelerate.utils.dataclasses import DataLoaderConfiguration from accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model +from accelerate.utils.random import set_seed + + +if is_torchdata_stateful_dataloader_available(): + from torchdata.stateful_dataloader import StatefulDataLoader class ModelWithTiedWeights(torch.nn.Module): @@ -58,7 +71,6 @@ def create_components(tied_weights=False): scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=2, epochs=1) train_dl = DataLoader(TensorDataset(torch.tensor([1, 2, 3]))) valid_dl = DataLoader(TensorDataset(torch.tensor([4, 5, 6]))) - return model, optimizer, scheduler, train_dl, valid_dl @@ -73,6 +85,21 @@ def forward(self, x): return self.linear2(self.batchnorm(self.linear1(x))) +def create_dataloaders_for_test(batch_size=3, n_train_batches: int = 12, n_valid_batches: int = 2, num_workers=0): + "Generates a tuple of dummy DataLoaders to test with" + + def get_dataset(n_batches): + x = torch.randn(batch_size * n_batches, 3) + y = torch.randn(batch_size * n_batches, 5) + return TensorDataset(x, y) + + train_dataset = get_dataset(n_train_batches) + valid_dataset = get_dataset(n_valid_batches) + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers) + valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=num_workers) + return (train_dataloader, valid_dataloader) + + def get_signature(model): return sum(param.abs().sum().item() for param in model.parameters()) @@ -89,7 +116,12 @@ def parameterized_custom_name_func(func, param_num, param): # customize the test name generator function as we want both params to appear in the sub-test # name, as by default it shows only the first param param_based_name = "use_safetensors" if param.args[0] is True else "use_pytorch" - param_based_name += "_tied_weights" if (len(param.args) == 2 and param.args[1] is True) else "" + if len(param.args) > 1: + param_based_name += "_tied_weights" if param.args[1] is True else "" + if len(param.args) > 2: + param_based_name += f"_num_workers_{param.args[2]}" + if len(param.args) > 3: + param_based_name += "_dispatch_batches" if param.args[3] is True else "_no_dispatch_batches" return f"{func.__name__}_{param_based_name}" @@ -615,3 +647,133 @@ def test_can_unwrap_model(self): # check that pickle roundtrip works model_loaded = pickle.loads(pickle.dumps(model)) model_loaded(inputs) + + # Ideally would be a parameterized test which works with either stateful or non-stateful dataloaders, but dependencies are a bit awkward. + @require_torchdata_stateful_dataloader + def test_prepared_objects_are_referenced_with_stateful_dataloader(self): + """Test that setting `use_stateful_dataloader=True` in `DataLoaderConfiguration` prepares a `StatefulDataLoader` object instead of a `DataLoader` object.""" + dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) + accelerator = Accelerator(dataloader_config=dataloader_config) + model, optimizer, scheduler, train_dl, valid_dl = create_components() + + ( + prepared_model, + prepared_optimizer, + prepared_scheduler, + prepared_train_dl, + prepared_valid_dl, + ) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl) + + assert prepared_model in accelerator._models + assert prepared_optimizer in accelerator._optimizers + assert prepared_scheduler in accelerator._schedulers + assert prepared_train_dl in accelerator._dataloaders + assert prepared_valid_dl in accelerator._dataloaders + assert isinstance(prepared_train_dl, StatefulDataLoader) + assert isinstance(prepared_valid_dl, StatefulDataLoader) + + @parameterized.expand( + itertools.product([True, False], [True, False], [0, 2], [True, False]), + name_func=parameterized_custom_name_func, + ) + @require_torchdata_stateful_dataloader + def test_save_model_with_stateful_dataloader(self, use_safetensors, tied_weights, num_workers, dispatch_batches): + """ + Test that saving and loading a model with a stateful dataloader returns the same model, + and that the dataloader's iterator is restored properly.""" + set_seed(42) + dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=True) + accelerator = Accelerator(dataloader_config=dataloader_config) + + model, optimizer, scheduler, train_dl, valid_dl = create_components(tied_weights) + train_dl, valid_dl = create_dataloaders_for_test(num_workers=num_workers) + model = ModelForTest() + + ( + prepared_model, + prepared_optimizer, + prepared_scheduler, + prepared_train_dl, + prepared_valid_dl, + ) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl) + + assert isinstance(prepared_train_dl, StatefulDataLoader) + assert isinstance(prepared_valid_dl, StatefulDataLoader) + + # Perform 3 training iterations to ensure the dataloader's iterator is advanced + num_batches_to_skip = 3 + model.train() + for step, batch in enumerate(prepared_train_dl): + x, y = batch + x.to(accelerator.device) + y.to(accelerator.device) + with accelerator.accumulate(prepared_model): + outputs = prepared_model(x) + loss = torch.nn.functional.mse_loss(outputs, y) + accelerator.backward(loss) + prepared_optimizer.step() + prepared_scheduler.step() + prepared_optimizer.zero_grad() + if step == num_batches_to_skip - 1: + state_dict = prepared_train_dl.state_dict() + # When breaking out without fully going through the iterator, must call end() to unregister this iterator from gradient state. + # TODO: Maybe this could be done automatically? + prepared_train_dl.end() + break + + assert accelerator.gradient_state.active_dataloader is None + + with tempfile.TemporaryDirectory() as tmpdirname: + # Save model for later use + accelerator.save_model(model, tmpdirname, safe_serialization=use_safetensors) + + # Starting from where we left off, train this model to the end of the DataLoader + prepared_train_dl = skip_first_batches(prepared_train_dl, num_batches_to_skip) + batches_seen_with_original_dl = 0 + for batch in prepared_train_dl: + x, y = batch + x.to(accelerator.device) + y.to(accelerator.device) + with accelerator.accumulate(prepared_model): + outputs = prepared_model(x) + loss = torch.nn.functional.mse_loss(outputs, y) + accelerator.backward(loss) + prepared_optimizer.step() + prepared_scheduler.step() + prepared_optimizer.zero_grad() + batches_seen_with_original_dl += 1 + + original_linear1 = prepared_model.linear1.weight.clone() + original_batchnorm = prepared_model.batchnorm.weight.clone() + original_linear2 = prepared_model.linear2.weight.clone() + + # Load the model and state dict + load_checkpoint_in_model(model, tmpdirname) + stateful_train_dl, _ = create_dataloaders_for_test(num_workers=num_workers) + prepared_stateful_train_dl = accelerator.prepare_data_loader(stateful_train_dl) + prepared_stateful_train_dl.load_state_dict(state_dict) + + # Train this to the end of the DataLoader + batches_seen_with_loaded_dl = 0 + for batch in prepared_stateful_train_dl: + x, y = batch + x.to(accelerator.device) + y.to(accelerator.device) + with accelerator.accumulate(prepared_model): + outputs = prepared_model(x) + loss = torch.nn.functional.mse_loss(outputs, y) + accelerator.backward(loss) + prepared_optimizer.step() + prepared_scheduler.step() + prepared_optimizer.zero_grad() + batches_seen_with_loaded_dl += 1 + + new_linear1 = prepared_model.linear1.weight + new_batchnorm = prepared_model.batchnorm.weight + new_linear2 = prepared_model.linear2.weight + + # Assert equalities + assert batches_seen_with_original_dl == batches_seen_with_loaded_dl + assert torch.allclose(original_linear1, new_linear1) + assert torch.allclose(original_batchnorm, new_batchnorm) + assert torch.allclose(original_linear2, new_linear2) diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 2f360d71bcb..d60e599722d 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -15,6 +15,9 @@ import random import unittest +import pytest +import torch +from parameterized import parameterized from torch.utils.data import BatchSampler, DataLoader, IterableDataset from accelerate import Accelerator @@ -22,11 +25,28 @@ BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard, + DataLoaderStateMixin, IterableDatasetShard, SkipBatchSampler, SkipDataLoader, skip_first_batches, ) +from accelerate.test_utils.testing import require_torchdata_stateful_dataloader +from accelerate.utils import is_torchdata_stateful_dataloader_available +from accelerate.utils.dataclasses import DataLoaderConfiguration + + +if is_torchdata_stateful_dataloader_available(): + from torchdata.stateful_dataloader import ( + StatefulDataLoader, + ) + + +def parameterized_custom_name_func(func, param_num, param): + # customize the test name generator function as we want both params to appear in the sub-test + # name, as by default it shows only the first param + param_based_name = f"num_workers_{param.args[0]}" + return f"{func.__name__}_{param_based_name}" class RandomIterableDataset(IterableDataset): @@ -369,6 +389,29 @@ def test_skip_batch_sampler(self): new_batch_sampler = SkipBatchSampler(batch_sampler, 2) assert list(new_batch_sampler) == [[8, 9, 10, 11], [12, 13, 14, 15]] + def test_dataloader_inheritance(self): + """ + `DataLoaderAdapter`'s parent classes are dynamically constructed, assert that subclasses of DataLoaderAdapter + are instances of DataLoader and DataLoaderStateMixin. + """ + Accelerator() + skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2) + dl_shard = DataLoaderShard(range(16), batch_size=4) + dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4) + assert isinstance(skip_dl, DataLoader) + assert isinstance(dl_shard, DataLoader) + assert isinstance(dl_dispatcher, DataLoader) + + assert isinstance(dl_shard, DataLoaderStateMixin) + assert isinstance(dl_dispatcher, DataLoaderStateMixin) + + assert isinstance(skip_dl.base_dataloader, DataLoader) + assert isinstance(dl_shard.base_dataloader, DataLoader) + assert isinstance(dl_dispatcher.base_dataloader, DataLoader) + + with pytest.raises(AttributeError): + _ = DataLoaderShard.base_dataloader + def test_skip_data_loader(self): dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2) assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] @@ -396,3 +439,230 @@ def test_end_of_dataloader_dispatcher(self): # Test it also works on the second iteration for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) + + +class StatefulDataLoaderTester(unittest.TestCase): + @require_torchdata_stateful_dataloader + def test_skip_data_loader(self): + dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2, use_stateful_dataloader=True) + assert isinstance(dataloader, StatefulDataLoader) + assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] + + @require_torchdata_stateful_dataloader + def test_skip_first_batches(self): + dataloader = StatefulDataLoader(list(range(16)), batch_size=4) + new_dataloader = skip_first_batches(dataloader, num_batches=2) + assert isinstance(new_dataloader, StatefulDataLoader) + assert [t.tolist() for t in new_dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] + + @require_torchdata_stateful_dataloader + def test_end_of_dataloader(self): + dataloader = DataLoaderShard(list(range(16)), batch_size=4, use_stateful_dataloader=True) + assert dataloader.use_stateful_dataloader + assert isinstance(dataloader, StatefulDataLoader) + for idx, _ in enumerate(dataloader): + assert dataloader.end_of_dataloader == (idx == 3) + + # Test it also works on the second iteration + for idx, _ in enumerate(dataloader): + assert dataloader.end_of_dataloader == (idx == 3) + + @require_torchdata_stateful_dataloader + def test_end_of_dataloader_dispatcher(self): + Accelerator() + dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) + assert isinstance(dataloader, StatefulDataLoader) + for idx, _ in enumerate(dataloader): + assert dataloader.end_of_dataloader == (idx == 3) + + # Test it also works on the second iteration + for idx, _ in enumerate(dataloader): + assert dataloader.end_of_dataloader == (idx == 3) + + @parameterized.expand([0, 2], name_func=parameterized_custom_name_func) + @require_torchdata_stateful_dataloader + def test_dataloader_state_dict(self, num_workers): + """ + Test that saving a stateful dataloader's state, then loading it back, gives the same results. + """ + dataset = list(range(16)) + dataloader = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers) + + assert dataloader.use_stateful_dataloader + assert isinstance(dataloader, StatefulDataLoader) + vals = [] + for idx, val in enumerate(dataloader): + vals.append(val) + if idx == 1: + sd = dataloader.state_dict() + assert len(vals) == 4 + + dataloader2 = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers) + dataloader2.load_state_dict(sd) + + data1 = vals[2:] + data2 = list(dataloader2) + assert len(data1) == len(data2) + for d1, d2 in zip(data1, data2): + assert torch.allclose(d1, d2) + + @parameterized.expand([0, 2], name_func=parameterized_custom_name_func) + @require_torchdata_stateful_dataloader + def test_dataloader_dispatcher_state_dict(self, num_workers): + """ + Test that saving a stateful dataloader's state, then loading it back, gives the same results. + """ + dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) + Accelerator(dataloader_config=dataloader_config) + dataset = list(range(16)) + dataloader = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers) + + assert dataloader.use_stateful_dataloader + assert isinstance(dataloader, StatefulDataLoader) + vals = [] + for idx, val in enumerate(dataloader): + vals.append(val) + if idx == 1: + sd = dataloader.state_dict() + assert len(vals) == 4 + dataloader2 = DataLoaderDispatcher( + dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers + ) + dataloader2.load_state_dict(sd) + + data1 = vals[2:] + data2 = list(dataloader2) + assert len(data1) == len(data2) + for d1, d2 in zip(data1, data2): + assert torch.allclose(d1, d2) + + @require_torchdata_stateful_dataloader + def test_dataloader_inheritance(self): + """ + `DataLoaderAdapter`'s parent classes are dynamically constructed, assert that if use_stateful_dataloader=True, + subclasses of DataLoaderAdapter are instances of StatefulDataLoader and DataLoaderStateMixin. + """ + Accelerator() + skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True) + dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True) + dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) + assert isinstance(skip_dl, StatefulDataLoader) + assert isinstance(dl_shard, StatefulDataLoader) + assert isinstance(dl_dispatcher, StatefulDataLoader) + + assert isinstance(dl_shard, DataLoaderStateMixin) + assert isinstance(dl_dispatcher, DataLoaderStateMixin) + + assert isinstance(skip_dl.base_dataloader, StatefulDataLoader) + assert isinstance(dl_shard.base_dataloader, StatefulDataLoader) + assert isinstance(dl_dispatcher.base_dataloader, StatefulDataLoader) + + @parameterized.expand([0, 2], name_func=parameterized_custom_name_func) + @require_torchdata_stateful_dataloader + def test_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader(self, num_workers): + """ + Assert that `state_dict()` and `load_state_dict()` for derived subclasses of `DataLoaderAdapter` produce + the same behavior as `state_dict()` and `load_state_dict()` for `StatefulDataLoader`. + """ + dataset = list(range(64)) + + # Set the seed for reproducibility + def g(): + return torch.Generator().manual_seed(42) + + accelerator = Accelerator() + stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g()) + skip_dl = SkipDataLoader( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + dl_shard = DataLoaderShard( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + dl_dispatcher = DataLoaderDispatcher( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + + dataloaders_under_test = [skip_dl, dl_shard, dl_dispatcher] + + num_batches_to_skip = 8 + + def get_first_n_batches(dl, n, device): + """ + Iterate over the first `n` batches of a dataloader then break, returning the batches in a list. + """ + batches = [] + for idx, batch in enumerate(dl): + if idx == n - 1: + if hasattr(dl, "end"): + dl.end() + break + batches.append(batch.to(device)) + return batches + + # Iterate over all of the dataloaders identically, expect the same values + expected_batches = get_first_n_batches(stateful_dl, num_batches_to_skip, accelerator.device) + batches_from_dataloaders = [ + get_first_n_batches(dl, num_batches_to_skip, accelerator.device) for dl in dataloaders_under_test + ] + + for dl_batches in batches_from_dataloaders: + for expected, actual in zip(expected_batches, dl_batches): + assert torch.allclose(expected, actual) + + # The adapters should all produce the same state_dict as the reference stateful dataloader + expected_state_dict = stateful_dl.state_dict() + skip_dl_state_dict = skip_dl.state_dict() + dl_shard_state_dict = dl_shard.state_dict() + dl_dispatcher_state_dict = dl_dispatcher.state_dict() + + assert expected_state_dict == skip_dl_state_dict + assert expected_state_dict == dl_shard_state_dict + assert expected_state_dict == dl_dispatcher_state_dict + + # Load the state dict into new dataloaders + manual_skip_dl = SkipDataLoader( + dataset, + batch_size=4, + num_workers=num_workers, + generator=g(), + skip_batches=num_batches_to_skip, + use_stateful_dataloader=True, + ) + loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g()) + loaded_stateful_dl.load_state_dict(expected_state_dict) + loaded_skip_dl = SkipDataLoader( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + loaded_skip_dl.load_state_dict(expected_state_dict) + loaded_dl_shard = DataLoaderShard( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + loaded_dl_shard.load_state_dict(expected_state_dict) + loaded_dl_dispatcher = DataLoaderDispatcher( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + loaded_dl_dispatcher.load_state_dict(expected_state_dict) + + # Continue the iteration, expecting identical behavior across the board + def get_all_batches(dl, device): + """ + Iterate over all batches of a dataloader, returning (batches, num_batches_yielded) + """ + batches = [] + num_batches_yielded = 0 + for batch in dl: + batches.append(batch.to(device)) + num_batches_yielded += 1 + return (batches, num_batches_yielded) + + expected_batch_results = get_all_batches(loaded_stateful_dl, accelerator.device) + dataloader_batch_results = [ + get_all_batches(dl, accelerator.device) + for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher] + ] + for dl_results in dataloader_batch_results: + for expected, actual in zip(expected_batches, dl_batches): + assert torch.allclose(expected[0], actual[0]) + assert expected_batch_results[1] == dl_results[1] + + assert accelerator.gradient_state.active_dataloader is None