diff --git a/client/client_flow.py b/client/client_flow.py index 429a4f950766..be486002f62d 100644 --- a/client/client_flow.py +++ b/client/client_flow.py @@ -1,5 +1,5 @@ from prefect import flow, task -from prefect.concurrency import asyncio, events, services, sync # noqa: F401 +from prefect.concurrency import asyncio, services, sync # noqa: F401 def skip_remote_run(): diff --git a/src/prefect/_internal/concurrency/services.py b/src/prefect/_internal/concurrency/services.py index ad54c076299c..1f992763ef83 100644 --- a/src/prefect/_internal/concurrency/services.py +++ b/src/prefect/_internal/concurrency/services.py @@ -1,15 +1,14 @@ import abc import asyncio -import atexit import concurrent.futures import contextlib import logging import queue -import sys import threading -from typing import Awaitable, Dict, Generic, List, Optional, Type, TypeVar, Union +from collections.abc import AsyncGenerator, Awaitable, Coroutine, Generator, Hashable +from typing import TYPE_CHECKING, Any, Generic, NoReturn, Optional, Union, cast -from typing_extensions import Self +from typing_extensions import Self, TypeVar, TypeVarTuple, Unpack from prefect._internal.concurrency import logger from prefect._internal.concurrency.api import create_call, from_sync @@ -18,17 +17,19 @@ from prefect._internal.concurrency.threads import WorkerThread, get_global_loop T = TypeVar("T") +Ts = TypeVarTuple("Ts") +R = TypeVar("R", infer_variance=True) -class QueueService(abc.ABC, Generic[T]): - _instances: Dict[int, Self] = {} +class _QueueServiceBase(abc.ABC, Generic[T]): + _instances: dict[int, Self] = {} _instance_lock = threading.Lock() - def __init__(self, *args) -> None: - self._queue: queue.Queue = queue.Queue() + def __init__(self, *args: Hashable) -> None: + self._queue: queue.Queue[Optional[T]] = queue.Queue() self._loop: Optional[asyncio.AbstractEventLoop] = None self._done_event: Optional[asyncio.Event] = None - self._task: Optional[asyncio.Task] = None + self._task: Optional[asyncio.Task[None]] = None self._stopped: bool = False self._started: bool = False self._key = hash(args) @@ -41,14 +42,14 @@ def __init__(self, *args) -> None: ) self._logger = logging.getLogger(f"{type(self).__name__}") - def start(self): + def start(self) -> None: logger.debug("Starting service %r", self) loop_thread = get_global_loop() - if not asyncio.get_running_loop() == loop_thread._loop: + if not asyncio.get_running_loop() == getattr(loop_thread, "_loop"): raise RuntimeError("Services must run on the global loop thread.") - self._loop = loop_thread._loop + self._loop = asyncio.get_running_loop() self._done_event = asyncio.Event() self._task = self._loop.create_task(self._run()) self._queue_get_thread.start() @@ -58,23 +59,18 @@ def start(self): loop_thread.add_shutdown_call(create_call(self.drain)) # Stop at interpreter exit by default - if sys.version_info < (3, 9): - atexit.register(self._at_exit) - else: - # See related issue at https://bugs.python.org/issue42647 - # Handling items may require spawning a thread and in 3.9 new threads - # cannot be spawned after the interpreter finalizes threads which happens - # _before_ the normal `atexit` hook is called resulting in failure to - # process items. This is particularly relevant for services which use an - # httpx client. - from threading import _register_atexit - - _register_atexit(self._at_exit) - - def _at_exit(self): + # Handling items may require spawning a thread and in 3.9 new threads + # cannot be spawned after the interpreter finalizes threads which + # happens _before_ the normal `atexit` hook is called resulting in + # failure to process items. This is particularly relevant for services + # which use an httpx client. See related issue at + # https://github.com/python/cpython/issues/86813 + threading._register_atexit(self._at_exit) # pyright: ignore[reportUnknownVariableType, reportAttributeAccessIssue] + + def _at_exit(self) -> None: self.drain(at_exit=True) - def _stop(self, at_exit: bool = False): + def _stop(self, at_exit: bool = False) -> None: """ Stop running this instance. @@ -100,27 +96,11 @@ def _stop(self, at_exit: bool = False): # Signal completion to the loop self._queue.put_nowait(None) - def send(self, item: T): - """ - Send an item to this instance of the service. - """ - with self._lock: - if self._stopped: - raise RuntimeError("Cannot put items in a stopped service instance.") - - logger.debug("Service %r enqueuing item %r", self, item) - self._queue.put_nowait(self._prepare_item(item)) - - def _prepare_item(self, item: T) -> T: - """ - Prepare an item for submission to the service. This is called before - the item is sent to the service. - - The default implementation returns the item unchanged. - """ - return item + @abc.abstractmethod + def send(self, item: Any) -> Any: + raise NotImplementedError - async def _run(self): + async def _run(self) -> None: try: async with self._lifespan(): await self._main_loop() @@ -142,14 +122,15 @@ async def _run(self): self._queue_get_thread.shutdown() self._stopped = True + assert self._done_event is not None self._done_event.set() - async def _main_loop(self): + async def _main_loop(self) -> None: last_log_time = 0 log_interval = 4 # log every 4 seconds while True: - item: T = await self._queue_get_thread.submit( + item: Optional[T] = await self._queue_get_thread.submit( create_call(self._queue.get) ).aresult() @@ -183,19 +164,17 @@ async def _main_loop(self): self._queue.task_done() @abc.abstractmethod - async def _handle(self, item: T): - """ - Process an item sent to the service. - """ + async def _handle(self, item: Any) -> Any: + raise NotImplementedError @contextlib.asynccontextmanager - async def _lifespan(self): + async def _lifespan(self) -> AsyncGenerator[None, Any]: """ Perform any setup and teardown for the service. """ yield - def _drain(self, at_exit: bool = False) -> concurrent.futures.Future: + def _drain(self, at_exit: bool = False) -> concurrent.futures.Future[bool]: """ Internal implementation for `drain`. Returns a future for sync/async interfaces. """ @@ -204,15 +183,17 @@ def _drain(self, at_exit: bool = False) -> concurrent.futures.Future: self._stop(at_exit=at_exit) + assert self._done_event is not None if self._done_event.is_set(): - future = concurrent.futures.Future() - future.set_result(None) + future: concurrent.futures.Future[bool] = concurrent.futures.Future() + future.set_result(False) return future - future = asyncio.run_coroutine_threadsafe(self._done_event.wait(), self._loop) - return future + assert self._loop is not None + task = cast(Coroutine[Any, Any, bool], self._done_event.wait()) + return asyncio.run_coroutine_threadsafe(task, self._loop) - def drain(self, at_exit: bool = False) -> None: + def drain(self, at_exit: bool = False) -> Union[bool, Awaitable[bool]]: """ Stop this instance of the service and wait for remaining work to be completed. @@ -226,15 +207,24 @@ def drain(self, at_exit: bool = False) -> None: @classmethod def drain_all( - cls, timeout: Optional[float] = None, at_exit=True - ) -> Union[Awaitable, None]: + cls, timeout: Optional[float] = None, at_exit: bool = True + ) -> Union[ + tuple[ + set[concurrent.futures.Future[bool]], set[concurrent.futures.Future[bool]] + ], + Coroutine[ + Any, + Any, + Optional[tuple[set[asyncio.Future[bool]], set[asyncio.Future[bool]]]], + ], + ]: """ Stop all instances of the service and wait for all remaining work to be completed. Returns an awaitable if called from an async context. """ - futures = [] + futures: list[concurrent.futures.Future[bool]] = [] with cls._instance_lock: instances = tuple(cls._instances.values()) @@ -242,26 +232,24 @@ def drain_all( futures.append(instance._drain(at_exit=at_exit)) if get_running_loop() is not None: - return ( - asyncio.wait( + if futures: + return asyncio.wait( [asyncio.wrap_future(fut) for fut in futures], timeout=timeout ) - if futures - # `wait` errors if it receives an empty list but we need to return a - # coroutine still - else asyncio.sleep(0) - ) + # `wait` errors if it receives an empty list but we need to return a + # coroutine still + return asyncio.sleep(0) else: return concurrent.futures.wait(futures, timeout=timeout) - def wait_until_empty(self): + def wait_until_empty(self) -> None: """ Wait until the queue is empty and all items have been processed. """ self._queue.join() @classmethod - def instance(cls: Type[Self], *args) -> Self: + def instance(cls, *args: Hashable) -> Self: """ Get an instance of the service. @@ -278,7 +266,7 @@ def _remove_instance(self): self._instances.pop(self._key, None) @classmethod - def _new_instance(cls, *args): + def _new_instance(cls, *args: Hashable) -> Self: """ Create and start a new instance of the service. """ @@ -295,6 +283,87 @@ def _new_instance(cls, *args): return instance +class QueueService(_QueueServiceBase[T]): + def send(self, item: T) -> None: + """ + Send an item to this instance of the service. + """ + with self._lock: + if self._stopped: + raise RuntimeError("Cannot put items in a stopped service instance.") + + logger.debug("Service %r enqueuing item %r", self, item) + self._queue.put_nowait(self._prepare_item(item)) + + def _prepare_item(self, item: T) -> T: + """ + Prepare an item for submission to the service. This is called before + the item is sent to the service. + + The default implementation returns the item unchanged. + """ + return item + + @abc.abstractmethod + async def _handle(self, item: T) -> None: + """ + Process an item sent to the service. + """ + + +class FutureQueueService( + _QueueServiceBase[tuple[Unpack[Ts], concurrent.futures.Future[R]]] +): + """Queued service that provides a future that is signalled with the acquired result for each item + + If there was a failure acquiring, the future result is set to the exception. + + Type Parameters: + Ts: the tuple of types that make up sent arguments + R: the type returned for each item once acquired + + """ + + async def _handle( + self, item: tuple[Unpack[Ts], concurrent.futures.Future[R]] + ) -> None: + send_item, future = item[:-1], item[-1] + try: + response = await self.acquire(*send_item) + except Exception as exc: + # If the request to the increment endpoint fails in a non-standard + # way, we need to set the future's result so it'll be re-raised in + # the context of the caller. + future.set_exception(exc) + raise exc + else: + future.set_result(response) + + @abc.abstractmethod + async def acquire(self, *args: Unpack[Ts]) -> R: + raise NotImplementedError + + def send(self, item: tuple[Unpack[Ts]]) -> concurrent.futures.Future[R]: + with self._lock: + if self._stopped: + raise RuntimeError("Cannot put items in a stopped service instance.") + + logger.debug("Service %r enqueuing item %r", self, item) + future: concurrent.futures.Future[R] = concurrent.futures.Future() + self._queue.put_nowait((*self._prepare_item(item), future)) + + return future + + def _prepare_item(self, item: tuple[Unpack[Ts]]) -> tuple[Unpack[Ts]]: + """ + Prepare an item for submission to the service. This is called before + the item is sent to the service. + + The default implementation returns the item unchanged. + """ + return item + + class BatchedQueueService(QueueService[T]): """ A queue service that handles a batch of items instead of a single item at a time. @@ -310,7 +379,7 @@ async def _main_loop(self): done = False while not done: - batch = [] + batch: list[T] = [] batch_size = 0 # Pull items from the queue until we reach the batch size @@ -359,13 +428,15 @@ async def _main_loop(self): ) @abc.abstractmethod - async def _handle_batch(self, items: List[T]): + async def _handle_batch(self, items: list[T]) -> None: """ Process a batch of items sent to the service. """ - async def _handle(self, item: T): - assert False, "`_handle` should never be called for batched queue services" + async def _handle(self, item: T) -> NoReturn: + raise AssertionError( + "`_handle` should never be called for batched queue services" + ) def _get_size(self, item: T) -> int: """ @@ -376,12 +447,15 @@ def _get_size(self, item: T) -> int: @contextlib.contextmanager -def drain_on_exit(service: QueueService): +def drain_on_exit(service: QueueService[Any]) -> Generator[None, Any, None]: yield service.drain_all(at_exit=True) @contextlib.asynccontextmanager -async def drain_on_exit_async(service: QueueService): +async def drain_on_exit_async(service: QueueService[Any]) -> AsyncGenerator[None, Any]: yield - await service.drain_all(at_exit=True) + drain_all = service.drain_all(at_exit=True) + if TYPE_CHECKING: + assert not isinstance(drain_all, tuple) + await drain_all diff --git a/src/prefect/concurrency/_asyncio.py b/src/prefect/concurrency/_asyncio.py new file mode 100644 index 000000000000..a7a8ff54ec03 --- /dev/null +++ b/src/prefect/concurrency/_asyncio.py @@ -0,0 +1,87 @@ +import asyncio +from typing import Literal, Optional + +import httpx + +from prefect._internal.compatibility.deprecated import deprecated_parameter +from prefect.client.orchestration import get_client +from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse +from prefect.logging.loggers import get_run_logger + +from .services import ConcurrencySlotAcquisitionService + + +class ConcurrencySlotAcquisitionError(Exception): + """Raised when an unhandlable occurs while acquiring concurrency slots.""" + + +class AcquireConcurrencySlotTimeoutError(TimeoutError): + """Raised when acquiring a concurrency slot times out.""" + + +@deprecated_parameter( + name="create_if_missing", + start_date="Sep 2024", + end_date="Oct 2024", + when=lambda x: x is not None, + help="Limits must be explicitly created before acquiring concurrency slots; see `strict` if you want to enforce this behavior.", +) +async def aacquire_concurrency_slots( + names: list[str], + slots: int, + mode: Literal["concurrency", "rate_limit"] = "concurrency", + timeout_seconds: Optional[float] = None, + create_if_missing: Optional[bool] = None, + max_retries: Optional[int] = None, + strict: bool = False, +) -> list[MinimalConcurrencyLimitResponse]: + service = ConcurrencySlotAcquisitionService.instance(frozenset(names)) + future = service.send( + (slots, mode, timeout_seconds, create_if_missing, max_retries) + ) + try: + response = await asyncio.wrap_future(future) + except TimeoutError as timeout: + raise AcquireConcurrencySlotTimeoutError( + f"Attempt to acquire concurrency slots timed out after {timeout_seconds} second(s)" + ) from timeout + except Exception as exc: + raise ConcurrencySlotAcquisitionError( + f"Unable to acquire concurrency slots on {names!r}" + ) from exc + + retval = _response_to_minimal_concurrency_limit_response(response) + + if not retval: + if strict: + raise ConcurrencySlotAcquisitionError( + f"Concurrency limits {names!r} must be created before acquiring slots" + ) + try: + logger = get_run_logger() + except Exception: + pass + else: + logger.warning( + f"Concurrency limits {names!r} do not exist - skipping acquisition." + ) + + return retval + + +async def arelease_concurrency_slots( + names: list[str], slots: int, occupancy_seconds: float +) -> list[MinimalConcurrencyLimitResponse]: + async with get_client() as client: + response = await client.release_concurrency_slots( + names=names, slots=slots, occupancy_seconds=occupancy_seconds + ) + return _response_to_minimal_concurrency_limit_response(response) + + +def _response_to_minimal_concurrency_limit_response( + response: httpx.Response, +) -> list[MinimalConcurrencyLimitResponse]: + return [ + MinimalConcurrencyLimitResponse.model_validate(obj_) for obj_ in response.json() + ] diff --git a/src/prefect/concurrency/events.py b/src/prefect/concurrency/_events.py similarity index 76% rename from src/prefect/concurrency/events.py rename to src/prefect/concurrency/_events.py index c5a7598c7f47..acd49b156dd3 100644 --- a/src/prefect/concurrency/events.py +++ b/src/prefect/concurrency/_events.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Literal, Optional, Union +from typing import Literal, Optional, Union from uuid import UUID from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse @@ -8,11 +8,11 @@ def _emit_concurrency_event( phase: Union[Literal["acquired"], Literal["released"]], primary_limit: MinimalConcurrencyLimitResponse, - related_limits: List[MinimalConcurrencyLimitResponse], + related_limits: list[MinimalConcurrencyLimitResponse], slots: int, follows: Union[Event, None] = None, ) -> Union[Event, None]: - resource: Dict[str, str] = { + resource: dict[str, str] = { "prefect.resource.id": f"prefect.concurrency-limit.{primary_limit.id}", "prefect.resource.name": primary_limit.name, "slots-acquired": str(slots), @@ -38,11 +38,11 @@ def _emit_concurrency_event( ) -def _emit_concurrency_acquisition_events( - limits: List[MinimalConcurrencyLimitResponse], +def emit_concurrency_acquisition_events( + limits: list[MinimalConcurrencyLimitResponse], occupy: int, -) -> Dict[UUID, Optional[Event]]: - events = {} +) -> dict[UUID, Optional[Event]]: + events: dict[UUID, Optional[Event]] = {} for limit in limits: event = _emit_concurrency_event("acquired", limit, limits, occupy) events[limit.id] = event @@ -50,10 +50,10 @@ def _emit_concurrency_acquisition_events( return events -def _emit_concurrency_release_events( - limits: List[MinimalConcurrencyLimitResponse], +def emit_concurrency_release_events( + limits: list[MinimalConcurrencyLimitResponse], occupy: int, - events: Dict[UUID, Optional[Event]], + events: dict[UUID, Optional[Event]], ) -> None: for limit in limits: _emit_concurrency_event("released", limit, limits, occupy, events[limit.id]) diff --git a/src/prefect/concurrency/asyncio.py b/src/prefect/concurrency/asyncio.py index 5d419a6c079f..a42c4edcfabf 100644 --- a/src/prefect/concurrency/asyncio.py +++ b/src/prefect/concurrency/asyncio.py @@ -1,42 +1,25 @@ -import asyncio +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import AsyncGenerator, List, Literal, Optional, Union, cast +from typing import Optional, Union import anyio -import httpx import pendulum -from prefect._internal.compatibility.deprecated import deprecated_parameter - -try: - from pendulum import Interval -except ImportError: - # pendulum < 3 - from pendulum.period import Period as Interval # type: ignore - -from prefect.client.orchestration import get_client -from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse -from prefect.logging.loggers import get_run_logger - -from .context import ConcurrencyContext -from .events import ( - _emit_concurrency_acquisition_events, - _emit_concurrency_release_events, +from ._asyncio import ( + AcquireConcurrencySlotTimeoutError as AcquireConcurrencySlotTimeoutError, ) -from .services import ConcurrencySlotAcquisitionService - - -class ConcurrencySlotAcquisitionError(Exception): - """Raised when an unhandlable occurs while acquiring concurrency slots.""" - - -class AcquireConcurrencySlotTimeoutError(TimeoutError): - """Raised when acquiring a concurrency slot times out.""" +from ._asyncio import ConcurrencySlotAcquisitionError as ConcurrencySlotAcquisitionError +from ._asyncio import aacquire_concurrency_slots, arelease_concurrency_slots +from ._events import ( + emit_concurrency_acquisition_events, + emit_concurrency_release_events, +) +from .context import ConcurrencyContext @asynccontextmanager async def concurrency( - names: Union[str, List[str]], + names: Union[str, list[str]], occupy: int = 1, timeout_seconds: Optional[float] = None, max_retries: Optional[int] = None, @@ -78,7 +61,7 @@ async def main(): names = names if isinstance(names, list) else [names] - limits = await _aacquire_concurrency_slots( + limits = await aacquire_concurrency_slots( names, occupy, timeout_seconds=timeout_seconds, @@ -87,14 +70,14 @@ async def main(): strict=strict, ) acquisition_time = pendulum.now("UTC") - emitted_events = _emit_concurrency_acquisition_events(limits, occupy) + emitted_events = emit_concurrency_acquisition_events(limits, occupy) try: yield finally: - occupancy_period = cast(Interval, (pendulum.now("UTC") - acquisition_time)) + occupancy_period = pendulum.now("UTC") - acquisition_time try: - await _arelease_concurrency_slots( + await arelease_concurrency_slots( names, occupy, occupancy_period.total_seconds() ) except anyio.get_cancelled_exc_class(): @@ -106,11 +89,11 @@ async def main(): (names, occupy, occupancy_period.total_seconds()) ) - _emit_concurrency_release_events(limits, occupy, emitted_events) + emit_concurrency_release_events(limits, occupy, emitted_events) async def rate_limit( - names: Union[str, List[str]], + names: Union[str, list[str]], occupy: int = 1, timeout_seconds: Optional[float] = None, create_if_missing: Optional[bool] = None, @@ -137,7 +120,7 @@ async def rate_limit( names = names if isinstance(names, list) else [names] - limits = await _aacquire_concurrency_slots( + limits = await aacquire_concurrency_slots( names, occupy, mode="rate_limit", @@ -145,71 +128,4 @@ async def rate_limit( create_if_missing=create_if_missing, strict=strict, ) - _emit_concurrency_acquisition_events(limits, occupy) - - -@deprecated_parameter( - name="create_if_missing", - start_date="Sep 2024", - end_date="Oct 2024", - when=lambda x: x is not None, - help="Limits must be explicitly created before acquiring concurrency slots; see `strict` if you want to enforce this behavior.", -) -async def _aacquire_concurrency_slots( - names: List[str], - slots: int, - mode: Literal["concurrency", "rate_limit"] = "concurrency", - timeout_seconds: Optional[float] = None, - create_if_missing: Optional[bool] = None, - max_retries: Optional[int] = None, - strict: bool = False, -) -> List[MinimalConcurrencyLimitResponse]: - service = ConcurrencySlotAcquisitionService.instance(frozenset(names)) - future = service.send( - (slots, mode, timeout_seconds, create_if_missing, max_retries) - ) - response_or_exception = await asyncio.wrap_future(future) - - if isinstance(response_or_exception, Exception): - if isinstance(response_or_exception, TimeoutError): - raise AcquireConcurrencySlotTimeoutError( - f"Attempt to acquire concurrency slots timed out after {timeout_seconds} second(s)" - ) from response_or_exception - - raise ConcurrencySlotAcquisitionError( - f"Unable to acquire concurrency slots on {names!r}" - ) from response_or_exception - - retval = _response_to_minimal_concurrency_limit_response(response_or_exception) - - if strict and not retval: - raise ConcurrencySlotAcquisitionError( - f"Concurrency limits {names!r} must be created before acquiring slots" - ) - elif not retval: - try: - logger = get_run_logger() - logger.warning( - f"Concurrency limits {names!r} do not exist - skipping acquisition." - ) - except Exception: - pass - return retval - - -async def _arelease_concurrency_slots( - names: List[str], slots: int, occupancy_seconds: float -) -> List[MinimalConcurrencyLimitResponse]: - async with get_client() as client: - response = await client.release_concurrency_slots( - names=names, slots=slots, occupancy_seconds=occupancy_seconds - ) - return _response_to_minimal_concurrency_limit_response(response) - - -def _response_to_minimal_concurrency_limit_response( - response: httpx.Response, -) -> List[MinimalConcurrencyLimitResponse]: - return [ - MinimalConcurrencyLimitResponse.model_validate(obj_) for obj_ in response.json() - ] + emit_concurrency_acquisition_events(limits, occupy) diff --git a/src/prefect/concurrency/context.py b/src/prefect/concurrency/context.py index 9fc3b40ddb80..986f36281c3a 100644 --- a/src/prefect/concurrency/context.py +++ b/src/prefect/concurrency/context.py @@ -1,19 +1,21 @@ from contextvars import ContextVar -from typing import List, Tuple +from typing import Any, ClassVar + +from typing_extensions import Self from prefect.client.orchestration import get_client from prefect.context import ContextModel, Field class ConcurrencyContext(ContextModel): - __var__: ContextVar = ContextVar("concurrency") + __var__: ClassVar[ContextVar[Self]] = ContextVar("concurrency") # Track the slots that have been acquired but were not able to be released # due to cancellation or some other error. These slots are released when # the context manager exits. - cleanup_slots: List[Tuple[List[str], int, float]] = Field(default_factory=list) + cleanup_slots: list[tuple[list[str], int, float]] = Field(default_factory=list) - def __exit__(self, *exc_info): + def __exit__(self, *exc_info: Any) -> None: if self.cleanup_slots: with get_client(sync_client=True) as client: for names, occupy, occupancy_seconds in self.cleanup_slots: diff --git a/src/prefect/concurrency/services.py b/src/prefect/concurrency/services.py index 64e847cad582..530ea7ceb303 100644 --- a/src/prefect/concurrency/services.py +++ b/src/prefect/concurrency/services.py @@ -1,31 +1,30 @@ import asyncio -import concurrent.futures +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import ( - TYPE_CHECKING, - AsyncGenerator, - FrozenSet, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Optional import httpx from starlette import status +from typing_extensions import TypeAlias, Unpack from prefect._internal.concurrency import logger -from prefect._internal.concurrency.services import QueueService +from prefect._internal.concurrency.services import FutureQueueService from prefect.client.orchestration import get_client from prefect.utilities.timeout import timeout_async if TYPE_CHECKING: from prefect.client.orchestration import PrefectClient +_Item: TypeAlias = tuple[int, str, Optional[float], Optional[bool], Optional[int]] -class ConcurrencySlotAcquisitionService(QueueService): - def __init__(self, concurrency_limit_names: FrozenSet[str]): + +class ConcurrencySlotAcquisitionService( + FutureQueueService[Unpack[_Item], httpx.Response] +): + def __init__(self, concurrency_limit_names: frozenset[str]): super().__init__(concurrency_limit_names) - self._client: "PrefectClient" - self.concurrency_limit_names = sorted(list(concurrency_limit_names)) + self._client: PrefectClient + self.concurrency_limit_names: list[str] = sorted(list(concurrency_limit_names)) @asynccontextmanager async def _lifespan(self) -> AsyncGenerator[None, None]: @@ -33,32 +32,7 @@ async def _lifespan(self) -> AsyncGenerator[None, None]: self._client = client yield - async def _handle( - self, - item: Tuple[ - int, - str, - Optional[float], - concurrent.futures.Future, - Optional[bool], - Optional[int], - ], - ) -> None: - occupy, mode, timeout_seconds, future, create_if_missing, max_retries = item - try: - response = await self.acquire_slots( - occupy, mode, timeout_seconds, create_if_missing, max_retries - ) - except Exception as exc: - # If the request to the increment endpoint fails in a non-standard - # way, we need to set the future's result so that the caller can - # handle the exception and then re-raise. - future.set_result(exc) - raise exc - else: - future.set_result(response) - - async def acquire_slots( + async def acquire( self, slots: int, mode: str, @@ -69,44 +43,22 @@ async def acquire_slots( with timeout_async(seconds=timeout_seconds): while True: try: - response = await self._client.increment_concurrency_slots( + return await self._client.increment_concurrency_slots( names=self.concurrency_limit_names, slots=slots, mode=mode, create_if_missing=create_if_missing, ) - except Exception as exc: - if ( - isinstance(exc, httpx.HTTPStatusError) - and exc.response.status_code == status.HTTP_423_LOCKED - ): - if max_retries is not None and max_retries <= 0: - raise exc - retry_after = float(exc.response.headers["Retry-After"]) - logger.debug( - f"Unable to acquire concurrency slot. Retrying in {retry_after} second(s)." - ) - await asyncio.sleep(retry_after) - if max_retries is not None: - max_retries -= 1 - else: - raise exc - else: - return response - - def send( - self, item: Tuple[int, str, Optional[float], Optional[bool], Optional[int]] - ) -> concurrent.futures.Future: - with self._lock: - if self._stopped: - raise RuntimeError("Cannot put items in a stopped service instance.") + except httpx.HTTPStatusError as exc: + if not exc.response.status_code == status.HTTP_423_LOCKED: + raise - logger.debug("Service %r enqueuing item %r", self, item) - future: concurrent.futures.Future = concurrent.futures.Future() - - occupy, mode, timeout_seconds, create_if_missing, max_retries = item - self._queue.put_nowait( - (occupy, mode, timeout_seconds, future, create_if_missing, max_retries) - ) - - return future + if max_retries is not None and max_retries <= 0: + raise exc + retry_after = float(exc.response.headers["Retry-After"]) + logger.debug( + f"Unable to acquire concurrency slot. Retrying in {retry_after} second(s)." + ) + await asyncio.sleep(retry_after) + if max_retries is not None: + max_retries -= 1 diff --git a/src/prefect/concurrency/sync.py b/src/prefect/concurrency/sync.py index 2f6bf47a3df6..88aa69f47c45 100644 --- a/src/prefect/concurrency/sync.py +++ b/src/prefect/concurrency/sync.py @@ -1,71 +1,54 @@ +from collections.abc import Generator from contextlib import contextmanager -from typing import ( - Generator, - List, - Optional, - TypeVar, - Union, - cast, -) +from typing import Optional, TypeVar, Union import pendulum from typing_extensions import Literal -from prefect.utilities.asyncutils import run_coro_as_sync - -try: - from pendulum import Interval -except ImportError: - # pendulum < 3 - from pendulum.period import Period as Interval # type: ignore - from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse +from prefect.utilities.asyncutils import run_coro_as_sync -from .asyncio import ( - _aacquire_concurrency_slots, - _arelease_concurrency_slots, +from ._asyncio import ( + aacquire_concurrency_slots, + arelease_concurrency_slots, ) -from .events import ( - _emit_concurrency_acquisition_events, - _emit_concurrency_release_events, +from ._events import ( + emit_concurrency_acquisition_events, + emit_concurrency_release_events, ) T = TypeVar("T") def _release_concurrency_slots( - names: List[str], slots: int, occupancy_seconds: float -) -> List[MinimalConcurrencyLimitResponse]: + names: list[str], slots: int, occupancy_seconds: float +) -> list[MinimalConcurrencyLimitResponse]: result = run_coro_as_sync( - _arelease_concurrency_slots(names, slots, occupancy_seconds) + arelease_concurrency_slots(names, slots, occupancy_seconds) ) - if result is None: - raise RuntimeError("Failed to release concurrency slots") return result def _acquire_concurrency_slots( - names: List[str], + names: list[str], slots: int, mode: Literal["concurrency", "rate_limit"] = "concurrency", timeout_seconds: Optional[float] = None, create_if_missing: Optional[bool] = None, max_retries: Optional[int] = None, strict: bool = False, -) -> List[MinimalConcurrencyLimitResponse]: +) -> list[MinimalConcurrencyLimitResponse]: result = run_coro_as_sync( - _aacquire_concurrency_slots( + aacquire_concurrency_slots( names, slots, mode, timeout_seconds, create_if_missing, max_retries, strict ) ) - if result is None: - raise RuntimeError("Failed to acquire concurrency slots") return result @contextmanager def concurrency( - names: Union[str, List[str]], + names: Union[str, list[str]], occupy: int = 1, timeout_seconds: Optional[float] = None, max_retries: Optional[int] = None, @@ -107,7 +90,7 @@ def main(): names = names if isinstance(names, list) else [names] - limits: List[MinimalConcurrencyLimitResponse] = _acquire_concurrency_slots( + limits: list[MinimalConcurrencyLimitResponse] = _acquire_concurrency_slots( names, occupy, timeout_seconds=timeout_seconds, @@ -116,22 +99,18 @@ def main(): max_retries=max_retries, ) acquisition_time = pendulum.now("UTC") - emitted_events = _emit_concurrency_acquisition_events(limits, occupy) + emitted_events = emit_concurrency_acquisition_events(limits, occupy) try: yield finally: - occupancy_period = cast(Interval, pendulum.now("UTC") - acquisition_time) - _release_concurrency_slots( - names, - occupy, - occupancy_period.total_seconds(), - ) - _emit_concurrency_release_events(limits, occupy, emitted_events) + occupancy_period = pendulum.now("UTC") - acquisition_time + _release_concurrency_slots(names, occupy, occupancy_period.total_seconds()) + emit_concurrency_release_events(limits, occupy, emitted_events) def rate_limit( - names: Union[str, List[str]], + names: Union[str, list[str]], occupy: int = 1, timeout_seconds: Optional[float] = None, create_if_missing: Optional[bool] = None, @@ -166,4 +145,4 @@ def rate_limit( create_if_missing=create_if_missing, strict=strict, ) - _emit_concurrency_acquisition_events(limits, occupy) + emit_concurrency_acquisition_events(limits, occupy) diff --git a/src/prefect/concurrency/v1/_asyncio.py b/src/prefect/concurrency/v1/_asyncio.py new file mode 100644 index 000000000000..d8ac04c4adb8 --- /dev/null +++ b/src/prefect/concurrency/v1/_asyncio.py @@ -0,0 +1,63 @@ +import asyncio +from typing import Optional +from uuid import UUID + +import httpx + +from prefect.client.orchestration import get_client +from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse +from prefect.utilities.asyncutils import sync_compatible + +from .services import ConcurrencySlotAcquisitionService + + +class ConcurrencySlotAcquisitionError(Exception): + """Raised when an unhandlable occurs while acquiring concurrency slots.""" + + +class AcquireConcurrencySlotTimeoutError(TimeoutError): + """Raised when acquiring a concurrency slot times out.""" + + +@sync_compatible +async def acquire_concurrency_slots( + names: list[str], + task_run_id: UUID, + timeout_seconds: Optional[float] = None, +) -> list[MinimalConcurrencyLimitResponse]: + service = ConcurrencySlotAcquisitionService.instance(frozenset(names)) + future = service.send((task_run_id, timeout_seconds)) + try: + response = await asyncio.wrap_future(future) + except TimeoutError as timeout: + raise AcquireConcurrencySlotTimeoutError( + f"Attempt to acquire concurrency limits timed out after {timeout_seconds} second(s)" + ) from timeout + except Exception as exc: + raise ConcurrencySlotAcquisitionError( + f"Unable to acquire concurrency limits {names!r}" + ) from exc + else: + return _response_to_concurrency_limit_response(response) + + +@sync_compatible +async def release_concurrency_slots( + names: list[str], task_run_id: UUID, occupancy_seconds: float +) -> list[MinimalConcurrencyLimitResponse]: + async with get_client() as client: + response = await client.decrement_v1_concurrency_slots( + names=names, + task_run_id=task_run_id, + occupancy_seconds=occupancy_seconds, + ) + return _response_to_concurrency_limit_response(response) + + +def _response_to_concurrency_limit_response( + response: httpx.Response, +) -> list[MinimalConcurrencyLimitResponse]: + data: list[MinimalConcurrencyLimitResponse] = response.json() or [] + return [ + MinimalConcurrencyLimitResponse.model_validate(limit) for limit in data if data + ] diff --git a/src/prefect/concurrency/v1/events.py b/src/prefect/concurrency/v1/_events.py similarity index 65% rename from src/prefect/concurrency/v1/events.py rename to src/prefect/concurrency/v1/_events.py index 3fa5193e6fea..f3924cb1a5d1 100644 --- a/src/prefect/concurrency/v1/events.py +++ b/src/prefect/concurrency/v1/_events.py @@ -1,18 +1,18 @@ -from typing import Dict, List, Literal, Optional, Union +from typing import Literal, Optional, Union from uuid import UUID from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse from prefect.events import Event, RelatedResource, emit_event -def _emit_concurrency_event( +def emit_concurrency_event( phase: Union[Literal["acquired"], Literal["released"]], primary_limit: MinimalConcurrencyLimitResponse, - related_limits: List[MinimalConcurrencyLimitResponse], + related_limits: list[MinimalConcurrencyLimitResponse], task_run_id: UUID, follows: Union[Event, None] = None, ) -> Union[Event, None]: - resource: Dict[str, str] = { + resource: dict[str, str] = { "prefect.resource.id": f"prefect.concurrency-limit.v1.{primary_limit.id}", "prefect.resource.name": primary_limit.name, "limit": str(primary_limit.limit), @@ -38,24 +38,22 @@ def _emit_concurrency_event( ) -def _emit_concurrency_acquisition_events( - limits: List[MinimalConcurrencyLimitResponse], +def emit_concurrency_acquisition_events( + limits: list[MinimalConcurrencyLimitResponse], task_run_id: UUID, -) -> Dict[UUID, Optional[Event]]: - events = {} +) -> dict[UUID, Optional[Event]]: + events: dict[UUID, Optional[Event]] = {} for limit in limits: - event = _emit_concurrency_event("acquired", limit, limits, task_run_id) + event = emit_concurrency_event("acquired", limit, limits, task_run_id) events[limit.id] = event return events -def _emit_concurrency_release_events( - limits: List[MinimalConcurrencyLimitResponse], - events: Dict[UUID, Optional[Event]], +def emit_concurrency_release_events( + limits: list[MinimalConcurrencyLimitResponse], + events: dict[UUID, Optional[Event]], task_run_id: UUID, ) -> None: for limit in limits: - _emit_concurrency_event( - "released", limit, limits, task_run_id, events[limit.id] - ) + emit_concurrency_event("released", limit, limits, task_run_id, events[limit.id]) diff --git a/src/prefect/concurrency/v1/asyncio.py b/src/prefect/concurrency/v1/asyncio.py index 7f888adc7172..35e4c1cc4c37 100644 --- a/src/prefect/concurrency/v1/asyncio.py +++ b/src/prefect/concurrency/v1/asyncio.py @@ -1,42 +1,30 @@ -import asyncio +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import AsyncGenerator, List, Optional, Union, cast +from typing import TYPE_CHECKING, Optional, Union from uuid import UUID import anyio -import httpx import pendulum -from ...client.schemas.responses import MinimalConcurrencyLimitResponse - -try: - from pendulum import Interval -except ImportError: - # pendulum < 3 - from pendulum.period import Period as Interval # type: ignore - -from prefect.client.orchestration import get_client -from prefect.utilities.asyncutils import sync_compatible - -from .context import ConcurrencyContext -from .events import ( - _emit_concurrency_acquisition_events, - _emit_concurrency_release_events, +from prefect.concurrency.v1._asyncio import ( + acquire_concurrency_slots, + release_concurrency_slots, ) -from .services import ConcurrencySlotAcquisitionService - - -class ConcurrencySlotAcquisitionError(Exception): - """Raised when an unhandlable occurs while acquiring concurrency slots.""" - +from prefect.concurrency.v1._events import ( + emit_concurrency_acquisition_events, + emit_concurrency_release_events, +) +from prefect.concurrency.v1.context import ConcurrencyContext -class AcquireConcurrencySlotTimeoutError(TimeoutError): - """Raised when acquiring a concurrency slot times out.""" +from ._asyncio import ( + AcquireConcurrencySlotTimeoutError as AcquireConcurrencySlotTimeoutError, +) +from ._asyncio import ConcurrencySlotAcquisitionError as ConcurrencySlotAcquisitionError @asynccontextmanager async def concurrency( - names: Union[str, List[str]], + names: Union[str, list[str]], task_run_id: UUID, timeout_seconds: Optional[float] = None, ) -> AsyncGenerator[None, None]: @@ -69,24 +57,30 @@ async def main(): yield return - names_normalized: List[str] = names if isinstance(names, list) else [names] + names_normalized: list[str] = names if isinstance(names, list) else [names] - limits = await _acquire_concurrency_slots( + acquire_slots = acquire_concurrency_slots( names_normalized, task_run_id=task_run_id, timeout_seconds=timeout_seconds, ) + if TYPE_CHECKING: + assert not isinstance(acquire_slots, list) + limits = await acquire_slots acquisition_time = pendulum.now("UTC") - emitted_events = _emit_concurrency_acquisition_events(limits, task_run_id) + emitted_events = emit_concurrency_acquisition_events(limits, task_run_id) try: yield finally: - occupancy_period = cast(Interval, (pendulum.now("UTC") - acquisition_time)) + occupancy_period = pendulum.now("UTC") - acquisition_time try: - await _release_concurrency_slots( + release_slots = release_concurrency_slots( names_normalized, task_run_id, occupancy_period.total_seconds() ) + if TYPE_CHECKING: + assert not isinstance(release_slots, list) + await release_slots except anyio.get_cancelled_exc_class(): # The task was cancelled before it could release the slots. Add the # slots to the cleanup list so they can be released when the @@ -96,51 +90,4 @@ async def main(): (names_normalized, occupancy_period.total_seconds(), task_run_id) ) - _emit_concurrency_release_events(limits, emitted_events, task_run_id) - - -@sync_compatible -async def _acquire_concurrency_slots( - names: List[str], - task_run_id: UUID, - timeout_seconds: Optional[float] = None, -) -> List[MinimalConcurrencyLimitResponse]: - service = ConcurrencySlotAcquisitionService.instance(frozenset(names)) - future = service.send((task_run_id, timeout_seconds)) - response_or_exception = await asyncio.wrap_future(future) - - if isinstance(response_or_exception, Exception): - if isinstance(response_or_exception, TimeoutError): - raise AcquireConcurrencySlotTimeoutError( - f"Attempt to acquire concurrency limits timed out after {timeout_seconds} second(s)" - ) from response_or_exception - - raise ConcurrencySlotAcquisitionError( - f"Unable to acquire concurrency limits {names!r}" - ) from response_or_exception - - return _response_to_concurrency_limit_response(response_or_exception) - - -@sync_compatible -async def _release_concurrency_slots( - names: List[str], - task_run_id: UUID, - occupancy_seconds: float, -) -> List[MinimalConcurrencyLimitResponse]: - async with get_client() as client: - response = await client.decrement_v1_concurrency_slots( - names=names, - task_run_id=task_run_id, - occupancy_seconds=occupancy_seconds, - ) - return _response_to_concurrency_limit_response(response) - - -def _response_to_concurrency_limit_response( - response: httpx.Response, -) -> List[MinimalConcurrencyLimitResponse]: - data = response.json() or [] - return [ - MinimalConcurrencyLimitResponse.model_validate(limit) for limit in data if data - ] + emit_concurrency_release_events(limits, emitted_events, task_run_id) diff --git a/src/prefect/concurrency/v1/context.py b/src/prefect/concurrency/v1/context.py index f413c84ed1f4..faaac13a4523 100644 --- a/src/prefect/concurrency/v1/context.py +++ b/src/prefect/concurrency/v1/context.py @@ -1,20 +1,22 @@ from contextvars import ContextVar -from typing import List, Tuple +from typing import Any, ClassVar from uuid import UUID +from typing_extensions import Self + from prefect.client.orchestration import get_client from prefect.context import ContextModel, Field class ConcurrencyContext(ContextModel): - __var__: ContextVar = ContextVar("concurrency_v1") + __var__: ClassVar[ContextVar[Self]] = ContextVar("concurrency_v1") # Track the limits that have been acquired but were not able to be released # due to cancellation or some other error. These limits are released when # the context manager exits. - cleanup_slots: List[Tuple[List[str], float, UUID]] = Field(default_factory=list) + cleanup_slots: list[tuple[list[str], float, UUID]] = Field(default_factory=list) - def __exit__(self, *exc_info): + def __exit__(self, *exc_info: Any) -> None: if self.cleanup_slots: with get_client(sync_client=True) as client: for names, occupancy_seconds, task_run_id in self.cleanup_slots: diff --git a/src/prefect/concurrency/v1/services.py b/src/prefect/concurrency/v1/services.py index 1199c7ef3373..ad8d4b742b45 100644 --- a/src/prefect/concurrency/v1/services.py +++ b/src/prefect/concurrency/v1/services.py @@ -1,21 +1,16 @@ import asyncio -import concurrent.futures +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from json import JSONDecodeError -from typing import ( - TYPE_CHECKING, - AsyncGenerator, - FrozenSet, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Optional from uuid import UUID import httpx from starlette import status +from typing_extensions import Unpack from prefect._internal.concurrency import logger -from prefect._internal.concurrency.services import QueueService +from prefect._internal.concurrency.services import FutureQueueService from prefect.client.orchestration import get_client from prefect.utilities.timeout import timeout_async @@ -27,11 +22,13 @@ class ConcurrencySlotAcquisitionServiceError(Exception): """Raised when an error occurs while acquiring concurrency slots.""" -class ConcurrencySlotAcquisitionService(QueueService): - def __init__(self, concurrency_limit_names: FrozenSet[str]): +class ConcurrencySlotAcquisitionService( + FutureQueueService[Unpack[tuple[UUID, Optional[float]]], httpx.Response] +): + def __init__(self, concurrency_limit_names: frozenset[str]) -> None: super().__init__(concurrency_limit_names) - self._client: "PrefectClient" - self.concurrency_limit_names = sorted(list(concurrency_limit_names)) + self._client: PrefectClient + self.concurrency_limit_names: list[str] = sorted(list(concurrency_limit_names)) @asynccontextmanager async def _lifespan(self) -> AsyncGenerator[None, None]: @@ -39,78 +36,35 @@ async def _lifespan(self) -> AsyncGenerator[None, None]: self._client = client yield - async def _handle( - self, - item: Tuple[ - UUID, - concurrent.futures.Future, - Optional[float], - ], - ) -> None: - task_run_id, future, timeout_seconds = item - try: - response = await self.acquire_slots(task_run_id, timeout_seconds) - except Exception as exc: - # If the request to the increment endpoint fails in a non-standard - # way, we need to set the future's result so that the caller can - # handle the exception and then re-raise. - future.set_result(exc) - raise exc - else: - future.set_result(response) - - async def acquire_slots( - self, - task_run_id: UUID, - timeout_seconds: Optional[float] = None, + async def acquire( + self, task_run_id: UUID, timeout_seconds: Optional[float] = None ) -> httpx.Response: with timeout_async(seconds=timeout_seconds): while True: try: - response = await self._client.increment_v1_concurrency_slots( + return await self._client.increment_v1_concurrency_slots( task_run_id=task_run_id, names=self.concurrency_limit_names, ) - except Exception as exc: - if ( - isinstance(exc, httpx.HTTPStatusError) - and exc.response.status_code == status.HTTP_423_LOCKED - ): - retry_after = exc.response.headers.get("Retry-After") - if retry_after: - retry_after = float(retry_after) - await asyncio.sleep(retry_after) - else: - # We received a 423 but no Retry-After header. This - # should indicate that the server told us to abort - # because the concurrency limit is set to 0, i.e. - # effectively disabled. - try: - reason = exc.response.json()["detail"] - except (JSONDecodeError, KeyError): - logger.error( - "Failed to parse response from concurrency limit 423 Locked response: %s", - exc.response.content, - ) - reason = "Concurrency limit is locked (server did not specify the reason)" - raise ConcurrencySlotAcquisitionServiceError( - reason - ) from exc + except httpx.HTTPStatusError as exc: + if not exc.response.status_code == status.HTTP_423_LOCKED: + raise + retry_after = exc.response.headers.get("Retry-After") + if retry_after: + retry_after = float(retry_after) + await asyncio.sleep(retry_after) else: - raise exc # type: ignore - else: - return response - - def send(self, item: Tuple[UUID, Optional[float]]) -> concurrent.futures.Future: - with self._lock: - if self._stopped: - raise RuntimeError("Cannot put items in a stopped service instance.") - - logger.debug("Service %r enqueuing item %r", self, item) - future: concurrent.futures.Future = concurrent.futures.Future() - - task_run_id, timeout_seconds = item - self._queue.put_nowait((task_run_id, future, timeout_seconds)) - - return future + # We received a 423 but no Retry-After header. This + # should indicate that the server told us to abort + # because the concurrency limit is set to 0, i.e. + # effectively disabled. + try: + reason = exc.response.json()["detail"] + except (JSONDecodeError, KeyError): + logger.error( + "Failed to parse response from concurrency limit 423 Locked response: %s", + exc.response.content, + ) + reason = "Concurrency limit is locked (server did not specify the reason)" + raise ConcurrencySlotAcquisitionServiceError(reason) from exc diff --git a/src/prefect/concurrency/v1/sync.py b/src/prefect/concurrency/v1/sync.py index 6e557b344502..287de878e4be 100644 --- a/src/prefect/concurrency/v1/sync.py +++ b/src/prefect/concurrency/v1/sync.py @@ -1,31 +1,15 @@ +import asyncio +from collections.abc import Generator from contextlib import contextmanager -from typing import ( - Generator, - List, - Optional, - TypeVar, - Union, - cast, -) +from typing import Optional, TypeVar, Union from uuid import UUID import pendulum -from ...client.schemas.responses import MinimalConcurrencyLimitResponse - -try: - from pendulum import Interval -except ImportError: - # pendulum < 3 - from pendulum.period import Period as Interval # type: ignore - -from .asyncio import ( - _acquire_concurrency_slots, - _release_concurrency_slots, -) -from .events import ( - _emit_concurrency_acquisition_events, - _emit_concurrency_release_events, +from ._asyncio import acquire_concurrency_slots, release_concurrency_slots +from ._events import ( + emit_concurrency_acquisition_events, + emit_concurrency_release_events, ) T = TypeVar("T") @@ -33,7 +17,7 @@ @contextmanager def concurrency( - names: Union[str, List[str]], + names: Union[str, list[str]], task_run_id: UUID, timeout_seconds: Optional[float] = None, ) -> Generator[None, None, None]: @@ -69,23 +53,20 @@ def main(): names = names if isinstance(names, list) else [names] - limits: List[MinimalConcurrencyLimitResponse] = _acquire_concurrency_slots( - names, - timeout_seconds=timeout_seconds, - task_run_id=task_run_id, - _sync=True, + force = {"_sync": True} + result = acquire_concurrency_slots( + names, timeout_seconds=timeout_seconds, task_run_id=task_run_id, **force ) + assert not asyncio.iscoroutine(result) + limits = result acquisition_time = pendulum.now("UTC") - emitted_events = _emit_concurrency_acquisition_events(limits, task_run_id) + emitted_events = emit_concurrency_acquisition_events(limits, task_run_id) try: yield finally: - occupancy_period = cast(Interval, pendulum.now("UTC") - acquisition_time) - _release_concurrency_slots( - names, - task_run_id, - occupancy_period.total_seconds(), - _sync=True, + occupancy_period = pendulum.now("UTC") - acquisition_time + release_concurrency_slots( + names, task_run_id, occupancy_period.total_seconds(), **force ) - _emit_concurrency_release_events(limits, emitted_events, task_run_id) + emit_concurrency_release_events(limits, emitted_events, task_run_id) diff --git a/src/prefect/context.py b/src/prefect/context.py index 7f92ad7528ca..31298b0b1b43 100644 --- a/src/prefect/context.py +++ b/src/prefect/context.py @@ -12,7 +12,7 @@ from collections.abc import AsyncGenerator, Generator, Mapping from contextlib import ExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar, Token -from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, TypeVar, Union from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from typing_extensions import Self @@ -47,11 +47,6 @@ from prefect.flows import Flow from prefect.tasks import Task -# Define the global settings context variable -# This will be populated downstream but must be null here to facilitate loading the -# default settings. -GLOBAL_SETTINGS_CONTEXT = None # type: ignore - def serialize_context() -> dict[str, Any]: """ @@ -75,7 +70,7 @@ def serialize_context() -> dict[str, Any]: def hydrated_context( serialized_context: Optional[dict[str, Any]] = None, client: Union[PrefectClient, SyncPrefectClient, None] = None, -): +) -> Generator[None, Any, None]: with ExitStack() as stack: if serialized_context: # Set up settings context @@ -112,10 +107,15 @@ class ContextModel(BaseModel): a context manager """ + if TYPE_CHECKING: + # subclasses can pass through keyword arguments to the pydantic base model + def __init__(self, **kwargs: Any) -> None: + ... + # The context variable for storing data must be defined by the child class - __var__: ContextVar[Self] + __var__: ClassVar[ContextVar[Self]] _token: Optional[Token[Self]] = PrivateAttr(None) - model_config = ConfigDict( + model_config: ClassVar[ConfigDict] = ConfigDict( arbitrary_types_allowed=True, extra="forbid", ) @@ -128,7 +128,7 @@ def __enter__(self) -> Self: self._token = self.__var__.set(self) return self - def __exit__(self, *_): + def __exit__(self, *_: Any) -> None: if not self._token: raise RuntimeError( "Asymmetric use of context. Context exit called without an enter." @@ -143,7 +143,7 @@ def get(cls: type[Self]) -> Optional[Self]: def model_copy( self: Self, *, update: Optional[Mapping[str, Any]] = None, deep: bool = False - ): + ) -> Self: """ Duplicate the context model, optionally choosing which fields to include, exclude, or change. @@ -191,19 +191,19 @@ class SyncClientContext(ContextModel): assert c1 is ctx.client """ - __var__: ContextVar[Self] = ContextVar("sync-client-context") + __var__: ClassVar[ContextVar[Self]] = ContextVar("sync-client-context") client: SyncPrefectClient _httpx_settings: Optional[dict[str, Any]] = PrivateAttr(None) _context_stack: int = PrivateAttr(0) - def __init__(self, httpx_settings: Optional[dict[str, Any]] = None): + def __init__(self, httpx_settings: Optional[dict[str, Any]] = None) -> None: super().__init__( - client=get_client(sync_client=True, httpx_settings=httpx_settings), # type: ignore[reportCallIssue] + client=get_client(sync_client=True, httpx_settings=httpx_settings), ) self._httpx_settings = httpx_settings self._context_stack = 0 - def __enter__(self): + def __enter__(self) -> Self: self._context_stack += 1 if self._context_stack == 1: self.client.__enter__() @@ -212,20 +212,20 @@ def __enter__(self): else: return self - def __exit__(self, *exc_info: Any): + def __exit__(self, *exc_info: Any) -> None: self._context_stack -= 1 if self._context_stack == 0: - self.client.__exit__(*exc_info) # type: ignore[reportUnknownMemberType] - return super().__exit__(*exc_info) # type: ignore[reportUnknownMemberType] + self.client.__exit__(*exc_info) + return super().__exit__(*exc_info) @classmethod @contextmanager - def get_or_create(cls) -> Generator["SyncClientContext", None, None]: - ctx = SyncClientContext.get() + def get_or_create(cls) -> Generator[Self, None, None]: + ctx = cls.get() if ctx: yield ctx else: - with SyncClientContext() as ctx: + with cls() as ctx: yield ctx @@ -249,14 +249,14 @@ class AsyncClientContext(ContextModel): assert c1 is ctx.client """ - __var__ = ContextVar("async-client-context") + __var__: ClassVar[ContextVar[Self]] = ContextVar("async-client-context") client: PrefectClient _httpx_settings: Optional[dict[str, Any]] = PrivateAttr(None) _context_stack: int = PrivateAttr(0) def __init__(self, httpx_settings: Optional[dict[str, Any]] = None): super().__init__( - client=get_client(sync_client=False, httpx_settings=httpx_settings), # type: ignore[reportCallIssue] + client=get_client(sync_client=False, httpx_settings=httpx_settings) ) self._httpx_settings = httpx_settings self._context_stack = 0 @@ -273,8 +273,8 @@ async def __aenter__(self: Self) -> Self: async def __aexit__(self: Self, *exc_info: Any) -> None: self._context_stack -= 1 if self._context_stack == 0: - await self.client.__aexit__(*exc_info) # type: ignore[reportUnknownMemberType] - return super().__exit__(*exc_info) # type: ignore[reportUnknownMemberType] + await self.client.__aexit__(*exc_info) + return super().__exit__(*exc_info) @classmethod @asynccontextmanager @@ -297,7 +297,7 @@ class RunContext(ContextModel): client: The Prefect client instance being used for API communication """ - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) start_client_metrics_server() @@ -356,7 +356,7 @@ class EngineContext(RunContext): # Events worker to emit events events: Optional[EventsWorker] = None - __var__: ContextVar[Self] = ContextVar("flow_run") + __var__: ClassVar[ContextVar[Self]] = ContextVar("flow_run") def serialize(self: Self, include_secrets: bool = True) -> dict[str, Any]: return self.model_dump( @@ -398,7 +398,7 @@ class TaskRunContext(RunContext): result_store: ResultStore persist_result: bool = Field(default_factory=get_default_persist_setting_for_tasks) - __var__ = ContextVar("task_run") + __var__: ClassVar[ContextVar[Self]] = ContextVar("task_run") def serialize(self: Self, include_secrets: bool = True) -> dict[str, Any]: return self.model_dump( @@ -429,11 +429,11 @@ class TagsContext(ContextModel): current_tags: set[str] = Field(default_factory=set) @classmethod - def get(cls) -> "TagsContext": + def get(cls) -> Self: # Return an empty `TagsContext` instead of `None` if no context exists - return cls.__var__.get(TagsContext()) + return cls.__var__.get(cls()) - __var__: ContextVar[Self] = ContextVar("tags") + __var__: ClassVar[ContextVar[Self]] = ContextVar("tags") class SettingsContext(ContextModel): @@ -450,15 +450,21 @@ class SettingsContext(ContextModel): profile: Profile settings: Settings - __var__: ContextVar[Self] = ContextVar("settings") + __var__: ClassVar[ContextVar[Self]] = ContextVar("settings") def __hash__(self: Self) -> int: return hash(self.settings) @classmethod - def get(cls) -> "SettingsContext": + def get(cls) -> Optional["SettingsContext"]: # Return the global context instead of `None` if no context exists - return super().get() or GLOBAL_SETTINGS_CONTEXT + try: + return super().get() or GLOBAL_SETTINGS_CONTEXT + except NameError: + # GLOBAL_SETTINGS_CONTEXT has not yet been set; in order to create + # it profiles need to be loaded, and that process calls + # SettingsContext.get(). + return None def get_run_context() -> Union[FlowRunContext, TaskRunContext]: @@ -559,10 +565,10 @@ def tags(*new_tags: str) -> Generator[set[str], None, None]: @contextmanager def use_profile( - profile: Union[Profile, str, Any], + profile: Union[Profile, str], override_environment_variables: bool = False, include_current_context: bool = True, -): +) -> Generator[SettingsContext, Any, None]: """ Switch to a profile for the duration of this context. @@ -584,11 +590,12 @@ def use_profile( profiles = prefect.settings.load_profiles() profile = profiles[profile] - if not isinstance(profile, Profile): - raise TypeError( - f"Unexpected type {type(profile).__name__!r} for `profile`. " - "Expected 'str' or 'Profile'." - ) + if not TYPE_CHECKING: + if not isinstance(profile, Profile): + raise TypeError( + f"Unexpected type {type(profile).__name__!r} for `profile`. " + "Expected 'str' or 'Profile'." + ) # Create a copy of the profiles settings as we will mutate it profile_settings = profile.settings.copy() @@ -609,7 +616,7 @@ def use_profile( yield ctx -def root_settings_context(): +def root_settings_context() -> SettingsContext: """ Return the settings context that will exist as the root context for the module. @@ -659,9 +666,9 @@ def root_settings_context(): # an override in the `SettingsContext.get` method. -GLOBAL_SETTINGS_CONTEXT: SettingsContext = root_settings_context() # type: ignore[reportConstantRedefinition] +GLOBAL_SETTINGS_CONTEXT: SettingsContext = root_settings_context() # 2024-07-02: This surfaces an actionable error message for removed objects # in Prefect 3.0 upgrade. -__getattr__ = getattr_migration(__name__) +__getattr__: Callable[[str], Any] = getattr_migration(__name__) diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index b0c0e5bcc20d..6a6a2722268f 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -756,9 +756,11 @@ def start( dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, ) -> Generator[None, None, None]: with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies): - with trace.use_span( - self._telemetry.span - ) if self._telemetry.span else nullcontext(): + with ( + trace.use_span(self._telemetry.span) + if self._telemetry.span + else nullcontext() + ): self.begin_run() try: yield @@ -1295,9 +1297,11 @@ async def start( async with self.initialize_run( task_run_id=task_run_id, dependencies=dependencies ): - with trace.use_span( - self._telemetry.span - ) if self._telemetry.span else nullcontext(): + with ( + trace.use_span(self._telemetry.span) + if self._telemetry.span + else nullcontext() + ): await self.begin_run() try: yield diff --git a/tests/concurrency/test_acquire_concurrency_slots.py b/tests/concurrency/test_acquire_concurrency_slots.py index a6d1817051a7..1848d67384f2 100644 --- a/tests/concurrency/test_acquire_concurrency_slots.py +++ b/tests/concurrency/test_acquire_concurrency_slots.py @@ -4,9 +4,7 @@ from httpx import Response from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse -from prefect.concurrency.asyncio import ( - _aacquire_concurrency_slots, -) +from prefect.concurrency._asyncio import aacquire_concurrency_slots async def test_calls_increment_client_method(): @@ -23,7 +21,7 @@ async def test_calls_increment_client_method(): ) increment_concurrency_slots.return_value = response - await _aacquire_concurrency_slots( + await aacquire_concurrency_slots( names=["test-1", "test-2"], slots=1, mode="concurrency" ) increment_concurrency_slots.assert_called_once_with( @@ -48,5 +46,5 @@ async def test_returns_minimal_concurrency_limit(): ) increment_concurrency_slots.return_value = response - result = await _aacquire_concurrency_slots(["test-1", "test-2"], 1) + result = await aacquire_concurrency_slots(["test-1", "test-2"], 1) assert result == limits diff --git a/tests/concurrency/test_concurrency_asyncio.py b/tests/concurrency/test_concurrency_asyncio.py index ff1306e95e01..97b9445021e6 100644 --- a/tests/concurrency/test_concurrency_asyncio.py +++ b/tests/concurrency/test_concurrency_asyncio.py @@ -5,10 +5,12 @@ from starlette import status from prefect import flow, task +from prefect.concurrency._asyncio import ( + aacquire_concurrency_slots, + arelease_concurrency_slots, +) from prefect.concurrency.asyncio import ( ConcurrencySlotAcquisitionError, - _aacquire_concurrency_slots, - _arelease_concurrency_slots, concurrency, rate_limit, ) @@ -28,12 +30,12 @@ async def resource_heavy(): assert not executed with mock.patch( - "prefect.concurrency.asyncio._aacquire_concurrency_slots", - wraps=_aacquire_concurrency_slots, + "prefect.concurrency.asyncio.aacquire_concurrency_slots", + wraps=aacquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.asyncio._arelease_concurrency_slots", - wraps=_arelease_concurrency_slots, + "prefect.concurrency.asyncio.arelease_concurrency_slots", + wraps=arelease_concurrency_slots, ) as release_spy: await resource_heavy() @@ -221,12 +223,12 @@ async def resource_heavy(): assert not executed with mock.patch( - "prefect.concurrency.asyncio._aacquire_concurrency_slots", - wraps=_aacquire_concurrency_slots, + "prefect.concurrency.asyncio.aacquire_concurrency_slots", + wraps=aacquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.asyncio._arelease_concurrency_slots", - wraps=_arelease_concurrency_slots, + "prefect.concurrency.asyncio.arelease_concurrency_slots", + wraps=arelease_concurrency_slots, ) as release_spy: await resource_heavy() @@ -377,7 +379,7 @@ async def resource_heavy(): wraps=lambda *args, **kwargs: None, ) as acquire_spy: with mock.patch( - "prefect.concurrency.sync._arelease_concurrency_slots", + "prefect.concurrency.sync.arelease_concurrency_slots", wraps=lambda *args, **kwargs: None, ) as release_spy: await resource_heavy() @@ -401,12 +403,12 @@ async def resource_heavy(): assert not executed with mock.patch( - "prefect.concurrency.asyncio._aacquire_concurrency_slots", - wraps=_aacquire_concurrency_slots, + "prefect.concurrency.asyncio.aacquire_concurrency_slots", + wraps=aacquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.asyncio._arelease_concurrency_slots", - wraps=_arelease_concurrency_slots, + "prefect.concurrency.asyncio.arelease_concurrency_slots", + wraps=arelease_concurrency_slots, ) as release_spy: await resource_heavy() @@ -447,7 +449,7 @@ async def resource_heavy(): wraps=lambda *args, **kwargs: None, ) as acquire_spy: with mock.patch( - "prefect.concurrency.sync._arelease_concurrency_slots", + "prefect.concurrency.sync.arelease_concurrency_slots", wraps=lambda *args, **kwargs: None, ) as release_spy: await resource_heavy() diff --git a/tests/concurrency/test_concurrency_slot_acquisition_service.py b/tests/concurrency/test_concurrency_slot_acquisition_service.py index 668099cf64a4..e7c880de46e8 100644 --- a/tests/concurrency/test_concurrency_slot_acquisition_service.py +++ b/tests/concurrency/test_concurrency_slot_acquisition_service.py @@ -69,7 +69,7 @@ async def test_retries_failed_call_respects_retry_after_header(mocked_client): limit_names = sorted(["api", "database"]) service = ConcurrencySlotAcquisitionService.instance(frozenset(limit_names)) - with mock.patch("prefect.concurrency.asyncio.asyncio.sleep") as sleep: + with mock.patch("asyncio.sleep") as sleep: future = service.send((1, "concurrency", None, True, None)) await service.drain() returned_response = await asyncio.wrap_future(future) @@ -111,7 +111,8 @@ async def test_basic_exception_returns_exception(mocked_client): future = service.send((1, "concurrency", None, True, None)) await service.drain() - exception = await asyncio.wrap_future(future) - assert isinstance(exception, Exception) - assert exception == exc + with pytest.raises(Exception) as info: + await asyncio.wrap_future(future) + + assert info.value == exc diff --git a/tests/concurrency/test_release_concurrency_slots.py b/tests/concurrency/test_release_concurrency_slots.py index 98d477f724f3..ee9225bb4700 100644 --- a/tests/concurrency/test_release_concurrency_slots.py +++ b/tests/concurrency/test_release_concurrency_slots.py @@ -4,9 +4,7 @@ from httpx import Response from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse -from prefect.concurrency.asyncio import ( - _arelease_concurrency_slots, -) +from prefect.concurrency._asyncio import arelease_concurrency_slots async def test_calls_release_client_method(): @@ -23,7 +21,7 @@ async def test_calls_release_client_method(): ) client_release_concurrency_slots.return_value = response - await _arelease_concurrency_slots( + await arelease_concurrency_slots( names=["test-1", "test-2"], slots=1, occupancy_seconds=1.0 ) client_release_concurrency_slots.assert_called_once_with( @@ -47,5 +45,5 @@ async def test_returns_minimal_concurrency_limit(): ) client_release_concurrency_slots.return_value = response - result = await _arelease_concurrency_slots(["test-1", "test-2"], 1, 1.0) + result = await arelease_concurrency_slots(["test-1", "test-2"], 1, 1.0) assert result == limits diff --git a/tests/concurrency/v1/test_concurrency_asyncio.py b/tests/concurrency/v1/test_concurrency_asyncio.py index 2a3941522622..c2dbf47ee722 100644 --- a/tests/concurrency/v1/test_concurrency_asyncio.py +++ b/tests/concurrency/v1/test_concurrency_asyncio.py @@ -6,12 +6,11 @@ from starlette import status from prefect import flow, task -from prefect.concurrency.v1.asyncio import ( - ConcurrencySlotAcquisitionError, - _acquire_concurrency_slots, - _release_concurrency_slots, - concurrency, +from prefect.concurrency.v1._asyncio import ( + acquire_concurrency_slots, + release_concurrency_slots, ) +from prefect.concurrency.v1.asyncio import ConcurrencySlotAcquisitionError, concurrency from prefect.events.clients import AssertingEventsClient from prefect.events.worker import EventsWorker from prefect.server.schemas.core import ConcurrencyLimit @@ -29,12 +28,12 @@ async def resource_heavy(): assert not executed with mock.patch( - "prefect.concurrency.v1.asyncio._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1.asyncio.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.asyncio._release_concurrency_slots", - wraps=_release_concurrency_slots, + "prefect.concurrency.v1.asyncio.release_concurrency_slots", + wraps=release_concurrency_slots, ) as release_spy: await resource_heavy() @@ -262,11 +261,11 @@ async def resource_heavy(): assert not executed with mock.patch( - "prefect.concurrency.v1.asyncio._acquire_concurrency_slots", + "prefect.concurrency.v1._asyncio.acquire_concurrency_slots", wraps=lambda *args, **kwargs: None, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.asyncio._release_concurrency_slots", + "prefect.concurrency.v1._asyncio.release_concurrency_slots", wraps=lambda *args, **kwargs: None, ) as release_spy: await resource_heavy() diff --git a/tests/concurrency/v1/test_concurrency_limit_acquisition_service.py b/tests/concurrency/v1/test_concurrency_limit_acquisition_service.py index bcadee6bdd30..59251bd1e97a 100644 --- a/tests/concurrency/v1/test_concurrency_limit_acquisition_service.py +++ b/tests/concurrency/v1/test_concurrency_limit_acquisition_service.py @@ -68,7 +68,7 @@ async def test_retries_failed_call_respects_retry_after_header(mocked_client): limit_names = sorted(["api", "database"]) service = ConcurrencySlotAcquisitionService.instance(frozenset(limit_names)) - with mock.patch("prefect.concurrency.v1.asyncio.asyncio.sleep") as sleep: + with mock.patch("asyncio.sleep") as sleep: future = service.send((task_run_id, None)) service.drain() returned_response = await asyncio.wrap_future(future) @@ -112,7 +112,7 @@ async def test_basic_exception_returns_exception(mocked_client): future = service.send((task_run_id, None)) await service.drain() - exception = await asyncio.wrap_future(future) + with pytest.raises(Exception) as info: + await asyncio.wrap_future(future) - assert isinstance(exception, Exception) - assert exception == exc + assert info.value == exc diff --git a/tests/concurrency/v1/test_concurrency_sync.py b/tests/concurrency/v1/test_concurrency_sync.py index d4be64158956..bf16f1bb0d8b 100644 --- a/tests/concurrency/v1/test_concurrency_sync.py +++ b/tests/concurrency/v1/test_concurrency_sync.py @@ -6,9 +6,9 @@ from starlette import status from prefect import flow, task -from prefect.concurrency.v1.asyncio import ( - _acquire_concurrency_slots, - _release_concurrency_slots, +from prefect.concurrency.v1._asyncio import ( + acquire_concurrency_slots, + release_concurrency_slots, ) from prefect.concurrency.v1.sync import concurrency from prefect.events.clients import AssertingEventsClient @@ -28,12 +28,12 @@ def resource_heavy(): assert not executed with mock.patch( - "prefect.concurrency.v1.sync._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1.sync.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.sync._release_concurrency_slots", - wraps=_release_concurrency_slots, + "prefect.concurrency.v1.sync.release_concurrency_slots", + wraps=release_concurrency_slots, ) as release_spy: resource_heavy() @@ -201,11 +201,11 @@ def resource_heavy(): assert not executed with mock.patch( - "prefect.concurrency.v1.sync._acquire_concurrency_slots", + "prefect.concurrency.v1.sync.acquire_concurrency_slots", wraps=lambda *args, **kwargs: None, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.sync._release_concurrency_slots", + "prefect.concurrency.v1.sync.release_concurrency_slots", wraps=lambda *args, **kwargs: None, ) as release_spy: resource_heavy() diff --git a/tests/concurrency/v1/test_decrement_concurrency_slots.py b/tests/concurrency/v1/test_decrement_concurrency_slots.py index 697c12b081a2..0a214f5e4d04 100644 --- a/tests/concurrency/v1/test_decrement_concurrency_slots.py +++ b/tests/concurrency/v1/test_decrement_concurrency_slots.py @@ -4,7 +4,7 @@ from httpx import Response from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse -from prefect.concurrency.v1.asyncio import _release_concurrency_slots +from prefect.concurrency.v1._asyncio import release_concurrency_slots async def test_calls_release_client_method(): @@ -23,7 +23,7 @@ async def test_calls_release_client_method(): ) client_decrement_v1_concurrency_slots.return_value = response - await _release_concurrency_slots( + await release_concurrency_slots( names=["test-1", "test-2"], task_run_id=task_run_id, occupancy_seconds=1.0 ) client_decrement_v1_concurrency_slots.assert_called_once_with( @@ -49,7 +49,7 @@ async def test_returns_minimal_concurrency_limit(): ) client_decrement_v1_concurrency_slots.return_value = response - result = await _release_concurrency_slots( + result = await release_concurrency_slots( ["test-1", "test-2"], task_run_id, 1.0, diff --git a/tests/concurrency/v1/test_increment_concurrency_limits.py b/tests/concurrency/v1/test_increment_concurrency_limits.py index 1ee3f0894349..b856cad8df7a 100644 --- a/tests/concurrency/v1/test_increment_concurrency_limits.py +++ b/tests/concurrency/v1/test_increment_concurrency_limits.py @@ -4,9 +4,7 @@ from httpx import Response from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse -from prefect.concurrency.asyncio import ( - _aacquire_concurrency_slots, -) +from prefect.concurrency._asyncio import aacquire_concurrency_slots async def test_calls_increment_client_method(): @@ -27,7 +25,7 @@ async def test_calls_increment_client_method(): ) increment_concurrency_slots.return_value = response - await _aacquire_concurrency_slots( + await aacquire_concurrency_slots( names=["test-1", "test-2"], slots=1, mode="concurrency" ) increment_concurrency_slots.assert_called_once_with( @@ -56,5 +54,5 @@ async def test_returns_minimal_concurrency_limit(): ) increment_concurrency_slots.return_value = response - result = await _aacquire_concurrency_slots(["test-1", "test-2"], 1) + result = await aacquire_concurrency_slots(["test-1", "test-2"], 1) assert result == limits diff --git a/tests/test_task_engine.py b/tests/test_task_engine.py index 711b19f0ae79..bf4aacee026a 100644 --- a/tests/test_task_engine.py +++ b/tests/test_task_engine.py @@ -19,9 +19,9 @@ from prefect.client.schemas.objects import StateType from prefect.concurrency.asyncio import concurrency as aconcurrency from prefect.concurrency.sync import concurrency -from prefect.concurrency.v1.asyncio import ( - _acquire_concurrency_slots, - _release_concurrency_slots, +from prefect.concurrency.v1._asyncio import ( + acquire_concurrency_slots, + release_concurrency_slots, ) from prefect.context import ( EngineContext, @@ -34,10 +34,7 @@ from prefect.logging import get_run_logger from prefect.results import ResultRecord, ResultStore from prefect.server.schemas.core import ConcurrencyLimitV2 -from prefect.settings import ( - PREFECT_TASK_DEFAULT_RETRIES, - temporary_settings, -) +from prefect.settings import PREFECT_TASK_DEFAULT_RETRIES, temporary_settings from prefect.states import Completed, Running, State from prefect.task_engine import ( AsyncTaskRunEngine, @@ -2512,12 +2509,12 @@ async def bar(): return 42 with mock.patch( - "prefect.concurrency.v1.asyncio._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1.asyncio.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.asyncio._release_concurrency_slots", - wraps=_release_concurrency_slots, + "prefect.concurrency.v1.asyncio.release_concurrency_slots", + wraps=release_concurrency_slots, ) as release_spy: await bar() @@ -2540,12 +2537,12 @@ def bar(): return 42 with mock.patch( - "prefect.concurrency.v1.sync._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1.sync.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.sync._release_concurrency_slots", - wraps=_release_concurrency_slots, + "prefect.concurrency.v1.sync.release_concurrency_slots", + wraps=release_concurrency_slots, ) as release_spy: bar() @@ -2571,12 +2568,12 @@ def bar(): return 42 with mock.patch( - "prefect.concurrency.v1.sync._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1.sync.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.sync._release_concurrency_slots", - wraps=_release_concurrency_slots, + "prefect.concurrency.v1.sync.release_concurrency_slots", + wraps=release_concurrency_slots, ) as release_spy: with tags("limit-tag"): bar() @@ -2603,12 +2600,12 @@ async def bar(): return 42 with mock.patch( - "prefect.concurrency.v1.asyncio._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1.asyncio.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.asyncio._release_concurrency_slots", - wraps=_release_concurrency_slots, + "prefect.concurrency.v1.asyncio.release_concurrency_slots", + wraps=release_concurrency_slots, ) as release_spy: with tags("limit-tag"): await bar() @@ -2628,12 +2625,12 @@ async def bar(): return 42 with mock.patch( - "prefect.concurrency.v1.asyncio._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1._asyncio.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.asyncio._release_concurrency_slots", - wraps=_release_concurrency_slots, + "prefect.concurrency.v1._asyncio.release_concurrency_slots", + wraps=release_concurrency_slots, ) as release_spy: await bar() @@ -2646,12 +2643,12 @@ def bar(): return 42 with mock.patch( - "prefect.concurrency.v1.sync._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1.sync.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.v1.sync._release_concurrency_slots", - wraps=_release_concurrency_slots, + "prefect.concurrency.v1.sync.release_concurrency_slots", + wraps=release_concurrency_slots, ) as release_spy: bar() @@ -2668,8 +2665,8 @@ async def bar(): return 42 with mock.patch( - "prefect.concurrency.v1.asyncio._acquire_concurrency_slots", - wraps=_acquire_concurrency_slots, + "prefect.concurrency.v1.asyncio.acquire_concurrency_slots", + wraps=acquire_concurrency_slots, ) as acquire_spy: await bar()