Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
byi8220 committed Sep 3, 2024
1 parent 42a253b commit 364a7fd
Showing 1 changed file with 2 additions and 15 deletions.
17 changes: 2 additions & 15 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,8 @@ def load_state_dict(self, state_dict):
@property
def __class__(self):
"""
In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)`
returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the
object.
In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)` returs true.
This is because some downstream code assumes that the `DataLoader` is the base class of the object.
"""
return self.base_dataloader.__class__

Expand Down Expand Up @@ -480,18 +479,6 @@ def _update_state_dict(self):
# Then tag if we are at the end of the dataloader
self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader


class DataLoaderAdapterImpl(DataLoaderAdapter, DataLoader):
pass


if is_torchdata_stateful_dataloader_available():
from torchdata.stateful_dataloader import StatefulDataLoader

class StatefulDataLoaderAdapterImpl(DataLoaderAdapter, StatefulDataLoader):
pass


class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
"""
Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.
Expand Down

0 comments on commit 364a7fd

Please sign in to comment.