Skip to content

Commit

Permalink
Adds a private status endpoint to TaskWorker
Browse files Browse the repository at this point in the history
This endpoint, at `http://127.0.0.1:4422/status` reports the current state of
the `TaskWorker`, including what its limiting parameters look like and which
task runs of each key are currently in-flight.  This is a rough-and-ready
monitoring endpoint that we can flesh out with more info and tests as we go, but
for now, we need it for load testing and reliability troubleshooting.
  • Loading branch information
chrisguidry committed Jun 17, 2024
1 parent d28cf8b commit 8aa43df
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 3 deletions.
67 changes: 65 additions & 2 deletions src/prefect/task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

import anyio
import anyio.abc
import uvicorn
from exceptiongroup import BaseExceptionGroup # novermin
from fastapi import FastAPI
from websockets.exceptions import InvalidStatusCode

from prefect import Task
Expand Down Expand Up @@ -89,10 +91,28 @@ def __init__(
self._executor = ThreadPoolExecutor(max_workers=limit if limit else None)
self._limiter = anyio.CapacityLimiter(limit) if limit else None

self.in_flight_task_runs = {task.task_key: set() for task in self.tasks}

@property
def _client_id(self) -> str:
def client_id(self) -> str:
return f"{socket.gethostname()}-{os.getpid()}"

@property
def limit(self) -> Optional[int]:
return int(self._limiter.total_tokens) if self._limiter else None

@property
def current_tasks(self) -> Optional[int]:
return (
int(self._limiter.borrowed_tokens)
if self._limiter
else sum(len(runs) for runs in self.in_flight_task_runs.values())
)

@property
def available_tasks(self) -> Optional[int]:
return int(self._limiter.available_tokens) if self._limiter else None

def handle_sigterm(self, signum, frame):
"""
Shuts down the task worker when a SIGTERM is received.
Expand Down Expand Up @@ -153,7 +173,7 @@ async def _subscribe_to_task_scheduling(self):
model=TaskRun,
path="/task_runs/subscriptions/scheduled",
keys=[task.task_key for task in self.tasks],
client_id=self._client_id,
client_id=self.client_id,
base_url=base_url,
):
logger.info(f"Received task run: {task_run.id} - {task_run.name}")
Expand All @@ -164,6 +184,7 @@ async def _subscribe_to_task_scheduling(self):
)

async def _safe_submit_scheduled_task_run(self, task_run: TaskRun):
self.in_flight_task_runs[task_run.task_key].add(task_run.id)
try:
await self._submit_scheduled_task_run(task_run)
except BaseException as exc:
Expand All @@ -172,6 +193,7 @@ async def _safe_submit_scheduled_task_run(self, task_run: TaskRun):
exc_info=exc,
)
finally:
self.in_flight_task_runs[task_run.task_key].discard(task_run.id)
if self._limiter:
self._limiter.release_on_behalf_of(task_run.id)

Expand Down Expand Up @@ -307,6 +329,27 @@ async def __aexit__(self, *exc_info):
await self._exit_stack.__aexit__(*exc_info)


def create_status_server(task_worker: TaskWorker) -> FastAPI:
status_app = FastAPI()

@status_app.get("/status")
def status():
return {
"client_id": task_worker.client_id,
"started": task_worker.started,
"stopping": task_worker.stopping,
"limit": task_worker.limit,
"current": task_worker.current_tasks,
"available": task_worker.available_tasks,
"tasks": {
key: list(sorted(tasks))
for key, tasks in task_worker.in_flight_task_runs.items()
},
}

return status_app


@sync_compatible
async def serve(*tasks: Task, limit: Optional[int] = 10):
"""Serve the provided tasks so that their runs may be submitted to and executed.
Expand Down Expand Up @@ -339,6 +382,19 @@ def yell(message: str):
"""
task_worker = TaskWorker(*tasks, limit=limit)

loop = asyncio.get_event_loop()

server = uvicorn.Server(
uvicorn.Config(
app=create_status_server(task_worker),
host="127.0.0.1",
port=4422,
access_log=False,
log_level="warning",
)
)
status_server = loop.create_task(server.serve())

try:
await task_worker.start()

Expand All @@ -355,3 +411,10 @@ def yell(message: str):

except (asyncio.CancelledError, KeyboardInterrupt):
logger.info("Task worker interrupted, stopping...")

finally:
status_server.cancel()
try:
await status_server
except asyncio.CancelledError:
pass
2 changes: 1 addition & 1 deletion tests/test_task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ async def test_task_worker_client_id_is_set():
task_worker = TaskWorker(...)
task_worker._client = MagicMock(api_url="http://localhost:4200")

assert task_worker._client_id == "foo-42"
assert task_worker.client_id == "foo-42"


async def test_task_worker_handles_aborted_task_run_submission(
Expand Down

0 comments on commit 8aa43df

Please sign in to comment.