Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

acquire/release token on TaskWorker without errors #14084

Merged
merged 6 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions src/prefect/task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from contextlib import AsyncExitStack
from contextvars import copy_context
from typing import List, Optional
from uuid import UUID

import anyio
import anyio.abc
Expand Down Expand Up @@ -160,6 +161,26 @@ async def stop(self):

raise StopTaskWorker

async def _acquire_token(self, task_run_id: UUID) -> bool:
try:
if self._limiter:
await self._limiter.acquire_on_behalf_of(task_run_id)
except RuntimeError:
logger.debug(f"Token already acquired for task run: {task_run_id!r}")
return False

return True

def _release_token(self, task_run_id: UUID) -> bool:
try:
if self._limiter:
self._limiter.release_on_behalf_of(task_run_id)
except RuntimeError:
logger.debug(f"No token to release for task run: {task_run_id!r}")
return False

return True

async def _subscribe_to_task_scheduling(self):
base_url = PREFECT_API_URL.value()
if base_url is None:
Expand All @@ -179,11 +200,12 @@ async def _subscribe_to_task_scheduling(self):
base_url=base_url,
):
logger.info(f"Received task run: {task_run.id} - {task_run.name}")
if self._limiter:
await self._limiter.acquire_on_behalf_of(task_run.id)
self._runs_task_group.start_soon(
self._safe_submit_scheduled_task_run, task_run
)

token_acquired = await self._acquire_token(task_run.id)
if token_acquired:
self._runs_task_group.start_soon(
self._safe_submit_scheduled_task_run, task_run
)

async def _safe_submit_scheduled_task_run(self, task_run: TaskRun):
self.in_flight_task_runs[task_run.task_key].add(task_run.id)
Expand All @@ -196,8 +218,7 @@ async def _safe_submit_scheduled_task_run(self, task_run: TaskRun):
)
finally:
self.in_flight_task_runs[task_run.task_key].discard(task_run.id)
if self._limiter:
self._limiter.release_on_behalf_of(task_run.id)
self._release_token(task_run.id)

async def _submit_scheduled_task_run(self, task_run: TaskRun):
logger.debug(
Expand Down Expand Up @@ -308,9 +329,9 @@ async def _submit_scheduled_task_run(self, task_run: TaskRun):
async def execute_task_run(self, task_run: TaskRun):
"""Execute a task run in the task worker."""
async with self if not self.started else asyncnullcontext():
if self._limiter:
await self._limiter.acquire_on_behalf_of(task_run.id)
await self._safe_submit_scheduled_task_run(task_run)
token_acquired = await self._acquire_token(task_run.id)
if token_acquired:
await self._safe_submit_scheduled_task_run(task_run)

async def __aenter__(self):
logger.debug("Starting task worker...")
Expand Down
30 changes: 30 additions & 0 deletions tests/test_task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,36 @@ async def register_localfilesystem(self):
"""Register LocalFileSystem before running tests to avoid race conditions."""
await LocalFileSystem.register_type_and_schema()

async def test_task_worker_limiter_gracefully_handles_same_task_run(
self, prefect_client
):
@task
def slow_task():
import time

time.sleep(1)

task_worker = TaskWorker(slow_task, limit=1)

task_run_future = slow_task.apply_async()
task_run = await prefect_client.read_task_run(task_run_future.task_run_id)

try:
with anyio.move_on_after(1):
# run same task, one should acquire a token
# the other will gracefully be skipped.
async with task_worker:
await asyncio.gather(
task_worker.execute_task_run(task_run),
task_worker.execute_task_run(task_run),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On main this errors with this borrower is already holding one of this CapacityLimiter's tokens

)
except asyncio.exceptions.CancelledError:
# we expect a cancelled error here
pass

updated_task_run = await prefect_client.read_task_run(task_run.id)
assert updated_task_run.state.is_completed()

async def test_task_worker_respects_limit(self, mock_subscription, prefect_client):
@task
def slow_task():
Expand Down