Skip to content

Commit

Permalink
Adds a private status endpoint to TaskWorker (#14089)
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisguidry authored Jun 17, 2024
1 parent cbb0788 commit 2589158
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 4 deletions.
78 changes: 75 additions & 3 deletions src/prefect/task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

import anyio
import anyio.abc
import uvicorn
from exceptiongroup import BaseExceptionGroup # novermin
from fastapi import FastAPI
from websockets.exceptions import InvalidStatusCode

from prefect import Task
Expand Down Expand Up @@ -89,10 +91,30 @@ def __init__(
self._executor = ThreadPoolExecutor(max_workers=limit if limit else None)
self._limiter = anyio.CapacityLimiter(limit) if limit else None

self.in_flight_task_runs = {
task.task_key: set() for task in self.tasks if isinstance(task, Task)
}

@property
def _client_id(self) -> str:
def client_id(self) -> str:
return f"{socket.gethostname()}-{os.getpid()}"

@property
def limit(self) -> Optional[int]:
return int(self._limiter.total_tokens) if self._limiter else None

@property
def current_tasks(self) -> Optional[int]:
return (
int(self._limiter.borrowed_tokens)
if self._limiter
else sum(len(runs) for runs in self.in_flight_task_runs.values())
)

@property
def available_tasks(self) -> Optional[int]:
return int(self._limiter.available_tokens) if self._limiter else None

def handle_sigterm(self, signum, frame):
"""
Shuts down the task worker when a SIGTERM is received.
Expand Down Expand Up @@ -153,7 +175,7 @@ async def _subscribe_to_task_scheduling(self):
model=TaskRun,
path="/task_runs/subscriptions/scheduled",
keys=[task.task_key for task in self.tasks],
client_id=self._client_id,
client_id=self.client_id,
base_url=base_url,
):
logger.info(f"Received task run: {task_run.id} - {task_run.name}")
Expand All @@ -164,6 +186,7 @@ async def _subscribe_to_task_scheduling(self):
)

async def _safe_submit_scheduled_task_run(self, task_run: TaskRun):
self.in_flight_task_runs[task_run.task_key].add(task_run.id)
try:
await self._submit_scheduled_task_run(task_run)
except BaseException as exc:
Expand All @@ -172,6 +195,7 @@ async def _safe_submit_scheduled_task_run(self, task_run: TaskRun):
exc_info=exc,
)
finally:
self.in_flight_task_runs[task_run.task_key].discard(task_run.id)
if self._limiter:
self._limiter.release_on_behalf_of(task_run.id)

Expand Down Expand Up @@ -307,8 +331,31 @@ async def __aexit__(self, *exc_info):
await self._exit_stack.__aexit__(*exc_info)


def create_status_server(task_worker: TaskWorker) -> FastAPI:
status_app = FastAPI()

@status_app.get("/status")
def status():
return {
"client_id": task_worker.client_id,
"started": task_worker.started,
"stopping": task_worker.stopping,
"limit": task_worker.limit,
"current": task_worker.current_tasks,
"available": task_worker.available_tasks,
"tasks": {
key: list(sorted(tasks))
for key, tasks in task_worker.in_flight_task_runs.items()
},
}

return status_app


@sync_compatible
async def serve(*tasks: Task, limit: Optional[int] = 10):
async def serve(
*tasks: Task, limit: Optional[int] = 10, status_server_port: Optional[int] = None
):
"""Serve the provided tasks so that their runs may be submitted to and executed.
in the engine. Tasks do not need to be within a flow run context to be submitted.
You must `.submit` the same task object that you pass to `serve`.
Expand All @@ -318,6 +365,9 @@ async def serve(*tasks: Task, limit: Optional[int] = 10):
given task, the task run will be submitted to the engine for execution.
- limit: The maximum number of tasks that can be run concurrently. Defaults to 10.
Pass `None` to remove the limit.
- status_server_port: An optional port on which to start an HTTP server
exposing status information about the task worker. If not provided, no
status server will run.
Example:
```python
Expand All @@ -339,6 +389,20 @@ def yell(message: str):
"""
task_worker = TaskWorker(*tasks, limit=limit)

status_server_task = None
if status_server_port is not None:
server = uvicorn.Server(
uvicorn.Config(
app=create_status_server(task_worker),
host="127.0.0.1",
port=status_server_port,
access_log=False,
log_level="warning",
)
)
loop = asyncio.get_event_loop()
status_server_task = loop.create_task(server.serve())

try:
await task_worker.start()

Expand All @@ -355,3 +419,11 @@ def yell(message: str):

except (asyncio.CancelledError, KeyboardInterrupt):
logger.info("Task worker interrupted, stopping...")

finally:
if status_server_task:
status_server_task.cancel()
try:
await status_server_task
except asyncio.CancelledError:
pass
2 changes: 1 addition & 1 deletion tests/test_task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ async def test_task_worker_client_id_is_set():
task_worker = TaskWorker(...)
task_worker._client = MagicMock(api_url="http://localhost:4200")

assert task_worker._client_id == "foo-42"
assert task_worker.client_id == "foo-42"


async def test_task_worker_handles_aborted_task_run_submission(
Expand Down

0 comments on commit 2589158

Please sign in to comment.