Skip to content

Commit

Permalink
initial fix for breaking accelerator pickling
Browse files Browse the repository at this point in the history
  • Loading branch information
byi8220 committed Sep 3, 2024
1 parent b5235f2 commit 42a253b
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 20 deletions.
85 changes: 65 additions & 20 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,25 +416,6 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, *
else:
self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)

# Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641
# In C++ terms, this is analogous to creating `DataLoaderAdapter<T> : T`, where T is a DataLoader or
# StatefulDataLoader
#
# The same functionality could be achieved by directly creating the required subclasses for both {DataLoader,
# StatefulDataLoader}, however that could lead to much messier code, with duplicated classes and conditional
# dispatching scattered throughout various functions and files.
#
# This code is incredibly awkward but it's the only way to make `isinstance(obj, StatefulDataLoader)` work
# transparently.
#
# A more robust solution is for DataLoaderAdapter to not inherit from DataLoader (compose rather than inherit),
# but this would not be backwards compatible with existing code which assumes
# DataLoaderShard/DataLoaderDispatcher are DataLoaders.
base_cls = self.__class__
base_cls_name = self.__class__.__name__
parent_cls_name = self.base_dataloader.__class__
self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {})

if hasattr(self.base_dataloader, "state_dict"):
self.dl_state_dict = self.base_dataloader.state_dict()

Expand All @@ -451,6 +432,18 @@ def state_dict(self):
def load_state_dict(self, state_dict):
self.base_dataloader.load_state_dict(state_dict)

@property
def __class__(self):
"""
In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)`
returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the
object.
"""
return self.base_dataloader.__class__

def __len__(self):
return len(self.base_dataloader)

def adjust_state_dict_for_prefetch(self):
"""
Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
Expand Down Expand Up @@ -488,6 +481,17 @@ def _update_state_dict(self):
self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader


class DataLoaderAdapterImpl(DataLoaderAdapter, DataLoader):
pass


if is_torchdata_stateful_dataloader_available():
from torchdata.stateful_dataloader import StatefulDataLoader

class StatefulDataLoaderAdapterImpl(DataLoaderAdapter, StatefulDataLoader):
pass


class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
"""
Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.
Expand Down Expand Up @@ -580,6 +584,22 @@ def __iter__(self):
self.iteration += 1
self.end()

def __reduce__(self):
return (
DataLoaderShard,
(
self.base_dataloader.dataset,
self.device,
self.rng_types,
self.synchronized_generator,
self.skip_batches,
self.use_stateful_dataloader,
self._drop_last,
self._non_blocking,
),
self.__dict__,
)

def set_epoch(self, epoch: int):
# In case it is manually passed in, the user can set it to what they like
if self.iteration != epoch:
Expand Down Expand Up @@ -865,14 +885,29 @@ def set_epoch(self, epoch: int):
self.dataset.set_epoch(epoch)

def __len__(self):
whole_length = super().__len__()
whole_length = self.base_dataloader.__len__()
if self.split_batches:
return whole_length
elif self._drop_last:
return whole_length // self.state.num_processes
else:
return math.ceil(whole_length / self.state.num_processes)

def __reduce__(self):
return (
DataLoaderDispatcher,
(
self.base_dataloader.dataset,
self.split_batches,
self.skip_batches,
self.use_stateful_dataloader,
self._drop_last,
self._non_blocking,
self.slice_fn,
),
self.__dict__,
)

@property
def total_batch_size(self):
return (
Expand Down Expand Up @@ -1211,6 +1246,16 @@ def __iter__(self):
yield batch
self.end()

def __len__(self):
return len(self.base_dataloader) - self.skip_batches

def __reduce__(self):
return (
SkipDataLoader,
(self.base_dataloader.dataset, self.skip_batches, self.use_stateful_dataloader),
self.__dict__,
)


def skip_first_batches(dataloader, num_batches=0):
"""
Expand Down
34 changes: 34 additions & 0 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
from accelerate.accelerator import Accelerator
from accelerate.data_loader import DataLoaderDispatcher, DataLoaderShard, skip_first_batches
from accelerate.state import GradientState, PartialState
from accelerate.test_utils import (
require_bnb,
Expand Down Expand Up @@ -647,6 +648,39 @@ def test_can_unwrap_model(self):
model_loaded = pickle.loads(pickle.dumps(model))
model_loaded(inputs)

@parameterized.expand([True, False])
def test_can_pickle_dataloader(self, dispatch_batches):
"""
Test that pickling a prepared dataloader works.
"""
data = torch.arange(10)
ds = torch.utils.data.TensorDataset(data)
dl = torch.utils.data.DataLoader(ds)
# Currently, StatefulDataLoader doesn't seem to support pickling, so we aren't testing that functionality
# TODO: Add support for pickling StatefulDataLoader
dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=False)
accelerator = Accelerator(dataloader_config=dataloader_config)
original_dl = accelerator.prepare(dl)
prepared_model_dumps = pickle.dumps(accelerator)

model_loaded = pickle.loads(prepared_model_dumps)
# Assert equality of recovered and original dataloader
assert isinstance(model_loaded._dataloaders[0], DataLoader)
if dispatch_batches:
assert isinstance(model_loaded._dataloaders[0], DataLoaderDispatcher)
else:
assert isinstance(model_loaded._dataloaders[0], DataLoaderShard)
assert len(model_loaded._dataloaders[0]) == len(original_dl)
assert [i for i in model_loaded._dataloaders[0]] == [i for i in original_dl]

# Test skip dataloader works as expected as well
skip_dl = skip_first_batches(original_dl, 2)
assert isinstance(skip_dl, torch.utils.data.DataLoader)
assert len(skip_dl) == len(original_dl) - 2
orig_items = [i for i in original_dl]
skip_dl_items = [i for i in skip_dl]
assert orig_items[2:] == skip_dl_items

# Ideally would be a parameterized test which works with either stateful or non-stateful dataloaders, but dependencies are a bit awkward.
@require_torchdata_stateful_dataloader
def test_prepared_objects_are_referenced_with_stateful_dataloader(self):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,14 @@ def test_dataloader_inheritance(self):
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2)
dl_shard = DataLoaderShard(range(16), batch_size=4)
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4)

# Test dataloaders are instances of instantiated classes
# These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__
assert isinstance(skip_dl, SkipDataLoader)
assert isinstance(dl_shard, DataLoaderShard)
assert isinstance(dl_dispatcher, DataLoaderDispatcher)

# Test dataloaders are instances of base classes
assert isinstance(skip_dl, DataLoader)
assert isinstance(dl_shard, DataLoader)
assert isinstance(dl_dispatcher, DataLoader)
Expand Down Expand Up @@ -556,6 +564,13 @@ def test_dataloader_inheritance(self):
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True)
dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True)
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)

# Test dataloaders are instances of instantiated classes
# These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__
assert isinstance(skip_dl, SkipDataLoader)
assert isinstance(dl_shard, DataLoaderShard)
assert isinstance(dl_dispatcher, DataLoaderDispatcher)

assert isinstance(skip_dl, StatefulDataLoader)
assert isinstance(dl_shard, StatefulDataLoader)
assert isinstance(dl_dispatcher, StatefulDataLoader)
Expand Down

0 comments on commit 42a253b

Please sign in to comment.