Skip to content

Commit

Permalink
Enforce pyre-strict
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jan 23, 2025
1 parent fcabb37 commit d26558b
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 42 deletions.
5 changes: 3 additions & 2 deletions src/spdl/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from typing import Any

# pyre-unsafe
from ._dataloader import DataLoader
from ._pytorch_dataloader import get_pytorch_dataloader, PyTorchDataLoader

Expand All @@ -18,12 +17,14 @@
"PyTorchDataLoader",
]

# pyre-strict


def __dir__() -> list[str]:
return __all__


def __getattr__(name: str) -> Any:
def __getattr__(name: str) -> Any: # pyre-ignore: [3]
# For backward compatibility
if name == "iterate_in_subprocess":
import warnings
Expand Down
14 changes: 12 additions & 2 deletions src/spdl/pipeline/_bg_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,26 @@

import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from queue import Empty, Queue
from threading import Event, Thread

from ._utils import _get_loop

__all__ = ["BackgroundConsumer"]

_LG = logging.getLogger(__name__)


def _get_loop(num_workers: int | None) -> asyncio.AbstractEventLoop:
loop = asyncio.new_event_loop()
loop.set_default_executor(
ThreadPoolExecutor(
max_workers=num_workers,
thread_name_prefix="spdl_",
)
)
return loop


def _async_executor(loop: asyncio.AbstractEventLoop, queue: Queue, stopped: Event):
tasks = set()

Expand Down
4 changes: 2 additions & 2 deletions src/spdl/pipeline/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ async def _unwrap(task: asyncio.Task[U]) -> U:

inter_queue = AsyncQueue(concurrency)

