Skip to content

Commit

Permalink
[typing] prefect.concurrency
Browse files Browse the repository at this point in the history
  • Loading branch information
mjpieters committed Dec 18, 2024
1 parent 581510a commit 853011a
Show file tree
Hide file tree
Showing 30 changed files with 528 additions and 592 deletions.
2 changes: 1 addition & 1 deletion client/client_flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from prefect import flow, task
from prefect.concurrency import asyncio, events, services, sync # noqa: F401
from prefect.concurrency import services, sync # noqa: F401


def skip_remote_run():
Expand Down
215 changes: 150 additions & 65 deletions src/prefect/_internal/concurrency/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import queue
import sys
import threading
from typing import Awaitable, Dict, Generic, List, Optional, Type, TypeVar, Union
from collections.abc import AsyncGenerator, Awaitable, Coroutine, Generator, Hashable
from typing import TYPE_CHECKING, Any, Generic, NoReturn, Optional, Union, cast

from typing_extensions import Self
from typing_extensions import Self, TypeVar, TypeVarTuple, Unpack

from prefect._internal.concurrency import logger
from prefect._internal.concurrency.api import create_call, from_sync
Expand All @@ -18,17 +19,19 @@
from prefect._internal.concurrency.threads import WorkerThread, get_global_loop

T = TypeVar("T")
Ts = TypeVarTuple("Ts")
R = TypeVar("R", infer_variance=True)


class QueueService(abc.ABC, Generic[T]):
_instances: Dict[int, Self] = {}
class _QueueServiceBase(abc.ABC, Generic[T]):
_instances: dict[int, Self] = {}
_instance_lock = threading.Lock()

def __init__(self, *args) -> None:
self._queue: queue.Queue = queue.Queue()
def __init__(self, *args: Hashable) -> None:
self._queue: queue.Queue[Optional[T]] = queue.Queue()
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._done_event: Optional[asyncio.Event] = None
self._task: Optional[asyncio.Task] = None
self._task: Optional[asyncio.Task[None]] = None
self._stopped: bool = False
self._started: bool = False
self._key = hash(args)
Expand All @@ -41,14 +44,14 @@ def __init__(self, *args) -> None:
)
self._logger = logging.getLogger(f"{type(self).__name__}")

def start(self):
def start(self) -> None:
logger.debug("Starting service %r", self)
loop_thread = get_global_loop()

if not asyncio.get_running_loop() == loop_thread._loop:
if not asyncio.get_running_loop() == getattr(loop_thread, "_loop"):
raise RuntimeError("Services must run on the global loop thread.")

self._loop = loop_thread._loop
self._loop = asyncio.get_running_loop()
self._done_event = asyncio.Event()
self._task = self._loop.create_task(self._run())
self._queue_get_thread.start()
Expand All @@ -67,14 +70,16 @@ def start(self):
# _before_ the normal `atexit` hook is called resulting in failure to
# process items. This is particularly relevant for services which use an
# httpx client.
from threading import _register_atexit
from threading import (
_register_atexit, # pyright: ignore[reportUnknownVariableType, reportAttributeAccessIssue]
)

_register_atexit(self._at_exit)

def _at_exit(self):
def _at_exit(self) -> None:
self.drain(at_exit=True)

