From 8432425d7c5a14f6df1715d4ac4d1eca1f00fccf Mon Sep 17 00:00:00 2001 From: "jake@prefect.io" Date: Mon, 17 Jun 2024 13:52:05 -0400 Subject: [PATCH 1/5] acquire in try-except --- src/prefect/task_worker.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/prefect/task_worker.py b/src/prefect/task_worker.py index 2f35d02c1eab..5d4fb296b244 100644 --- a/src/prefect/task_worker.py +++ b/src/prefect/task_worker.py @@ -157,14 +157,15 @@ async def _subscribe_to_task_scheduling(self): base_url=base_url, ): logger.info(f"Received task run: {task_run.id} - {task_run.name}") - if self._limiter: - await self._limiter.acquire_on_behalf_of(task_run.id) + self._runs_task_group.start_soon( self._safe_submit_scheduled_task_run, task_run ) async def _safe_submit_scheduled_task_run(self, task_run: TaskRun): try: + if self._limiter: + await self._limiter.acquire_on_behalf_of(task_run.id) await self._submit_scheduled_task_run(task_run) except BaseException as exc: logger.exception( @@ -284,8 +285,6 @@ async def _submit_scheduled_task_run(self, task_run: TaskRun): async def execute_task_run(self, task_run: TaskRun): """Execute a task run in the task worker.""" async with self if not self.started else asyncnullcontext(): - if self._limiter: - await self._limiter.acquire_on_behalf_of(task_run.id) await self._safe_submit_scheduled_task_run(task_run) async def __aenter__(self): From 3d95cd3b71c10e0a0a9298da88889af89c0b546a Mon Sep 17 00:00:00 2001 From: "jake@prefect.io" Date: Mon, 17 Jun 2024 14:44:18 -0400 Subject: [PATCH 2/5] aquire and release a token safely --- src/prefect/task_worker.py | 46 +++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/src/prefect/task_worker.py b/src/prefect/task_worker.py index 5d4fb296b244..40a080ce5ace 100644 --- a/src/prefect/task_worker.py +++ b/src/prefect/task_worker.py @@ -8,6 +8,7 @@ from contextlib import AsyncExitStack from contextvars import copy_context from typing import List, Optional +from uuid import UUID import anyio import anyio.abc @@ -138,6 +139,26 @@ async def stop(self): raise StopTaskWorker + async def _acquire_token(self, task_run_id: UUID) -> bool: + try: + if self._limiter: + await self._limiter.acquire_on_behalf_of(task_run_id) + except RuntimeError: + logger.debug(f"Failed to acquire token for task run: {task_run_id!r}") + return False + + return True + + def _release_token(self, task_run_id: UUID) -> bool: + try: + if self._limiter: + self._limiter.release_on_behalf_of(task_run_id) + except RuntimeError: + logger.debug(f"Token for task run: {task_run_id!r} was never taken") + return False + + return True + async def _subscribe_to_task_scheduling(self): base_url = PREFECT_API_URL.value() if base_url is None: @@ -158,14 +179,18 @@ async def _subscribe_to_task_scheduling(self): ): logger.info(f"Received task run: {task_run.id} - {task_run.name}") - self._runs_task_group.start_soon( - self._safe_submit_scheduled_task_run, task_run - ) + token_acquired = await self._acquire_token(task_run.id) + if token_acquired: + self._runs_task_group.start_soon( + self._safe_submit_scheduled_task_run, task_run + ) + else: + logger.info( + f"Skipping task run {task_run.id!r} because limit is reached" + ) async def _safe_submit_scheduled_task_run(self, task_run: TaskRun): try: - if self._limiter: - await self._limiter.acquire_on_behalf_of(task_run.id) await self._submit_scheduled_task_run(task_run) except BaseException as exc: logger.exception( @@ -173,8 +198,7 @@ async def _safe_submit_scheduled_task_run(self, task_run: TaskRun): exc_info=exc, ) finally: - if self._limiter: - self._limiter.release_on_behalf_of(task_run.id) + self._release_token(task_run.id) async def _submit_scheduled_task_run(self, task_run: TaskRun): logger.debug( @@ -285,7 +309,13 @@ async def _submit_scheduled_task_run(self, task_run: TaskRun): async def execute_task_run(self, task_run: TaskRun): """Execute a task run in the task worker.""" async with self if not self.started else asyncnullcontext(): - await self._safe_submit_scheduled_task_run(task_run) + token_acquired = await self._acquire_token(task_run.id) + if token_acquired: + await self._safe_submit_scheduled_task_run(task_run) + else: + logger.info( + f"Skipping task run {task_run.id!r} because limit is reached" + ) async def __aenter__(self): logger.debug("Starting task worker...") From 32de8b2a207bfeaefe44cbfb78dfe3d71e8ed128 Mon Sep 17 00:00:00 2001 From: "jake@prefect.io" Date: Mon, 17 Jun 2024 14:52:52 -0400 Subject: [PATCH 3/5] fix logs --- src/prefect/task_worker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/prefect/task_worker.py b/src/prefect/task_worker.py index 40a080ce5ace..8fa97edc5a01 100644 --- a/src/prefect/task_worker.py +++ b/src/prefect/task_worker.py @@ -144,7 +144,7 @@ async def _acquire_token(self, task_run_id: UUID) -> bool: if self._limiter: await self._limiter.acquire_on_behalf_of(task_run_id) except RuntimeError: - logger.debug(f"Failed to acquire token for task run: {task_run_id!r}") + logger.debug(f"Token already acquired for task run: {task_run_id!r}") return False return True @@ -154,7 +154,7 @@ def _release_token(self, task_run_id: UUID) -> bool: if self._limiter: self._limiter.release_on_behalf_of(task_run_id) except RuntimeError: - logger.debug(f"Token for task run: {task_run_id!r} was never taken") + logger.debug(f"No token to release for task run: {task_run_id!r}") return False return True @@ -186,7 +186,7 @@ async def _subscribe_to_task_scheduling(self): ) else: logger.info( - f"Skipping task run {task_run.id!r} because limit is reached" + f"TaskWorker run limit reached. Skipping task run {task_run.id!r}" ) async def _safe_submit_scheduled_task_run(self, task_run: TaskRun): From 578d36d2b504cfe51698f13d9b9fb9fd26440e46 Mon Sep 17 00:00:00 2001 From: "jake@prefect.io" Date: Mon, 17 Jun 2024 14:57:04 -0400 Subject: [PATCH 4/5] don't log unessecarrily --- src/prefect/task_worker.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/prefect/task_worker.py b/src/prefect/task_worker.py index 8fa97edc5a01..a30479ac46ea 100644 --- a/src/prefect/task_worker.py +++ b/src/prefect/task_worker.py @@ -184,10 +184,6 @@ async def _subscribe_to_task_scheduling(self): self._runs_task_group.start_soon( self._safe_submit_scheduled_task_run, task_run ) - else: - logger.info( - f"TaskWorker run limit reached. Skipping task run {task_run.id!r}" - ) async def _safe_submit_scheduled_task_run(self, task_run: TaskRun): try: @@ -312,10 +308,6 @@ async def execute_task_run(self, task_run: TaskRun): token_acquired = await self._acquire_token(task_run.id) if token_acquired: await self._safe_submit_scheduled_task_run(task_run) - else: - logger.info( - f"Skipping task run {task_run.id!r} because limit is reached" - ) async def __aenter__(self): logger.debug("Starting task worker...") From b780e758aab9a1f0c30119eb2f48b00f1bcf361d Mon Sep 17 00:00:00 2001 From: "jake@prefect.io" Date: Mon, 17 Jun 2024 15:56:53 -0400 Subject: [PATCH 5/5] add test --- tests/test_task_worker.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/test_task_worker.py b/tests/test_task_worker.py index a4fcf5e790d5..2de777007fb9 100644 --- a/tests/test_task_worker.py +++ b/tests/test_task_worker.py @@ -672,6 +672,36 @@ async def register_localfilesystem(self): """Register LocalFileSystem before running tests to avoid race conditions.""" await LocalFileSystem.register_type_and_schema() + async def test_task_worker_limiter_gracefully_handles_same_task_run( + self, prefect_client + ): + @task + def slow_task(): + import time + + time.sleep(1) + + task_worker = TaskWorker(slow_task, limit=1) + + task_run_future = slow_task.apply_async() + task_run = await prefect_client.read_task_run(task_run_future.task_run_id) + + try: + with anyio.move_on_after(1): + # run same task, one should acquire a token + # the other will gracefully be skipped. + async with task_worker: + await asyncio.gather( + task_worker.execute_task_run(task_run), + task_worker.execute_task_run(task_run), + ) + except asyncio.exceptions.CancelledError: + # we expect a cancelled error here + pass + + updated_task_run = await prefect_client.read_task_run(task_run.id) + assert updated_task_run.state.is_completed() + async def test_task_worker_respects_limit(self, mock_subscription, prefect_client): @task def slow_task():