coro1: Awaitable[None] = _pipe( # pyre-ignore: [1001]
coro1: Coroutine[None, None, None] = _pipe( # pyre-ignore: [1001]
input_queue,
_wrap,
inter_queue,
Expand All @@ -310,7 +310,7 @@ async def _unwrap(task: asyncio.Task[U]) -> U:
hooks=[],
)

coro2: Awaitable[None] = _pipe( # pyre-ignore: [1001]
coro2: Coroutine[None, None, None] = _pipe( # pyre-ignore: [1001]
inter_queue,
_unwrap,
output_queue,
Expand Down
53 changes: 33 additions & 20 deletions src/spdl/pipeline/_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
# pyre-strict

import asyncio
import logging
import time
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence
from asyncio import Task
from collections.abc import AsyncIterator, Callable, Coroutine, Iterator, Sequence
from contextlib import asynccontextmanager, AsyncExitStack, contextmanager
from typing import AsyncContextManager, TypeVar

Expand All @@ -25,7 +26,7 @@
"StatsCounter",
]

_LG = logging.getLogger(__name__)
_LG: logging.Logger = logging.getLogger(__name__)


T = TypeVar("T")
Expand All @@ -39,7 +40,7 @@ def _time_str(val: float) -> str:


class StatsCounter:
def __init__(self):
def __init__(self) -> None:
self.num_items: int = 0
self.ave_time: float = 0.0

Expand All @@ -58,7 +59,7 @@ def count(self) -> Iterator[None]:
elapsed = time.monotonic() - t0
self.update(elapsed)

def __str__(self):
def __str__(self) -> str:
return _time_str(self.ave_time)


Expand Down Expand Up @@ -162,7 +163,7 @@ async def stage_hook(self):
"""

@asynccontextmanager
async def stage_hook(self):
async def stage_hook(self) -> AsyncIterator[None]:
"""Perform custom action when the pipeline stage is initialized and completed.
.. important::
Expand Down Expand Up @@ -207,8 +208,8 @@ async def stask_hook(self):
yield


def _stage_hooks(hooks: Sequence[PipelineHook]):
hs = [hook.stage_hook() for hook in hooks]
def _stage_hooks(hooks: Sequence[PipelineHook]) -> AsyncContextManager[None]:
hs: list[AsyncContextManager[None]] = [hook.stage_hook() for hook in hooks]

if not all(hasattr(h, "__aenter__") and hasattr(h, "__aexit__") for h in hs):
raise ValueError(
Expand All @@ -228,7 +229,7 @@ async def stage_hooks() -> AsyncIterator[None]:


def _task_hooks(hooks: Sequence[PipelineHook]) -> AsyncContextManager[None]:
hs = [hook.task_hook() for hook in hooks]
hs: list[AsyncContextManager[None]] = [hook.task_hook() for hook in hooks]

if not all(hasattr(h, "__aenter__") or hasattr(h, "__aexit__") for h in hs):
raise ValueError(
Expand All @@ -247,8 +248,10 @@ async def task_hooks() -> AsyncIterator[None]:
return task_hooks()


async def _periodic_dispatch(afun, interval):
tasks = set()
async def _periodic_dispatch(
afun: Callable[[], Coroutine[None, None, None]], interval: float
) -> None:
tasks: set[Task] = set()
while True:
await asyncio.sleep(interval)

Expand All @@ -265,7 +268,12 @@ class TaskStatsHook(PipelineHook):
concurrency: Concurrency of the stage. Only used for logging.
"""

def __init__(self, name: str, concurrency: int, interval: float | None = None):
def __init__(
self,
name: str,
concurrency: int,
interval: float | None = None,
) -> None:
self.name = name
self.concurrency = concurrency
self.interval = interval
Expand All @@ -275,21 +283,20 @@ def __init__(self, name: str, concurrency: int, interval: float | None = None):
self.ave_time = 0.0

# For interval
self._int_task = None
self._int_task: Task | None = None
self._int_t0 = 0.0
self._int_num_tasks = 0
self._int_num_success = 0
self._int_ave_time = 0.0

@asynccontextmanager
async def stage_hook(self):
async def stage_hook(self) -> AsyncIterator[None]:
"""Track the stage runtime and log the task stats."""
if self.interval is not None:
coro = _periodic_dispatch(self._log_interval_stats, self.interval)
self._int_t0 = time.monotonic()
self._int_task = create_task(
_periodic_dispatch(self._log_interval_stats, self.interval),
name="periodic_dispatch",
ignore_cancelled=True,
coro, name="periodic_dispatch", ignore_cancelled=True
)

t0 = time.monotonic()
Expand All @@ -302,7 +309,7 @@ async def stage_hook(self):
self._log_stats(elapsed, self.num_tasks, self.num_success, self.ave_time)

@asynccontextmanager
async def task_hook(self):
async def task_hook(self) -> AsyncIterator[None]:
"""Track task runtime and success rate."""
t0 = time.monotonic()
try:
Expand All @@ -319,7 +326,7 @@ async def task_hook(self):
self.num_success += 1
self.ave_time += (elapsed - self.ave_time) / self.num_success

async def _log_interval_stats(self):
async def _log_interval_stats(self) -> None:
t0 = time.monotonic()
num_success = self.num_success
num_tasks = self.num_tasks
Expand All @@ -345,7 +352,13 @@ async def _log_interval_stats(self):
self._int_num_success = num_success
self._int_ave_time = ave_time

def _log_stats(self, elapsed, num_tasks, num_success, ave_time):
def _log_stats(
self,
elapsed: float,
num_tasks: int,
num_success: int,
ave_time: float,
) -> None:
_LG.info(
"[%s]\tCompleted %5d tasks (%3d failed) in %s. "
"QPS: %.2f (Concurrency: %3d). "
Expand Down
27 changes: 11 additions & 16 deletions src/spdl/pipeline/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,23 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
# pyre-strict

import asyncio
import logging
import sys
import traceback
from concurrent.futures import ThreadPoolExecutor
from asyncio import Task
from collections.abc import Coroutine, Generator
from typing import Any, TypeVar

__all__ = [
"create_task",
]

_LG = logging.getLogger(__name__)
_LG: logging.Logger = logging.getLogger(__name__)


def _get_loop(num_workers: int | None) -> asyncio.AbstractEventLoop:
loop = asyncio.new_event_loop()
loop.set_default_executor(
ThreadPoolExecutor(
max_workers=num_workers,
thread_name_prefix="spdl_",
)
)
return loop
T = TypeVar("T")


# Note:
Expand All @@ -41,7 +34,7 @@ def _get_loop(num_workers: int | None) -> asyncio.AbstractEventLoop:
# task was created.
# Otherwise the log will point to the location somewhere deep in `asyncio` module
# which is not very helpful.
def _log_exception(task, stacklevel, ignore_cancelled):
def _log_exception(task: Task, stacklevel: int, ignore_cancelled: bool) -> None:
try:
task.result()
except asyncio.exceptions.CancelledError:
Expand Down Expand Up @@ -69,8 +62,10 @@ def _log_exception(task, stacklevel, ignore_cancelled):


def create_task(
coro, name: str | None = None, ignore_cancelled: bool = True
) -> asyncio.Task:
coro: Coroutine[Any, Any, T] | Generator[Any, None, T], # pyre-ignore: [2]
name: str | None = None,
ignore_cancelled: bool = True,
) -> Task[T]:
"""Wrapper around :py:func:`asyncio.create_task`. Add logging callback."""
task = asyncio.create_task(coro, name=name)
task.add_done_callback(
Expand Down

0 comments on commit d26558b

Please sign in to comment.