def _stop(self, at_exit: bool = False):
def _stop(self, at_exit: bool = False) -> None:
"""
Stop running this instance.
Expand All @@ -100,27 +105,11 @@ def _stop(self, at_exit: bool = False):
# Signal completion to the loop
self._queue.put_nowait(None)

def send(self, item: T):
"""
Send an item to this instance of the service.
"""
with self._lock:
if self._stopped:
raise RuntimeError("Cannot put items in a stopped service instance.")

logger.debug("Service %r enqueuing item %r", self, item)
self._queue.put_nowait(self._prepare_item(item))

def _prepare_item(self, item: T) -> T:
"""
Prepare an item for submission to the service. This is called before
the item is sent to the service.
The default implementation returns the item unchanged.
"""
return item
@abc.abstractmethod
def send(self, item: Any) -> Any:
raise NotImplementedError

async def _run(self):
async def _run(self) -> None:
try:
async with self._lifespan():
await self._main_loop()
Expand All @@ -142,14 +131,15 @@ async def _run(self):
self._queue_get_thread.shutdown()

self._stopped = True
assert self._done_event is not None
self._done_event.set()

async def _main_loop(self):
async def _main_loop(self) -> None:
last_log_time = 0
log_interval = 4 # log every 4 seconds

while True:
item: T = await self._queue_get_thread.submit(
item: Optional[T] = await self._queue_get_thread.submit(
create_call(self._queue.get)
).aresult()

Expand Down Expand Up @@ -183,19 +173,17 @@ async def _main_loop(self):
self._queue.task_done()

@abc.abstractmethod
async def _handle(self, item: T):
"""
Process an item sent to the service.
"""
def _handle(self, item: Any) -> Any:
raise NotImplementedError

@contextlib.asynccontextmanager
async def _lifespan(self):
async def _lifespan(self) -> AsyncGenerator[None, Any]:
"""
Perform any setup and teardown for the service.
"""
yield

def _drain(self, at_exit: bool = False) -> concurrent.futures.Future:
def _drain(self, at_exit: bool = False) -> concurrent.futures.Future[bool]:
"""
Internal implementation for `drain`. Returns a future for sync/async interfaces.
"""
Expand All @@ -204,15 +192,17 @@ def _drain(self, at_exit: bool = False) -> concurrent.futures.Future:

self._stop(at_exit=at_exit)

assert self._done_event is not None
if self._done_event.is_set():
future = concurrent.futures.Future()
future.set_result(None)
future: concurrent.futures.Future[bool] = concurrent.futures.Future()
future.set_result(False)
return future

future = asyncio.run_coroutine_threadsafe(self._done_event.wait(), self._loop)
return future
assert self._loop is not None
task = cast(Coroutine[Any, Any, bool], self._done_event.wait())
return asyncio.run_coroutine_threadsafe(task, self._loop)

def drain(self, at_exit: bool = False) -> None:
def drain(self, at_exit: bool = False) -> Union[bool, Awaitable[bool]]:
"""
Stop this instance of the service and wait for remaining work to be completed.
Expand All @@ -225,41 +215,50 @@ def drain(self, at_exit: bool = False) -> None:
return future.result()

@classmethod
def drain_all(cls, timeout: Optional[float] = None) -> Union[Awaitable, None]:
def drain_all(
cls, timeout: Optional[float] = None
) -> Union[
tuple[
set[concurrent.futures.Future[bool]], set[concurrent.futures.Future[bool]]
],
Coroutine[
Any,
Any,
Optional[tuple[set[asyncio.Future[bool]], set[asyncio.Future[bool]]]],
],
]:
"""
Stop all instances of the service and wait for all remaining work to be
completed.
Returns an awaitable if called from an async context.
"""
futures = []
futures: list[concurrent.futures.Future[bool]] = []
with cls._instance_lock:
instances = tuple(cls._instances.values())

for instance in instances:
futures.append(instance._drain())

