-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Replicate concurrency helper from v2 for v1 limits (#15037)
Co-authored-by: Alexander Streed <[email protected]> Co-authored-by: Chris Guidry <[email protected]>
- Loading branch information
1 parent
737f536
commit eac7892
Showing
21 changed files
with
1,481 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import asyncio | ||
from contextlib import asynccontextmanager | ||
from typing import AsyncGenerator, List, Optional, Union, cast | ||
from uuid import UUID | ||
|
||
import anyio | ||
import httpx | ||
import pendulum | ||
|
||
from ...client.schemas.responses import MinimalConcurrencyLimitResponse | ||
|
||
try: | ||
from pendulum import Interval | ||
except ImportError: | ||
# pendulum < 3 | ||
from pendulum.period import Period as Interval # type: ignore | ||
|
||
from prefect.client.orchestration import get_client | ||
|
||
from .context import ConcurrencyContext | ||
from .events import ( | ||
_emit_concurrency_acquisition_events, | ||
_emit_concurrency_release_events, | ||
) | ||
from .services import ConcurrencySlotAcquisitionService | ||
|
||
|
||
class ConcurrencySlotAcquisitionError(Exception): | ||
"""Raised when an unhandlable occurs while acquiring concurrency slots.""" | ||
|
||
|
||
class AcquireConcurrencySlotTimeoutError(TimeoutError): | ||
"""Raised when acquiring a concurrency slot times out.""" | ||
|
||
|
||
@asynccontextmanager | ||
async def concurrency( | ||
names: Union[str, List[str]], | ||
task_run_id: UUID, | ||
timeout_seconds: Optional[float] = None, | ||
) -> AsyncGenerator[None, None]: | ||
"""A context manager that acquires and releases concurrency slots from the | ||
given concurrency limits. | ||
Args: | ||
names: The names of the concurrency limits to acquire slots from. | ||
task_run_id: The name of the task_run_id that is incrementing the slots. | ||
timeout_seconds: The number of seconds to wait for the slots to be acquired before | ||
raising a `TimeoutError`. A timeout of `None` will wait indefinitely. | ||
Raises: | ||
TimeoutError: If the slots are not acquired within the given timeout. | ||
Example: | ||
A simple example of using the async `concurrency` context manager: | ||
```python | ||
from prefect.concurrency.v1.asyncio import concurrency | ||
async def resource_heavy(): | ||
async with concurrency("test", task_run_id): | ||
print("Resource heavy task") | ||
async def main(): | ||
await resource_heavy() | ||
``` | ||
""" | ||
if not names: | ||
yield | ||
return | ||
|
||
names_normalized: List[str] = names if isinstance(names, list) else [names] | ||
|
||
limits = await _acquire_concurrency_slots( | ||
names_normalized, | ||
task_run_id=task_run_id, | ||
timeout_seconds=timeout_seconds, | ||
) | ||
acquisition_time = pendulum.now("UTC") | ||
emitted_events = _emit_concurrency_acquisition_events(limits, task_run_id) | ||
|
||
try: | ||
yield | ||
finally: | ||
occupancy_period = cast(Interval, (pendulum.now("UTC") - acquisition_time)) | ||
try: | ||
await _release_concurrency_slots( | ||
names_normalized, task_run_id, occupancy_period.total_seconds() | ||
) | ||
except anyio.get_cancelled_exc_class(): | ||
# The task was cancelled before it could release the slots. Add the | ||
# slots to the cleanup list so they can be released when the | ||
# concurrency context is exited. | ||
if ctx := ConcurrencyContext.get(): | ||
ctx.cleanup_slots.append( | ||
(names_normalized, occupancy_period.total_seconds(), task_run_id) | ||
) | ||
|
||
_emit_concurrency_release_events(limits, emitted_events, task_run_id) | ||
|
||
|
||
async def _acquire_concurrency_slots( | ||
names: List[str], | ||
task_run_id: UUID, | ||
timeout_seconds: Optional[float] = None, | ||
) -> List[MinimalConcurrencyLimitResponse]: | ||
service = ConcurrencySlotAcquisitionService.instance(frozenset(names)) | ||
future = service.send((task_run_id, timeout_seconds)) | ||
response_or_exception = await asyncio.wrap_future(future) | ||
|
||
if isinstance(response_or_exception, Exception): | ||
if isinstance(response_or_exception, TimeoutError): | ||
raise AcquireConcurrencySlotTimeoutError( | ||
f"Attempt to acquire concurrency limits timed out after {timeout_seconds} second(s)" | ||
) from response_or_exception | ||
|
||
raise ConcurrencySlotAcquisitionError( | ||
f"Unable to acquire concurrency limits {names!r}" | ||
) from response_or_exception | ||
|
||
return _response_to_concurrency_limit_response(response_or_exception) | ||
|
||
|
||
async def _release_concurrency_slots( | ||
names: List[str], | ||
task_run_id: UUID, | ||
occupancy_seconds: float, | ||
) -> List[MinimalConcurrencyLimitResponse]: | ||
async with get_client() as client: | ||
response = await client.decrement_v1_concurrency_slots( | ||
names=names, | ||
task_run_id=task_run_id, | ||
occupancy_seconds=occupancy_seconds, | ||
) | ||
return _response_to_concurrency_limit_response(response) | ||
|
||
|
||
def _response_to_concurrency_limit_response( | ||
response: httpx.Response, | ||
) -> List[MinimalConcurrencyLimitResponse]: | ||
data = response.json() or [] | ||
return [ | ||
MinimalConcurrencyLimitResponse.model_validate(limit) for limit in data if data | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from contextvars import ContextVar | ||
from typing import List, Tuple | ||
from uuid import UUID | ||
|
||
from prefect.client.orchestration import get_client | ||
from prefect.context import ContextModel, Field | ||
|
||
|
||
class ConcurrencyContext(ContextModel): | ||
__var__: ContextVar = ContextVar("concurrency_v1") | ||
|
||
# Track the limits that have been acquired but were not able to be released | ||
# due to cancellation or some other error. These limits are released when | ||
# the context manager exits. | ||
cleanup_slots: List[Tuple[List[str], float, UUID]] = Field(default_factory=list) | ||
|
||
def __exit__(self, *exc_info): | ||
if self.cleanup_slots: | ||
with get_client(sync_client=True) as client: | ||
for names, occupancy_seconds, task_run_id in self.cleanup_slots: | ||
client.decrement_v1_concurrency_slots( | ||
names=names, | ||
occupancy_seconds=occupancy_seconds, | ||
task_run_id=task_run_id, | ||
) | ||
|
||
return super().__exit__(*exc_info) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from typing import Dict, List, Literal, Optional, Union | ||
from uuid import UUID | ||
|
||
from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse | ||
from prefect.events import Event, RelatedResource, emit_event | ||
|
||
|
||
def _emit_concurrency_event( | ||
phase: Union[Literal["acquired"], Literal["released"]], | ||
primary_limit: MinimalConcurrencyLimitResponse, | ||
related_limits: List[MinimalConcurrencyLimitResponse], | ||
task_run_id: UUID, | ||
follows: Union[Event, None] = None, | ||
) -> Union[Event, None]: | ||
resource: Dict[str, str] = { | ||
"prefect.resource.id": f"prefect.concurrency-limit.v1.{primary_limit.id}", | ||
"prefect.resource.name": primary_limit.name, | ||
"limit": str(primary_limit.limit), | ||
"task_run_id": str(task_run_id), | ||
} | ||
|
||
related = [ | ||
RelatedResource.model_validate( | ||
{ | ||
"prefect.resource.id": f"prefect.concurrency-limit.v1.{limit.id}", | ||
"prefect.resource.role": "concurrency-limit", | ||
} | ||
) | ||
for limit in related_limits | ||
if limit.id != primary_limit.id | ||
] | ||
|
||
return emit_event( | ||
f"prefect.concurrency-limit.v1.{phase}", | ||
resource=resource, | ||
related=related, | ||
follows=follows, | ||
) | ||
|
||
|
||
def _emit_concurrency_acquisition_events( | ||
limits: List[MinimalConcurrencyLimitResponse], | ||
task_run_id: UUID, | ||
) -> Dict[UUID, Optional[Event]]: | ||
events = {} | ||
for limit in limits: | ||
event = _emit_concurrency_event("acquired", limit, limits, task_run_id) | ||
events[limit.id] = event | ||
|
||
return events | ||
|
||
|
||
def _emit_concurrency_release_events( | ||
limits: List[MinimalConcurrencyLimitResponse], | ||
events: Dict[UUID, Optional[Event]], | ||
task_run_id: UUID, | ||
) -> None: | ||
for limit in limits: | ||
_emit_concurrency_event( | ||
"released", limit, limits, task_run_id, events[limit.id] | ||
) |
Oops, something went wrong.