Skip to content

Commit

Permalink
fork dataloader __init__ instead of patching samplers (#1281)
Browse files Browse the repository at this point in the history
* fork dataloader __init__ instead of patching samplers

* set super __initialized=True

* add commet

* remove commented code
  • Loading branch information
andrewkho authored Jul 2, 2024
1 parent b0e25e2 commit b421e86
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 29 deletions.
8 changes: 0 additions & 8 deletions torchdata/stateful_dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ def __iter__(self):
return _StatefulRandomSamplerIterator(self, super().__iter__())


torch.utils.data.sampler.RandomSampler = RandomSampler # type: ignore[misc]
torch.utils.data.dataloader.RandomSampler = RandomSampler # type: ignore[misc]


class BatchSampler(torch.utils.data.sampler.BatchSampler, Stateful):
_SAMPLES_YIELDED = "samples_yielded"
_SAMPLER_STATE = "sampler_state"
Expand Down Expand Up @@ -129,7 +125,3 @@ def __iter__(self):
batch = [0] * self.batch_size
if idx_in_batch > 0:
yield batch[:idx_in_batch]


torch.utils.data.sampler.BatchSampler = BatchSampler # type: ignore[misc]
torch.utils.data.dataloader.BatchSampler = BatchSampler # type: ignore[misc]
184 changes: 163 additions & 21 deletions torchdata/stateful_dataloader/stateful_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,19 @@

from torch._utils import ExceptionWrapper

from torch.utils.data import _utils, DataLoader, Dataset, IterDataPipe, MapDataPipe, Sampler
from torch.utils.data import (
_utils,
DataLoader,
Dataset,
IterableDataset,
IterDataPipe,
MapDataPipe,
Sampler,
SequentialSampler,
)

from torch.utils.data.dataloader import _BaseDataLoaderIter
from torch.utils.data.dataloader import _BaseDataLoaderIter, _InfiniteConstantSampler
from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper

from .incremental_state import (
_DATASET_ITER_STATE,
Expand All @@ -46,7 +56,7 @@
_IncrementalWorkerState,
_WORKER_ID,
)
from .sampler import BatchSampler, RandomSampler # noqa
from .sampler import BatchSampler, RandomSampler
from .stateful import Stateful

from .worker import _AckStartup, _worker_loop, try_to_deserialize, try_to_serialize
Expand Down Expand Up @@ -192,25 +202,155 @@ def __init__(
pin_memory_device: str = "",
snapshot_every_n_steps: Optional[int] = 1,
):
torch._C._log_api_usage_once("python.stateful_data_loader")

if num_workers < 0:
raise ValueError(
"num_workers option should be non-negative; " "use num_workers=0 to disable multiprocessing."
)

if timeout < 0:
raise ValueError("timeout option should be non-negative")

if num_workers == 0 and prefetch_factor is not None:
raise ValueError(
"prefetch_factor option could only be specified in multiprocessing."
"let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None."
)
elif num_workers > 0 and prefetch_factor is None:
prefetch_factor = 2
elif prefetch_factor is not None and prefetch_factor < 0:
raise ValueError("prefetch_factor option should be non-negative")

if persistent_workers and num_workers == 0:
raise ValueError("persistent_workers option needs num_workers > 0")

self.dataset = dataset
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
self.pin_memory = pin_memory
self.pin_memory_device = pin_memory_device
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.multiprocessing_context = multiprocessing_context

# Adds forward compatibilities so classic DataLoader can work with DataPipes:
# _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
if isinstance(self.dataset, IterDataPipe):
self.dataset = _IterDataPipeSerializationWrapper(self.dataset)
elif isinstance(self.dataset, MapDataPipe):
self.dataset = _MapDataPipeSerializationWrapper(self.dataset)

# Arg-check dataset related before checking samplers because we want to
# tell users that iterable-style datasets are incompatible with custom
# samplers first, so that they don't learn that this combo doesn't work
# after spending time fixing the custom sampler errors.
if isinstance(dataset, IterableDataset):
self._dataset_kind = _DatasetKind.Iterable
# NOTE [ Custom Samplers and IterableDataset ]
#
# `IterableDataset` does not support custom `batch_sampler` or
# `sampler` since the key is irrelevant (unless we support
# generator-style dataset one day...).
#
# For `sampler`, we always create a dummy sampler. This is an
# infinite sampler even when the dataset may have an implemented
# finite `__len__` because in multi-process data loading, naive
# settings will return duplicated data (which may be desired), and
# thus using a sampler with length matching that of dataset will
# cause data lost (you may have duplicates of the first couple
# batches, but never see anything afterwards). Therefore,
# `Iterabledataset` always uses an infinite sampler, an instance of
# `_InfiniteConstantSampler` defined above.
#
# A custom `batch_sampler` essentially only controls the batch size.
# However, it is unclear how useful it would be since an iterable-style
# dataset can handle that within itself. Moreover, it is pointless
# in multi-process data loading as the assignment order of batches
# to workers is an implementation detail so users can not control
# how to batchify each worker's iterable. Thus, we disable this
# option. If this turns out to be useful in future, we can re-enable
# this, and support custom samplers that specify the assignments to
# specific workers.
if isinstance(dataset, IterDataPipe):
if shuffle is not None:
dataset = torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
# We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
elif shuffle not in {False, None}:
raise ValueError(
f"DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle={shuffle}"
)

if sampler is not None:
# See NOTE [ Custom Samplers and IterableDataset ]
raise ValueError(
f"DataLoader with IterableDataset: expected unspecified sampler option, but got sampler={sampler}"
)
elif batch_sampler is not None:
# See NOTE [ Custom Samplers and IterableDataset ]
raise ValueError(
"DataLoader with IterableDataset: expected unspecified "
f"batch_sampler option, but got batch_sampler={batch_sampler}"
)
else:
shuffle = bool(shuffle)
self._dataset_kind = _DatasetKind.Map

if sampler is not None and shuffle:
raise ValueError("sampler option is mutually exclusive with " "shuffle")

if batch_sampler is not None:
# auto_collation with custom batch_sampler
if batch_size != 1 or shuffle or sampler is not None or drop_last:
raise ValueError(
"batch_sampler option is mutually exclusive " "with batch_size, shuffle, sampler, and " "drop_last"
)
batch_size = None
drop_last = False
elif batch_size is None:
# no auto_collation
if drop_last:
raise ValueError(
"batch_size=None option disables auto-batching " "and is mutually exclusive with drop_last"
)

if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
else:
sampler = SequentialSampler(dataset) # type: ignore[arg-type]

if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)

self.batch_size = batch_size
self.drop_last = drop_last
self.sampler = sampler
self.batch_sampler = batch_sampler
self.generator = generator

if collate_fn is None:
if self._auto_collation:
collate_fn = _utils.collate.default_collate
else:
collate_fn = _utils.collate.default_convert

self.collate_fn = collate_fn
self.persistent_workers = persistent_workers

# set DataLoader's __initialized attribute.
self._DataLoader__initialized = True
self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ]

self._iterator = None

self.check_worker_number_rationality()

super().__init__(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context,
generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory_device=pin_memory_device,
)
self.snapshot_every_n_steps = snapshot_every_n_steps
self.next_iter_state: Optional[Dict[str, Any]] = None
# When a state_dict is requested before __iter__ is called,
Expand All @@ -219,6 +359,8 @@ def __init__(
# iterator on the next __iter__ call, and this flag is used for those cases.
self._initial_iter_for_state_dict = False

torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined]

def _get_iterator(self) -> "_StatefulBaseDataLoaderIter":
it: _StatefulBaseDataLoaderIter
if self.num_workers == 0:
Expand Down

0 comments on commit b421e86

Please sign in to comment.