if get_running_loop() is not None:
return (
asyncio.wait(
if futures:
return asyncio.wait(
[asyncio.wrap_future(fut) for fut in futures], timeout=timeout
)
if futures
# `wait` errors if it receives an empty list but we need to return a
# coroutine still
else asyncio.sleep(0)
)
# `wait` errors if it receives an empty list but we need to return a
# coroutine still
return asyncio.sleep(0)
else:
return concurrent.futures.wait(futures, timeout=timeout)

def wait_until_empty(self):
def wait_until_empty(self) -> None:
"""
Wait until the queue is empty and all items have been processed.
"""
self._queue.join()

@classmethod
def instance(cls: Type[Self], *args) -> Self:
def instance(cls, *args: Hashable) -> Self:
"""
Get an instance of the service.
Expand All @@ -276,7 +275,7 @@ def _remove_instance(self):
self._instances.pop(self._key, None)

@classmethod
def _new_instance(cls, *args):
def _new_instance(cls, *args: Hashable) -> Self:
"""
Create and start a new instance of the service.
"""
Expand All @@ -293,6 +292,87 @@ def _new_instance(cls, *args):
return instance


class QueueService(_QueueServiceBase[T]):
def send(self, item: T) -> None:
"""
Send an item to this instance of the service.
"""
with self._lock:
if self._stopped:
raise RuntimeError("Cannot put items in a stopped service instance.")

logger.debug("Service %r enqueuing item %r", self, item)
self._queue.put_nowait(self._prepare_item(item))

def _prepare_item(self, item: T) -> T:
"""
Prepare an item for submission to the service. This is called before
the item is sent to the service.
The default implementation returns the item unchanged.
"""
return item

@abc.abstractmethod
async def _handle(self, item: T) -> None:
"""
Process an item sent to the service.
"""


class FutureQueueService(
_QueueServiceBase[tuple[Unpack[Ts], concurrent.futures.Future[R]]]
):
"""Queued service that provides a future that is signalled with the acquired result for each item
If there was a failure acquiring, the future result is set to the exception.
Type Parameters:
Ts: the tuple of types that make up sent arguments
R: the type returned for each item once acquired
"""

async def _handle(
self, item: tuple[Unpack[Ts], concurrent.futures.Future[R]]
) -> None:
send_item, future = item[:-1], item[-1]
try:
response = await self.acquire(*send_item)
except Exception as exc:
# If the request to the increment endpoint fails in a non-standard
# way, we need to set the future's result so it'll be re-raised in
# the context of the caller.
future.set_exception(exc)
raise exc
else:
future.set_result(response)

@abc.abstractmethod
async def acquire(self, *args: Unpack[Ts]) -> R:
raise NotImplementedError

def send(self, item: tuple[Unpack[Ts]]) -> concurrent.futures.Future[R]:
with self._lock:
if self._stopped:
raise RuntimeError("Cannot put items in a stopped service instance.")

logger.debug("Service %r enqueuing item %r", self, item)
future: concurrent.futures.Future[R] = concurrent.futures.Future()
self._queue.put_nowait((*self._prepare_item(item), future))

return future

def _prepare_item(self, item: tuple[Unpack[Ts]]) -> tuple[Unpack[Ts]]:
"""
Prepare an item for submission to the service. This is called before
the item is sent to the service.
The default implementation returns the item unchanged.
"""
return item


class BatchedQueueService(QueueService[T]):
"""
A queue service that handles a batch of items instead of a single item at a time.
Expand All @@ -308,7 +388,7 @@ async def _main_loop(self):
done = False

while not done:
batch = []
batch: list[T] = []
batch_size = 0

# Pull items from the queue until we reach the batch size
Expand Down Expand Up @@ -357,13 +437,15 @@ async def _main_loop(self):
)

@abc.abstractmethod
async def _handle_batch(self, items: List[T]):
async def _handle_batch(self, items: list[T]) -> None:
"""
Process a batch of items sent to the service.
"""

async def _handle(self, item: T):
assert False, "`_handle` should never be called for batched queue services"
async def _handle(self, item: T) -> NoReturn:
raise AssertionError(
"`_handle` should never be called for batched queue services"
)

def _get_size(self, item: T) -> int:
"""
Expand All @@ -374,12 +456,15 @@ def _get_size(self, item: T) -> int:


@contextlib.contextmanager
def drain_on_exit(service: QueueService):
def drain_on_exit(service: QueueService[Any]) -> Generator[None, Any, None]:
yield
service.drain_all()


@contextlib.asynccontextmanager
async def drain_on_exit_async(service: QueueService):
async def drain_on_exit_async(service: QueueService[Any]) -> AsyncGenerator[None, Any]:
yield
await service.drain_all()
drain_all = service.drain_all()
if TYPE_CHECKING:
assert not isinstance(drain_all, tuple)
await drain_all
3 changes: 3 additions & 0 deletions src/prefect/concurrency/.ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
extend-select = ["UP"]
target-version = "py39"

Loading

0 comments on commit 853011a

Please sign in to comment.