Skip to content

Commit

Permalink
Fix most of typing in pipeline builder
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jan 23, 2025
1 parent 686c00a commit 90751c6
Showing 1 changed file with 58 additions and 51 deletions.
109 changes: 58 additions & 51 deletions src/spdl/pipeline/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
# pyre-strict

import asyncio
import inspect
Expand All @@ -18,15 +18,15 @@
Callable,
Coroutine,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import asynccontextmanager, contextmanager
from functools import partial
from typing import TypeVar
from typing import Any, AsyncGenerator, Generic, TypeVar

from . import _convert
from ._convert import _to_async_gen, Callables
from ._convert import _to_async_gen, Callables, convert_to_async
from ._hook import (
_stage_hooks,
_task_hooks,
Expand All @@ -40,18 +40,18 @@

__all__ = ["PipelineFailure", "PipelineBuilder", "_get_op_name"]

_LG = logging.getLogger(__name__)
_LG: logging.Logger = logging.getLogger(__name__)

T = TypeVar("T")
U = TypeVar("U")


# Sentinel objects used to instruct AsyncPipeline to take special actions.
class _Sentinel:
def __init__(self, name):
def __init__(self, name: str) -> None:
self.name = name

def __str__(self):
def __str__(self) -> str:
return self.name


Expand Down Expand Up @@ -98,7 +98,7 @@ def __str__(self):
# unless the task was cancelled.
#
@asynccontextmanager
async def _put_eof_when_done(queue):
async def _put_eof_when_done(queue: AsyncQueue) -> AsyncGenerator[None, None]:
# Note:
# `asyncio.CancelledError` is a subclass of BaseException, so it won't be
# caught in the following, and EOF won't be passed to the output queue.
Expand Down Expand Up @@ -145,7 +145,9 @@ def _pipe(
else hooks
)

afunc = _convert.convert_to_async(op, executor)
afunc: Callable[[T], Awaitable[U]] = ( # pyre-ignore: [9]
convert_to_async(op, executor)
)

if inspect.iscoroutinefunction(afunc):

Expand Down Expand Up @@ -200,7 +202,7 @@ async def _wrap(coro: AsyncIterator[U]) -> None:

@_put_eof_when_done(output_queue)
@_stage_hooks(hooks)
async def pipe():
async def pipe() -> None:
i, tasks = 0, set()
while True:
item = await input_queue.get()
Expand Down Expand Up @@ -273,7 +275,7 @@ def _ordered_pipe(
if concurrency < 1:
raise ValueError("`concurrency` value must be >= 1")

hooks = (
hooks_: Sequence[PipelineHook] = (
[TaskStatsHook(name, concurrency, interval=report_stats_interval)]
if hooks is None
else hooks
Expand All @@ -282,11 +284,13 @@ def _ordered_pipe(
# This has been checked in `PipelineBuilder.pipe()`
assert not inspect.isasyncgenfunction(op)

afunc = _convert.convert_to_async(op, executor)
afunc: Callable[[T], Awaitable[U]] = ( # pyre-ignore: [9]
convert_to_async(op, executor)
)

async def _wrap(item: T) -> asyncio.Task[U]:
async def _with_hooks():
async with _task_hooks(hooks):
async def _with_hooks() -> U:
async with _task_hooks(hooks_):
return await afunc(item)

return create_task(_with_hooks())
Expand All @@ -296,7 +300,7 @@ async def _unwrap(task: asyncio.Task[U]) -> U:

inter_queue = AsyncQueue(concurrency)

coro1 = _pipe( # pyre-ignore: [1001]
coro1: Awaitable[None] = _pipe( # pyre-ignore: [1001]
input_queue,
_wrap,
inter_queue,
Expand All @@ -306,7 +310,7 @@ async def _unwrap(task: asyncio.Task[U]) -> U:
hooks=[],
)

coro2 = _pipe( # pyre-ignore: [1001]
coro2: Awaitable[None] = _pipe( # pyre-ignore: [1001]
inter_queue,
_unwrap,
output_queue,
Expand All @@ -317,10 +321,9 @@ async def _unwrap(task: asyncio.Task[U]) -> U:
)

@_put_eof_when_done(output_queue)
@_stage_hooks(hooks)
async def ordered_pipe():
tasks = {create_task(coro1), create_task(coro2)}
await asyncio.wait(tasks)
@_stage_hooks(hooks_)
async def ordered_pipe() -> None:
await asyncio.wait({create_task(coro1), create_task(coro2)})

return ordered_pipe()

Expand All @@ -335,13 +338,14 @@ def _enqueue(
queue: AsyncQueue[T],
max_items: int | None = None,
) -> Coroutine:
if not hasattr(src, "__aiter__"):
src = _to_async_gen(iter, None)(src)
src_: AsyncIterable[T] = ( # pyre-ignore: [9]
src if hasattr(src, "__aiter__") else _to_async_gen(iter, None)(src)
)

@_put_eof_when_done(queue)
async def enqueue():
async def enqueue() -> None:
num_items = 0
async for item in src:
async for item in src_:
if item is not _SKIP:
await queue.put(item)
num_items += 1
Expand All @@ -357,7 +361,7 @@ async def enqueue():


@contextmanager
def _sink_stats():
def _sink_stats() -> Iterator[tuple[StatsCounter, StatsCounter]]:
get_counter = StatsCounter()
put_counter = StatsCounter()
t0 = time.monotonic()
Expand All @@ -377,7 +381,7 @@ def _sink_stats():
)


async def _sink(input_queue: AsyncQueue[T], output_queue: AsyncQueue[T]):
async def _sink(input_queue: AsyncQueue[T], output_queue: AsyncQueue[T]) -> None:
with _sink_stats() as (get_counter, put_counter):
while True:
with get_counter.count():
Expand All @@ -404,7 +408,7 @@ class PipelineFailure(RuntimeError):
Thrown by :py:class:`spdl.pipeline.Pipeline` when pipeline encounters an error.
"""

def __init__(self, errs):
def __init__(self, errs: dict[str, Exception]) -> None:
msg = []
for k, v in errs.items():
e = str(v)
Expand Down Expand Up @@ -439,7 +443,7 @@ async def _run_pipeline_coroutines(
# demonstrate the behavior.
# https://gist.github.com/mthrok/3a1c11c2d8012e29f4835679ac0baaee
try:
done, pending = await asyncio.wait(
_, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_EXCEPTION
)
except asyncio.CancelledError:
Expand Down Expand Up @@ -475,30 +479,32 @@ async def _run_pipeline_coroutines(
################################################################################


def disaggregate(items):
def disaggregate(items: Sequence[T]) -> Iterator[T]:
for item in items:
yield item


class PipelineBuilder:
class PipelineBuilder(Generic[T]):
"""Build :py:class:`~spdl.pipeline.Pipeline` object.
See :py:class:`~spdl.pipeline.Pipeline` for details.
"""

def __init__(self):
self._source = None
def __init__(self) -> None:
self._source: Iterable | AsyncIterable | None = None
self._source_buffer_size = 1

self._process_args: list[tuple[str, dict, int]] = []
self._process_args: list[tuple[str, dict[str, Any], int]] = []

self._sink_buffer_size = None
self._sink_buffer_size: int | None = None
self._num_aggregate = 0
self._num_disaggregate = 0

def add_source(
self, source: Iterable[T] | AsyncIterable[T], **_kwargs
) -> "PipelineBuilder":
self,
source: Iterable[T] | AsyncIterable[T],
**_kwargs, # pyre-ignore: [2]
) -> "PipelineBuilder[T]":
"""Attach an iterator to the source buffer.
.. code-block::
Expand Down Expand Up @@ -543,8 +549,8 @@ def pipe(
report_stats_interval: float | None = None,
output_order: str = "completion",
kwargs: dict[str, ...] | None = None,
**_kwargs,
) -> "PipelineBuilder":
**_kwargs, # pyre-ignore: [2]
) -> "PipelineBuilder[T]":
"""Apply an operation to items in the pipeline.
.. code-block::
Expand Down Expand Up @@ -673,7 +679,7 @@ def aggregate(
drop_last: bool = False,
hooks: Sequence[PipelineHook] | None = None,
report_stats_interval: float | None = None,
) -> "PipelineBuilder":
) -> "PipelineBuilder[T]":
"""Buffer the items in the pipeline.
Args:
Expand All @@ -682,17 +688,17 @@ def aggregate(
hooks: See :py:meth:`pipe`.
report_stats_interval: See :py:meth:`pipe`.
"""
vals = [[]]
vals: list[list[T]] = [[]]

def aggregate(i):
def aggregate(i: T) -> list[T]:
if i is not _EOF:
vals[0].append(i)

if (i is _EOF and vals[0]) or (len(vals[0]) >= num_items):
ret = vals.pop(0)
vals.append([])
return ret
return _SKIP
return _SKIP # pyre-ignore: [7]

name = f"aggregate_{self._num_aggregate}({num_items}, {drop_last=})"
self._num_aggregate += 1
Expand Down Expand Up @@ -720,7 +726,7 @@ def disaggregate(
*,
hooks: Sequence[PipelineHook] | None = None,
report_stats_interval: float | None = None,
) -> "PipelineBuilder":
) -> "PipelineBuilder[T]":
"""Disaggregate the items in the pipeline.
Args:
Expand All @@ -746,7 +752,7 @@ def disaggregate(
)
return self

def add_sink(self, buffer_size: int) -> "PipelineBuilder":
def add_sink(self, buffer_size: int) -> "PipelineBuilder[T]":
"""Attach a buffer to the end of the pipeline.
.. code-block::
Expand Down Expand Up @@ -778,6 +784,7 @@ def _build(self) -> tuple[Coroutine[None, None, None], list[AsyncQueue]]:

# source
queues.append(AsyncQueue(self._source_buffer_size))
assert self._source is not None
coros.append(
(
"AsyncPipeline::0_source",
Expand Down Expand Up @@ -816,12 +823,12 @@ def _build(self) -> tuple[Coroutine[None, None, None], list[AsyncQueue]]:

def _get_desc(self) -> list[str]:
parts = []
src_repr = (
self._source.__name__
if hasattr(self._source, "__name__")
else type(self._source).__name__
)
parts.append(f" - src: {src_repr}")
if self._source is not None:
src_repr = getattr(self._source, "__name__", type(self._source).__name__)
parts.append(f" - src: {src_repr}")
else:
parts.append(" - src: n/a")

if self._source_buffer_size != 1:
parts.append(f" Buffer: buffer_size={self._source_buffer_size}")

Expand Down Expand Up @@ -851,7 +858,7 @@ def _get_desc(self) -> list[str]:
def __str__(self) -> str:
return "\n".join([repr(self), *self._get_desc()])

def build(self, *, num_threads: int | None = None) -> Pipeline:
def build(self, *, num_threads: int | None = None) -> Pipeline[T]:
"""Build the pipeline.
Args:
Expand Down

0 comments on commit 90751c6

Please sign in to comment.