Skip to content

Commit

Permalink
Add iterate_in_subprocess and deprecate run_in_subprocess
Browse files Browse the repository at this point in the history
`run_in_subprocess` is not descriptive enough
  • Loading branch information
mthrok committed Jan 3, 2025
1 parent bcf57bb commit a04787f
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 15 deletions.
8 changes: 8 additions & 0 deletions src/spdl/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,12 @@ def __getattr__(name: str) -> Any:
)
return getattr(spdl.pipeline, name)

if name == "run_in_subprocess":
warnings.warn(
"`run_in_subprocess` has been deprecated. "
"Use `iterate_in_subprocess` instead.",
stacklevel=2,
)
return _iterators.run_in_subprocess

raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
45 changes: 30 additions & 15 deletions src/spdl/dataloader/_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from typing import Any, TypeVar

__all__ = ["run_in_subprocess", "MergeIterator"]
__all__ = ["iterate_in_subprocess", "MergeIterator"]

T = TypeVar("T")

Expand All @@ -30,7 +30,7 @@
# pyre-strict

################################################################################
# run_in_subprocess
# iterate_in_subprocess
################################################################################

# Message from parent to worker
Expand All @@ -45,12 +45,10 @@
def _execute_iterator(
msg_queue: mp.Queue,
data_queue: mp.Queue,
fn: Callable[..., Iterator[T]],
args: tuple[...] | None,
kwargs: dict[str, Any] | None,
fn: Callable[[], Iterator[T]],
) -> None:
try:
gen = fn(*(args or ()), **(kwargs or {}))
gen = iter(fn())
except Exception:
msg_queue.put(_MSG_GENERATOR_FAILED)
raise
Expand Down Expand Up @@ -81,10 +79,8 @@ def _execute_iterator(
return


def run_in_subprocess(
fn: Callable[..., Iterator[T]],
args: tuple[...] | None = None,
kwargs: dict[str, Any] | None = None,
def iterate_in_subprocess(
fn: Callable[[], Iterable[T]],
queue_size: int = 64,
mp_context: str = "forkserver",
timeout: float | None = None,
Expand All @@ -93,9 +89,8 @@ def run_in_subprocess(
"""Run an iterator in a separate process, and yield the results one by one.
Args:
fn: Generator function.
args: Arguments to pass to the generator function.
kwargs: Keyword arguments to pass to the generator function.
fn: Function that returns an iterator. Use :py:func:`functools.partial` to
pass arguments to the function.
queue_size: Maximum number of items to buffer in the queue.
mp_context: Context to use for multiprocessing.
timeout: Timeout for inactivity. If the generator function does not yield
Expand All @@ -107,7 +102,7 @@ def run_in_subprocess(
.. note::
The generator function, its arguments and the result of generator must be picklable.
The function and the values yielded by the iterator of generator must be picklable.
"""
ctx = mp.get_context(mp_context)
msg_q = ctx.Queue()
Expand All @@ -119,7 +114,7 @@ def _drain() -> Iterator[T]:

process = ctx.Process(
target=_execute_iterator,
args=(msg_q, data_q, fn, args, kwargs),
args=(msg_q, data_q, fn),
daemon=daemon,
)
process.start()
Expand Down Expand Up @@ -181,6 +176,26 @@ def _drain() -> Iterator[T]:
_LG.warning("Failed to kill the worker process.")


def run_in_subprocess(
fn: Callable[..., Iterable[T]],
args: tuple[...] | None = None,
kwargs: dict[str, Any] | None = None,
queue_size: int = 64,
mp_context: str = "forkserver",
timeout: float | None = None,
daemon: bool = False,
) -> Iterator[T]:
from functools import partial

return iterate_in_subprocess(
fn=partial(fn, *(args or ()), **(kwargs or {})),
queue_size=queue_size,
mp_context=mp_context,
timeout=timeout,
daemon=daemon,
)


################################################################################
# MergeIterator
################################################################################
Expand Down

0 comments on commit a04787f

Please sign in to comment.