diff --git a/src/prefect/task_worker.py b/src/prefect/task_worker.py index 2f35d02c1eab..b5c40e9611e6 100644 --- a/src/prefect/task_worker.py +++ b/src/prefect/task_worker.py @@ -11,7 +11,9 @@ import anyio import anyio.abc +import uvicorn from exceptiongroup import BaseExceptionGroup # novermin +from fastapi import FastAPI from websockets.exceptions import InvalidStatusCode from prefect import Task @@ -89,10 +91,28 @@ def __init__( self._executor = ThreadPoolExecutor(max_workers=limit if limit else None) self._limiter = anyio.CapacityLimiter(limit) if limit else None + self.in_flight_task_runs = {task.task_key: set() for task in self.tasks} + @property - def _client_id(self) -> str: + def client_id(self) -> str: return f"{socket.gethostname()}-{os.getpid()}" + @property + def limit(self) -> Optional[int]: + return int(self._limiter.total_tokens) if self._limiter else None + + @property + def current_tasks(self) -> Optional[int]: + return ( + int(self._limiter.borrowed_tokens) + if self._limiter + else sum(len(runs) for runs in self.in_flight_task_runs.values()) + ) + + @property + def available_tasks(self) -> Optional[int]: + return int(self._limiter.available_tokens) if self._limiter else None + def handle_sigterm(self, signum, frame): """ Shuts down the task worker when a SIGTERM is received. @@ -153,7 +173,7 @@ async def _subscribe_to_task_scheduling(self): model=TaskRun, path="/task_runs/subscriptions/scheduled", keys=[task.task_key for task in self.tasks], - client_id=self._client_id, + client_id=self.client_id, base_url=base_url, ): logger.info(f"Received task run: {task_run.id} - {task_run.name}") @@ -164,6 +184,7 @@ async def _subscribe_to_task_scheduling(self): ) async def _safe_submit_scheduled_task_run(self, task_run: TaskRun): + self.in_flight_task_runs[task_run.task_key].add(task_run.id) try: await self._submit_scheduled_task_run(task_run) except BaseException as exc: @@ -172,6 +193,7 @@ async def _safe_submit_scheduled_task_run(self, task_run: TaskRun): exc_info=exc, ) finally: + self.in_flight_task_runs[task_run.task_key].discard(task_run.id) if self._limiter: self._limiter.release_on_behalf_of(task_run.id) @@ -307,6 +329,27 @@ async def __aexit__(self, *exc_info): await self._exit_stack.__aexit__(*exc_info) +def create_status_server(task_worker: TaskWorker) -> FastAPI: + status_app = FastAPI() + + @status_app.get("/status") + def status(): + return { + "client_id": task_worker.client_id, + "started": task_worker.started, + "stopping": task_worker.stopping, + "limit": task_worker.limit, + "current": task_worker.current_tasks, + "available": task_worker.available_tasks, + "tasks": { + key: list(sorted(tasks)) + for key, tasks in task_worker.in_flight_task_runs.items() + }, + } + + return status_app + + @sync_compatible async def serve(*tasks: Task, limit: Optional[int] = 10): """Serve the provided tasks so that their runs may be submitted to and executed. @@ -339,6 +382,19 @@ def yell(message: str): """ task_worker = TaskWorker(*tasks, limit=limit) + loop = asyncio.get_event_loop() + + server = uvicorn.Server( + uvicorn.Config( + app=create_status_server(task_worker), + host="127.0.0.1", + port=4422, + access_log=False, + log_level="warning", + ) + ) + status_server = loop.create_task(server.serve()) + try: await task_worker.start() @@ -355,3 +411,10 @@ def yell(message: str): except (asyncio.CancelledError, KeyboardInterrupt): logger.info("Task worker interrupted, stopping...") + + finally: + status_server.cancel() + try: + await status_server + except asyncio.CancelledError: + pass diff --git a/tests/test_task_worker.py b/tests/test_task_worker.py index a4fcf5e790d5..62068ca85900 100644 --- a/tests/test_task_worker.py +++ b/tests/test_task_worker.py @@ -131,7 +131,7 @@ async def test_task_worker_client_id_is_set(): task_worker = TaskWorker(...) task_worker._client = MagicMock(api_url="http://localhost:4200") - assert task_worker._client_id == "foo-42" + assert task_worker.client_id == "foo-42" async def test_task_worker_handles_aborted_task_run_submission(