Skip to content

Commit

Permalink
Fix potential deadlock in PyTorchDataLoader (#305)
Browse files Browse the repository at this point in the history
`ProcessPoolExecutor` should be used as context manager.
  • Loading branch information
mthrok authored Dec 31, 2024
1 parent 8d252a2 commit bcf57bb
Showing 1 changed file with 80 additions and 73 deletions.
153 changes: 80 additions & 73 deletions src/spdl/dataloader/_pytorch_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pickle
import time
from collections.abc import Callable, Iterable, Iterator
from concurrent.futures import Executor, ProcessPoolExecutor
from concurrent.futures import ProcessPoolExecutor
from multiprocessing.shared_memory import SharedMemory
from types import ModuleType
from typing import cast, Sized, TYPE_CHECKING, TypeVar
Expand All @@ -31,11 +31,71 @@
K = TypeVar("K")
T = TypeVar("T")
U = TypeVar("U")
V = TypeVar("V")

_LG: logging.Logger = logging.getLogger(__name__)


class PyTorchDataLoader(Iterable[U]):
################################################################################
# ProcessExecutor
################################################################################

_DATASET: "torch.utils.data.dataset.Dataset[T]" = None # pyre-ignore: [15]
_COLLATE_FN: Callable = None # pyre-ignore: [15]


def _get_item(index: K) -> ...:
global _DATASET, _COLLATE_FN
return _COLLATE_FN(_DATASET[index])


def _get_items(indices: list[K]) -> ...:
global _DATASET, _COLLATE_FN
if hasattr(_DATASET, "__getitems__"):
return _COLLATE_FN(_DATASET.__getitems__(indices)) # pyre-ignore: [16]
return _COLLATE_FN([_DATASET[index] for index in indices])


def _init_dataset(name: str, collate_fn: Callable) -> None:
_LG.info("[%s] Initializing dataset.", os.getpid())
shmem = SharedMemory(name=name)
global _DATASET, _COLLATE_FN
_DATASET = pickle.loads(shmem.buf)
_COLLATE_FN = collate_fn


def _get_executor(
name: str,
collate_fn: Callable[[list[T]], U],
num_workers: int,
mp_ctx: mp.context.BaseContext,
) -> ProcessPoolExecutor:
executor = ProcessPoolExecutor(
max_workers=num_workers,
mp_context=mp_ctx,
initializer=_init_dataset,
initargs=(name, collate_fn),
)
return executor


def _serialize_dataset(dataset: "torch.utils.data.dataset.Dataset[T]") -> SharedMemory:
_LG.info("Serializing dataset.")
t0 = time.monotonic()
data = pickle.dumps(dataset)
shmem = SharedMemory(create=True, size=len(data))
shmem.buf[:] = data
elapsed = time.monotonic() - t0
_LG.info(
"Written dataset into shared memory %s (%s MB) in %.2f seconds",
shmem.name,
f"{len(data) // 1_000_000:_d}",
elapsed,
)
return shmem


class PyTorchDataLoader(Iterable[V]):
"""PyTorchDataLoader()
A PyTorch-style data loader that works on map-style dataset.
Expand Down Expand Up @@ -67,7 +127,8 @@ def __init__(
shmem: SharedMemory, # to keep the reference alive
sampler: "torch.utils.data.sampler.Sampler[K]",
fetch_fn: Callable[[K], U] | Callable[[list[K]], U],
executor: Executor,
collate_fn: Callable[[list[T]], U],
mp_ctx: mp.context.BaseContext,
num_workers: int,
timeout: float | None,
buffer_size: int,
Expand All @@ -77,7 +138,8 @@ def __init__(
self._shmem: SharedMemory = shmem
self._sampler = sampler
self._fetch_fn = fetch_fn
self._executor = executor
self._collate_fn = collate_fn
self._mp_ctx = mp_ctx
self._num_workers = num_workers
self._buffer_size = buffer_size
self._timeout = timeout
Expand All @@ -88,86 +150,31 @@ def __len__(self) -> int:
return len(cast(Sized, self._sampler))

def _get_pipeline(self) -> Pipeline:
return (
executor = _get_executor(
self._shmem.name, self._collate_fn, self._num_workers, self._mp_ctx
)
pipeline = (
PipelineBuilder()
.add_source(self._sampler)
.pipe(
self._fetch_fn,
executor=self._executor,
executor=executor,
output_order=self._output_order,
concurrency=self._num_workers,
)
.add_sink(self._buffer_size)
.build(num_threads=1)
)
return executor, pipeline

def __iter__(self) -> Iterator[U]:
def __iter__(self) -> Iterator[V]:
"""Iterate on the dataset and yields samples/batches."""
pipeline = self._get_pipeline()
with pipeline.auto_stop():
executor, pipeline = self._get_pipeline()
with executor, pipeline.auto_stop():
for item in pipeline.get_iterator(timeout=self._timeout):
yield item


################################################################################
# ProcessExecutor
################################################################################

_DATASET: "torch.utils.data.dataset.Dataset[T]" = None # pyre-ignore: [15]
_COLLATE_FN: Callable = None # pyre-ignore: [15]


def _get_item(index: K) -> ...:
global _DATASET, _COLLATE_FN
return _COLLATE_FN(_DATASET[index])


def _get_items(indices: list[K]) -> ...:
global _DATASET, _COLLATE_FN
if hasattr(_DATASET, "__getitems__"):
return _COLLATE_FN(_DATASET.__getitems__(indices)) # pyre-ignore: [16]
return _COLLATE_FN([_DATASET[index] for index in indices])


def _init_dataset(name: str, collate_fn: Callable) -> None:
_LG.info("[%s] Initializing dataset.", os.getpid())
shmem = SharedMemory(name=name)
global _DATASET, _COLLATE_FN
_DATASET = pickle.loads(shmem.buf)
_COLLATE_FN = collate_fn


def _get_executor(
name: str,
collate_fn: Callable[[list[T]], U],
num_workers: int,
mp_ctx: mp.context.BaseContext,
) -> ProcessPoolExecutor:
executor = ProcessPoolExecutor(
max_workers=num_workers,
mp_context=mp_ctx,
initializer=_init_dataset,
initargs=(name, collate_fn),
)
return executor


def _serialize_dataset(dataset: "torch.utils.data.dataset.Dataset[T]") -> SharedMemory:
_LG.info("Serializing dataset.")
t0 = time.monotonic()
data = pickle.dumps(dataset)
shmem = SharedMemory(create=True, size=len(data))
shmem.buf[:] = data
elapsed = time.monotonic() - t0
_LG.info(
"Written dataset into shared memory %s (%s bytes) in %.2f seconds",
shmem.name,
f"{len(data):_d}",
elapsed,
)
return shmem


################################################################################
# resolve sampler, fetch and collate
################################################################################
Expand Down Expand Up @@ -263,7 +270,7 @@ def get_pytorch_dataloader(
generator: "torch.Generator | None" = None,
*,
prefetch_factor: int = 2,
persistent_workers: bool = True,
persistent_workers: bool = False,
pin_memory_device: str | None = None,
) -> PyTorchDataLoader[U]:
from torch.utils.data.dataloader import IterableDataset
Expand All @@ -280,8 +287,8 @@ def get_pytorch_dataloader(
if pin_memory_device is not None:
raise ValueError("`pin_memory_device` is not supported (yet).")

if not persistent_workers:
raise ValueError("`persistent_workers=False` is not supported. ")
if persistent_workers:
raise ValueError("`persistent_workers` is not supported.")

if timeout is not None and timeout < 0:
raise ValueError(f"`timeout` must be positive. Found: {timeout}.")
Expand Down Expand Up @@ -309,14 +316,14 @@ def get_pytorch_dataloader(
)
_LG.info("Using multiprocessing context: %s", mp_ctx.get_start_method())
shmem = _serialize_dataset(dataset)
executor = _get_executor(shmem.name, _collate_fn, num_workers, mp_ctx)

return PyTorchDataLoader(
dataset=dataset,
shmem=shmem,
sampler=_sampler,
fetch_fn=_fetch_fn,
executor=executor,
collate_fn=_collate_fn,
mp_ctx=mp_ctx,
num_workers=num_workers,
timeout=timeout,
buffer_size=buffer_size,
Expand Down

0 comments on commit bcf57bb

Please sign in to comment.