Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fork dataloader __init__ instead of patching samplers #1281

Merged
merged 4 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions torchdata/stateful_dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ def __iter__(self):
return _StatefulRandomSamplerIterator(self, super().__iter__())


torch.utils.data.sampler.RandomSampler = RandomSampler # type: ignore[misc]
torch.utils.data.dataloader.RandomSampler = RandomSampler # type: ignore[misc]


class BatchSampler(torch.utils.data.sampler.BatchSampler, Stateful):
_SAMPLES_YIELDED = "samples_yielded"
_SAMPLER_STATE = "sampler_state"
Expand Down Expand Up @@ -129,7 +125,3 @@ def __iter__(self):
batch = [0] * self.batch_size
if idx_in_batch > 0:
yield batch[:idx_in_batch]


torch.utils.data.sampler.BatchSampler = BatchSampler # type: ignore[misc]
torch.utils.data.dataloader.BatchSampler = BatchSampler # type: ignore[misc]
184 changes: 163 additions & 21 deletions torchdata/stateful_dataloader/stateful_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,19 @@

from torch._utils import ExceptionWrapper

from torch.utils.data import _utils, DataLoader, Dataset, IterDataPipe, MapDataPipe, Sampler
from torch.utils.data import (
_utils,
DataLoader,
Dataset,
IterableDataset,
IterDataPipe,
MapDataPipe,
Sampler,
SequentialSampler,
)

from torch.utils.data.dataloader import _BaseDataLoaderIter
from torch.utils.data.dataloader import _BaseDataLoaderIter, _InfiniteConstantSampler
from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper

from .incremental_state import (
_DATASET_ITER_STATE,
Expand All @@ -46,7 +56,7 @@
_IncrementalWorkerState,
_WORKER_ID,
)
from .sampler import BatchSampler, RandomSampler # noqa
from .sampler import BatchSampler, RandomSampler
from .stateful import Stateful

from .worker import _AckStartup, _worker_loop, try_to_deserialize, try_to_serialize
Expand Down Expand Up @@ -192,25 +202,155 @@ def __init__(
pin_memory_device: str = "",
snapshot_every_n_steps: Optional[int] = 1,
):
torch._C._log_api_usage_once("python.stateful_data_loader")

if num_workers < 0:
raise ValueError(
"num_workers option should be non-negative; " "use num_workers=0 to disable multiprocessing."
)

if timeout < 0:
raise ValueError("timeout option should be non-negative")

if num_workers == 0 and prefetch_factor is not None:
raise ValueError(
"prefetch_factor option could only be specified in multiprocessing."
"let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None."
)
elif num_workers > 0 and prefetch_factor is None:
prefetch_factor = 2
elif prefetch_factor is not None and prefetch_factor < 0:
raise ValueError("prefetch_factor option should be non-negative")

if persistent_workers and num_workers == 0:
raise ValueError("persistent_workers option needs num_workers > 0")

self.dataset = dataset
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
self.pin_memory = pin_memory
self.pin_memory_device = pin_memory_device
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.multiprocessing_context = multiprocessing_context

# Adds forward compatibilities so classic DataLoader can work with DataPipes:
# _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
if isinstance(self.dataset, IterDataPipe):
self.dataset = _IterDataPipeSerializationWrapper(self.dataset)
elif isinstance(self.dataset, MapDataPipe):
self.dataset = _MapDataPipeSerializationWrapper(self.dataset)

# Arg-check dataset related before checking samplers because we want to
# tell users that iterable-style datasets are incompatible with custom
# samplers first, so that they don't learn that this combo doesn't work
# after spending time fixing the custom sampler errors.
if isinstance(dataset, IterableDataset):
self._dataset_kind = _DatasetKind.Iterable
# NOTE [ Custom Samplers and IterableDataset ]
#
# `IterableDataset` does not support custom `batch_sampler` or
# `sampler` since the key is irrelevant (unless we support
# generator-style dataset one day...).
#
# For `sampler`, we always create a dummy sampler. This is an
# infinite sampler even when the dataset may have an implemented
# finite `__len__` because in multi-process data loading, naive
# settings will return duplicated data (which may be desired), and
# thus using a sampler with length matching that of dataset will
# cause data lost (you may have duplicates of the first couple
# batches, but never see anything afterwards). Therefore,
# `Iterabledataset` always uses an infinite sampler, an instance of
# `_InfiniteConstantSampler` defined above.
#
# A custom `batch_sampler` essentially only controls the batch size.
# However, it is unclear how useful it would be since an iterable-style
# dataset can handle that within itself. Moreover, it is pointless
# in multi-process data loading as the assignment order of batches
# to workers is an implementation detail so users can not control
# how to batchify each worker's iterable. Thus, we disable this
# option. If this turns out to be useful in future, we can re-enable
# this, and support custom samplers that specify the assignments to
# specific workers.
if isinstance(dataset, IterDataPipe):
if shuffle is not None:
dataset = torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
# We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
elif shuffle not in {False, None}:
raise ValueError(
f"DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle={shuffle}"
)

if sampler is not None:
# See NOTE [ Custom Samplers and IterableDataset ]
raise ValueError(
f"DataLoader with IterableDataset: expected unspecified sampler option, but got sampler={sampler}"
)
elif batch_sampler is not None:
# See NOTE [ Custom Samplers and IterableDataset ]
raise ValueError(
"DataLoader with IterableDataset: expected unspecified "
f"batch_sampler option, but got batch_sampler={batch_sampler}"
)
else:
shuffle = bool(shuffle)
self._dataset_kind = _DatasetKind.Map

if sampler is not None and shuffle:
raise ValueError("sampler option is mutually exclusive with " "shuffle")

if batch_sampler is not None:
# auto_collation with custom batch_sampler
if batch_size != 1 or shuffle or sampler is not None or drop_last:
raise ValueError(
"batch_sampler option is mutually exclusive " "with batch_size, shuffle, sampler, and " "drop_last"
)
batch_size = None
drop_last = False
elif batch_size is None:
# no auto_collation
if drop_last:
raise ValueError(
"batch_size=None option disables auto-batching " "and is mutually exclusive with drop_last"
)

if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
else:
sampler = SequentialSampler(dataset) # type: ignore[arg-type]

if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)

self.batch_size = batch_size
self.drop_last = drop_last
self.sampler = sampler
self.batch_sampler = batch_sampler
self.generator = generator

if collate_fn is None:
if self._auto_collation:
collate_fn = _utils.collate.default_collate
else:
collate_fn = _utils.collate.default_convert

self.collate_fn = collate_fn
self.persistent_workers = persistent_workers

# set DataLoader's __initialized attribute.
self._DataLoader__initialized = True
self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ]

self._iterator = None

self.check_worker_number_rationality()

super().__init__(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context,
generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory_device=pin_memory_device,
)
self.snapshot_every_n_steps = snapshot_every_n_steps
self.next_iter_state: Optional[Dict[str, Any]] = None
# When a state_dict is requested before __iter__ is called,
Expand All @@ -219,6 +359,8 @@ def __init__(
# iterator on the next __iter__ call, and this flag is used for those cases.
self._initial_iter_for_state_dict = False

torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined]

def _get_iterator(self) -> "_StatefulBaseDataLoaderIter":
it: _StatefulBaseDataLoaderIter
if self.num_workers == 0:
Expand Down
Loading