From 109dcd9005dc419eb9f4737a613101bc0edf2538 Mon Sep 17 00:00:00 2001 From: nate nowack Date: Tue, 16 Jul 2024 12:58:58 -0500 Subject: [PATCH] port task worker api (#14536) --- .../3.0rc/api-ref/rest-api/server/schema.json | 141 +++++++++++++++++- .../server/task-workers/read-task-workers.mdx | 3 + .../server/task-workers/read-task-workers.mdx | 3 + docs/mint.json | 6 + src/prefect/client/cloud.py | 2 +- src/prefect/server/api/__init__.py | 1 + src/prefect/server/api/server.py | 1 + src/prefect/server/api/task_runs.py | 25 +++- src/prefect/server/api/task_workers.py | 31 ++++ src/prefect/server/models/__init__.py | 1 + src/prefect/server/models/task_workers.py | 103 +++++++++++++ src/prefect/task_worker.py | 2 +- tests/server/models/test_task_workers.py | 60 ++++++++ .../api/test_task_run_subscriptions.py | 132 ++++++++++++---- .../orchestration/api/test_task_workers.py | 43 ++++++ tests/test_task_worker.py | 2 +- 16 files changed, 514 insertions(+), 42 deletions(-) create mode 100644 docs/3.0rc/api-ref/rest-api/server/task-workers/read-task-workers.mdx create mode 100644 docs/3.0rc/api-ref/server/task-workers/read-task-workers.mdx create mode 100644 src/prefect/server/api/task_workers.py create mode 100644 src/prefect/server/models/task_workers.py create mode 100644 tests/server/models/test_task_workers.py create mode 100644 tests/server/orchestration/api/test_task_workers.py diff --git a/docs/3.0rc/api-ref/rest-api/server/schema.json b/docs/3.0rc/api-ref/rest-api/server/schema.json index 5ba97258f363..700ae8efdc04 100644 --- a/docs/3.0rc/api-ref/rest-api/server/schema.json +++ b/docs/3.0rc/api-ref/rest-api/server/schema.json @@ -6500,6 +6500,67 @@ } } }, + "/api/task_workers/filter": { + "post": { + "tags": [ + "Task Workers" + ], + "summary": "Read Task Workers", + "description": "Read active task workers. Optionally filter by task keys.", + "operationId": "read_task_workers_task_workers_filter_post", + "parameters": [ + { + "name": "x-prefect-api-version", + "in": "header", + "required": false, + "schema": { + "type": "string", + "title": "X-Prefect-Api-Version" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "allOf": [ + { + "$ref": "#/components/schemas/Body_read_task_workers_task_workers_filter_post" + } + ], + "title": "Body" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/TaskWorkerResponse" + }, + "title": "Response Read Task Workers Task Workers Filter Post" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/api/work_queues/": { "post": { "tags": [ @@ -14250,30 +14311,42 @@ "default": 0 }, "flows": { - "allOf": [ + "anyOf": [ { "$ref": "#/components/schemas/FlowFilter" + }, + { + "type": "null" } ] }, "flow_runs": { - "allOf": [ + "anyOf": [ { "$ref": "#/components/schemas/FlowRunFilter" + }, + { + "type": "null" } ] }, "task_runs": { - "allOf": [ + "anyOf": [ { "$ref": "#/components/schemas/TaskRunFilter" + }, + { + "type": "null" } ] }, "deployments": { - "allOf": [ + "anyOf": [ { "$ref": "#/components/schemas/DeploymentFilter" + }, + { + "type": "null" } ] }, @@ -14286,6 +14359,23 @@ "type": "object", "title": "Body_read_task_runs_task_runs_filter_post" }, + "Body_read_task_workers_task_workers_filter_post": { + "properties": { + "task_worker_filter": { + "anyOf": [ + { + "$ref": "#/components/schemas/TaskWorkerFilter" + }, + { + "type": "null" + } + ], + "description": "The task worker filter" + } + }, + "type": "object", + "title": "Body_read_task_workers_task_workers_filter_post" + }, "Body_read_variables_variables_filter_post": { "properties": { "offset": { @@ -23951,6 +24041,49 @@ "title": "TaskRunUpdate", "description": "Data used by the Prefect REST API to update a task run" }, + "TaskWorkerFilter": { + "properties": { + "task_keys": { + "items": { + "type": "string" + }, + "type": "array", + "title": "Task Keys" + } + }, + "type": "object", + "required": [ + "task_keys" + ], + "title": "TaskWorkerFilter" + }, + "TaskWorkerResponse": { + "properties": { + "identifier": { + "type": "string", + "title": "Identifier" + }, + "task_keys": { + "items": { + "type": "string" + }, + "type": "array", + "title": "Task Keys" + }, + "timestamp": { + "type": "string", + "format": "date-time", + "title": "Timestamp" + } + }, + "type": "object", + "required": [ + "identifier", + "task_keys", + "timestamp" + ], + "title": "TaskWorkerResponse" + }, "TimeUnit": { "type": "string", "enum": [ diff --git a/docs/3.0rc/api-ref/rest-api/server/task-workers/read-task-workers.mdx b/docs/3.0rc/api-ref/rest-api/server/task-workers/read-task-workers.mdx new file mode 100644 index 000000000000..fc3b933e044d --- /dev/null +++ b/docs/3.0rc/api-ref/rest-api/server/task-workers/read-task-workers.mdx @@ -0,0 +1,3 @@ +--- +openapi: post /api/task_workers/filter +--- \ No newline at end of file diff --git a/docs/3.0rc/api-ref/server/task-workers/read-task-workers.mdx b/docs/3.0rc/api-ref/server/task-workers/read-task-workers.mdx new file mode 100644 index 000000000000..fc3b933e044d --- /dev/null +++ b/docs/3.0rc/api-ref/server/task-workers/read-task-workers.mdx @@ -0,0 +1,3 @@ +--- +openapi: post /api/task_workers/filter +--- \ No newline at end of file diff --git a/docs/mint.json b/docs/mint.json index dc1ec1b729b5..0120728dfc3d 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -565,6 +565,12 @@ "3.0rc/api-ref/rest-api/server/work-pools/delete-worker" ] }, + { + "group": "Task Workers", + "pages": [ + "3.0rc/api-ref/rest-api/server/task-workers/read-task-workers" + ] + }, { "group": "Work Queues", "pages": [ diff --git a/src/prefect/client/cloud.py b/src/prefect/client/cloud.py index ae4b5c4a4ce1..1676d418f333 100644 --- a/src/prefect/client/cloud.py +++ b/src/prefect/client/cloud.py @@ -9,7 +9,7 @@ import prefect.context import prefect.settings from prefect.client.base import PrefectHttpxAsyncClient -from prefect.client.schemas import Workspace +from prefect.client.schemas.objects import Workspace from prefect.exceptions import ObjectNotFound, PrefectException from prefect.settings import ( PREFECT_API_KEY, diff --git a/src/prefect/server/api/__init__.py b/src/prefect/server/api/__init__.py index 9258a6d713bc..a5d4c7383cdb 100644 --- a/src/prefect/server/api/__init__.py +++ b/src/prefect/server/api/__init__.py @@ -24,6 +24,7 @@ saved_searches, task_run_states, task_runs, + task_workers, templates, ui, variables, diff --git a/src/prefect/server/api/server.py b/src/prefect/server/api/server.py index e867dc7fa2ac..71634fff30e7 100644 --- a/src/prefect/server/api/server.py +++ b/src/prefect/server/api/server.py @@ -85,6 +85,7 @@ api.block_types.router, api.block_documents.router, api.workers.router, + api.task_workers.router, api.work_queues.router, api.artifacts.router, api.block_schemas.router, diff --git a/src/prefect/server/api/task_runs.py b/src/prefect/server/api/task_runs.py index b417f13aaf4e..645962a9ca24 100644 --- a/src/prefect/server/api/task_runs.py +++ b/src/prefect/server/api/task_runs.py @@ -4,7 +4,7 @@ import asyncio import datetime -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from uuid import UUID import pendulum @@ -188,10 +188,10 @@ async def read_task_runs( sort: schemas.sorting.TaskRunSort = Body(schemas.sorting.TaskRunSort.ID_DESC), limit: int = dependencies.LimitBody(), offset: int = Body(0, ge=0), - flows: schemas.filters.FlowFilter = None, - flow_runs: schemas.filters.FlowRunFilter = None, - task_runs: schemas.filters.TaskRunFilter = None, - deployments: schemas.filters.DeploymentFilter = None, + flows: Optional[schemas.filters.FlowFilter] = None, + flow_runs: Optional[schemas.filters.FlowRunFilter] = None, + task_runs: Optional[schemas.filters.TaskRunFilter] = None, + deployments: Optional[schemas.filters.DeploymentFilter] = None, db: PrefectDBInterface = Depends(provide_database_interface), ) -> List[schemas.core.TaskRun]: """ @@ -296,13 +296,24 @@ async def scheduled_task_subscription(websocket: WebSocket): code=4001, reason="Protocol violation: expected 'keys' in subscribe message" ) + if not (client_id := subscription.get("client_id")): + return await websocket.close( + code=4001, + reason="Protocol violation: expected 'client_id' in subscribe message", + ) + subscribed_queue = MultiQueue(task_keys) + logger.info(f"Task worker {client_id!r} subscribed to task keys {task_keys!r}") + while True: try: + # observe here so that all workers with active websockets are tracked + await models.task_workers.observe_worker(task_keys, client_id) task_run = await asyncio.wait_for(subscribed_queue.get(), timeout=1) except asyncio.TimeoutError: if not await subscriptions.still_connected(websocket): + await models.task_workers.forget_worker(client_id) return continue @@ -319,7 +330,11 @@ async def scheduled_task_subscription(websocket: WebSocket): code=4001, reason="Protocol violation: expected 'ack' message" ) + await models.task_workers.observe_worker([task_run.task_key], client_id) + except subscriptions.NORMAL_DISCONNECT_EXCEPTIONS: # If sending fails or pong fails, put the task back into the retry queue await asyncio.shield(TaskQueue.for_key(task_run.task_key).retry(task_run)) return + finally: + await models.task_workers.forget_worker(client_id) diff --git a/src/prefect/server/api/task_workers.py b/src/prefect/server/api/task_workers.py new file mode 100644 index 000000000000..b3ebc3edb6ff --- /dev/null +++ b/src/prefect/server/api/task_workers.py @@ -0,0 +1,31 @@ +from typing import List, Optional + +from fastapi import Body +from pydantic import BaseModel + +from prefect.server import models +from prefect.server.models.task_workers import TaskWorkerResponse +from prefect.server.utilities.server import PrefectRouter + +router = PrefectRouter(prefix="/task_workers", tags=["Task Workers"]) + + +class TaskWorkerFilter(BaseModel): + task_keys: List[str] + + +@router.post("/filter") +async def read_task_workers( + task_worker_filter: Optional[TaskWorkerFilter] = Body( + default=None, description="The task worker filter", embed=True + ), +) -> List[TaskWorkerResponse]: + """Read active task workers. Optionally filter by task keys.""" + + if task_worker_filter and task_worker_filter.task_keys: + return await models.task_workers.get_workers_for_task_keys( + task_keys=task_worker_filter.task_keys, + ) + + else: + return await models.task_workers.get_all_workers() diff --git a/src/prefect/server/models/__init__.py b/src/prefect/server/models/__init__.py index 1a1b27e4ebb3..00c1707cd1f9 100644 --- a/src/prefect/server/models/__init__.py +++ b/src/prefect/server/models/__init__.py @@ -19,6 +19,7 @@ saved_searches, task_run_states, task_runs, + task_workers, variables, work_queues, workers, diff --git a/src/prefect/server/models/task_workers.py b/src/prefect/server/models/task_workers.py new file mode 100644 index 000000000000..b2e2353bba4d --- /dev/null +++ b/src/prefect/server/models/task_workers.py @@ -0,0 +1,103 @@ +import time +from collections import defaultdict +from typing import Dict, List, Set + +from pydantic import BaseModel +from pydantic_extra_types.pendulum_dt import DateTime +from typing_extensions import TypeAlias + +TaskKey: TypeAlias = str +WorkerId: TypeAlias = str + + +class TaskWorkerResponse(BaseModel): + identifier: WorkerId + task_keys: List[TaskKey] + timestamp: DateTime + + +class InMemoryTaskWorkerTracker: + def __init__(self): + self.workers: dict[WorkerId, Set[TaskKey]] = {} + self.task_keys: Dict[TaskKey, Set[WorkerId]] = defaultdict(set) + self.worker_timestamps: Dict[WorkerId, float] = {} + + async def observe_worker( + self, + task_keys: List[TaskKey], + worker_id: WorkerId, + ) -> None: + self.workers[worker_id] = self.workers.get(worker_id, set()) | set(task_keys) + self.worker_timestamps[worker_id] = time.monotonic() + + for task_key in task_keys: + self.task_keys[task_key].add(worker_id) + + async def forget_worker( + self, + worker_id: WorkerId, + ) -> None: + if worker_id in self.workers: + task_keys = self.workers.pop(worker_id) + for task_key in task_keys: + self.task_keys[task_key].discard(worker_id) + if not self.task_keys[task_key]: + del self.task_keys[task_key] + self.worker_timestamps.pop(worker_id, None) + + async def get_workers_for_task_keys( + self, + task_keys: List[TaskKey], + ) -> List[TaskWorkerResponse]: + if not task_keys: + return await self.get_all_workers() + active_workers = set().union(*(self.task_keys[key] for key in task_keys)) + return [self._create_worker_response(worker_id) for worker_id in active_workers] + + async def get_all_workers(self) -> List[TaskWorkerResponse]: + return [ + self._create_worker_response(worker_id) + for worker_id in self.worker_timestamps.keys() + ] + + def _create_worker_response(self, worker_id: WorkerId) -> TaskWorkerResponse: + timestamp = time.monotonic() - self.worker_timestamps[worker_id] + return TaskWorkerResponse( + identifier=worker_id, + task_keys=list(self.workers.get(worker_id, set())), + timestamp=DateTime.utcnow().subtract(seconds=timestamp), + ) + + def reset(self): + """Testing utility to reset the state of the task worker tracker""" + self.workers.clear() + self.task_keys.clear() + self.worker_timestamps.clear() + + +# Global instance of the task worker tracker +task_worker_tracker = InMemoryTaskWorkerTracker() + + +# Main utilities to be used in the API layer +async def observe_worker( + task_keys: List[TaskKey], + worker_id: WorkerId, +) -> None: + await task_worker_tracker.observe_worker(task_keys, worker_id) + + +async def forget_worker( + worker_id: WorkerId, +) -> None: + await task_worker_tracker.forget_worker(worker_id) + + +async def get_workers_for_task_keys( + task_keys: List[TaskKey], +) -> List[TaskWorkerResponse]: + return await task_worker_tracker.get_workers_for_task_keys(task_keys) + + +async def get_all_workers() -> List[TaskWorkerResponse]: + return await task_worker_tracker.get_all_workers() diff --git a/src/prefect/task_worker.py b/src/prefect/task_worker.py index fef421f852f5..fa1bc9003f5e 100644 --- a/src/prefect/task_worker.py +++ b/src/prefect/task_worker.py @@ -325,7 +325,7 @@ async def _submit_scheduled_task_run(self, task_run: TaskRun): if task_run_url := url_for(task_run): logger.info( - f"Submitting task run {task_run.name!r} to engine. View run in the UI at {task_run_url!r}" + f"Submitting task run {task_run.name!r} to engine. View in the UI: {task_run_url}" ) if task.isasync: diff --git a/tests/server/models/test_task_workers.py b/tests/server/models/test_task_workers.py new file mode 100644 index 000000000000..23d732bcd86b --- /dev/null +++ b/tests/server/models/test_task_workers.py @@ -0,0 +1,60 @@ +import pytest + +from prefect.server.models.task_workers import InMemoryTaskWorkerTracker + + +@pytest.fixture +async def tracker(): + return InMemoryTaskWorkerTracker() + + +@pytest.mark.parametrize( + "task_keys,task_worker_id", + [(["task1", "task2"], "worker1"), (["task3"], "worker2"), ([], "worker3")], + ids=["task_keys", "no_task_keys", "empty_task_keys"], +) +async def test_observe_and_get_worker(tracker, task_keys, task_worker_id): + await tracker.observe_worker(task_keys, task_worker_id) + workers = await tracker.get_all_workers() + assert len(workers) == 1 + assert workers[0].identifier == task_worker_id + assert set(workers[0].task_keys) == set(task_keys) + + +@pytest.mark.parametrize( + "initial_tasks,forget_id,expected_count", + [ + ({"worker1": ["task1"], "worker2": ["task2"]}, "worker1", 1), + ({"worker1": ["task1"]}, "worker1", 0), + ({"worker1": ["task1"]}, "worker2", 1), + ], + ids=["forget_worker", "forget_no_worker", "forget_empty_worker"], +) +async def test_forget_worker(tracker, initial_tasks, forget_id, expected_count): + for worker, tasks in initial_tasks.items(): + await tracker.observe_worker(tasks, worker) + await tracker.forget_worker(forget_id) + workers = await tracker.get_all_workers() + assert len(workers) == expected_count + + +@pytest.mark.parametrize( + "observed_workers,query_tasks,expected_workers", + [ + ( + {"worker1": ["task1", "task2"], "worker2": ["task2", "task3"]}, + ["task2"], + {"worker1", "worker2"}, + ), + ({"worker1": ["task1"], "worker2": ["task2"]}, ["task3"], set()), + ({"worker1": ["task1"], "worker2": ["task2"]}, [], {"worker1", "worker2"}), + ], + ids=["filter_tasks", "filter_tasks_and_task_keys", "no_filter"], +) +async def test_get_workers_for_task_keys( + tracker, observed_workers, query_tasks, expected_workers +): + for worker, tasks in observed_workers.items(): + await tracker.observe_worker(tasks, worker) + workers = await tracker.get_workers_for_task_keys(query_tasks) + assert {w.identifier for w in workers} == expected_workers diff --git a/tests/server/orchestration/api/test_task_run_subscriptions.py b/tests/server/orchestration/api/test_task_run_subscriptions.py index 8bf1df7a80cd..1d92fd27ceed 100644 --- a/tests/server/orchestration/api/test_task_run_subscriptions.py +++ b/tests/server/orchestration/api/test_task_run_subscriptions.py @@ -1,4 +1,6 @@ import asyncio +import os +import socket from collections import Counter from contextlib import contextmanager from typing import Generator, List @@ -26,6 +28,11 @@ def reset_task_queues() -> Generator[None, None, None]: task_runs.TaskQueue.reset() +@pytest.fixture +def client_id() -> str: + return f"{socket.gethostname()}-{os.getpid()}" + + def auth_dance(socket: WebSocketTestSession): socket.send_json({"type": "auth", "token": None}) response = socket.receive_json() @@ -66,8 +73,8 @@ def drain( @pytest.fixture -async def taskA_run1(reset_task_queues) -> TaskRun: - queued = TaskRun( +async def taskA_run1(reset_task_queues) -> ServerTaskRun: + queued = ServerTaskRun( id=uuid4(), flow_run_id=None, task_key="mytasks.taskA", @@ -77,9 +84,11 @@ async def taskA_run1(reset_task_queues) -> TaskRun: return queued -def test_receiving_task_run(app: FastAPI, taskA_run1: TaskRun): +def test_receiving_task_run(app: FastAPI, taskA_run1: TaskRun, client_id: str): with authenticated_socket(app) as socket: - socket.send_json({"type": "subscribe", "keys": ["mytasks.taskA"]}) + socket.send_json( + {"type": "subscribe", "keys": ["mytasks.taskA"], "client_id": client_id} + ) (received,) = drain(socket) @@ -87,9 +96,8 @@ def test_receiving_task_run(app: FastAPI, taskA_run1: TaskRun): @pytest.fixture -async def taskA_run2(reset_task_queues) -> TaskRun: - queued = TaskRun( - id=uuid4(), +async def taskA_run2(reset_task_queues) -> ServerTaskRun: + queued = ServerTaskRun( flow_run_id=None, task_key="mytasks.taskA", dynamic_key="mytasks.taskA-1", @@ -99,10 +107,12 @@ async def taskA_run2(reset_task_queues) -> TaskRun: def test_acknowledging_between_each_run( - app: FastAPI, taskA_run1: TaskRun, taskA_run2: TaskRun + app: FastAPI, taskA_run1: TaskRun, taskA_run2: TaskRun, client_id: str ): with authenticated_socket(app) as socket: - socket.send_json({"type": "subscribe", "keys": ["mytasks.taskA"]}) + socket.send_json( + {"type": "subscribe", "keys": ["mytasks.taskA"], "client_id": client_id} + ) (first, second) = drain(socket, 2) @@ -114,7 +124,7 @@ def test_acknowledging_between_each_run( @pytest.fixture async def mixed_bag_of_tasks(reset_task_queues) -> None: await task_runs.TaskQueue.enqueue( - TaskRun( + TaskRun( # type: ignore id=uuid4(), flow_run_id=None, task_key="mytasks.taskA", @@ -123,7 +133,7 @@ async def mixed_bag_of_tasks(reset_task_queues) -> None: ) await task_runs.TaskQueue.enqueue( - TaskRun( + TaskRun( # type: ignore id=uuid4(), flow_run_id=None, task_key="mytasks.taskA", @@ -133,7 +143,7 @@ async def mixed_bag_of_tasks(reset_task_queues) -> None: # this one should not be delivered await task_runs.TaskQueue.enqueue( - TaskRun( + TaskRun( # type: ignore id=uuid4(), flow_run_id=None, task_key="nope.not.this.one", @@ -142,7 +152,7 @@ async def mixed_bag_of_tasks(reset_task_queues) -> None: ) await task_runs.TaskQueue.enqueue( - TaskRun( + TaskRun( # type: ignore id=uuid4(), flow_run_id=None, task_key="other_tasks.taskB", @@ -153,11 +163,16 @@ async def mixed_bag_of_tasks(reset_task_queues) -> None: def test_server_only_delivers_tasks_for_subscribed_keys( app: FastAPI, - mixed_bag_of_tasks, + mixed_bag_of_tasks: List[TaskRun], + client_id: str, ): with authenticated_socket(app) as socket: socket.send_json( - {"type": "subscribe", "keys": ["mytasks.taskA", "other_tasks.taskB"]} + { + "type": "subscribe", + "keys": ["mytasks.taskA", "other_tasks.taskB"], + "client_id": client_id, + } ) received = drain(socket, 3) @@ -169,10 +184,10 @@ def test_server_only_delivers_tasks_for_subscribed_keys( @pytest.fixture -async def ten_task_A_runs(reset_task_queues) -> List[TaskRun]: - queued: List[TaskRun] = [] +async def ten_task_A_runs(reset_task_queues) -> List[ServerTaskRun]: + queued: List[ServerTaskRun] = [] for _ in range(10): - run = TaskRun( + run = ServerTaskRun( id=uuid4(), flow_run_id=None, task_key="mytasks.taskA", @@ -184,14 +199,18 @@ async def ten_task_A_runs(reset_task_queues) -> List[TaskRun]: def test_only_one_socket_gets_each_task_run( - app: FastAPI, ten_task_A_runs: List[TaskRun] + app: FastAPI, ten_task_A_runs: List[TaskRun], client_id: str ): received1: List[TaskRun] = [] received2: List[TaskRun] = [] with authenticated_socket(app) as first, authenticated_socket(app) as second: - first.send_json({"type": "subscribe", "keys": ["mytasks.taskA"]}) - second.send_json({"type": "subscribe", "keys": ["mytasks.taskA"]}) + first.send_json( + {"type": "subscribe", "keys": ["mytasks.taskA"], "client_id": client_id} + ) + second.send_json( + {"type": "subscribe", "keys": ["mytasks.taskA"], "client_id": client_id} + ) for i in range(5): received1 += drain(first, 1, quit=(i == 4)) @@ -216,9 +235,13 @@ def test_only_one_socket_gets_each_task_run( assert received_ids.issubset(queued_ids) -def test_server_redelivers_unacknowledged_runs(app: FastAPI, taskA_run1: TaskRun): +def test_server_redelivers_unacknowledged_runs( + app: FastAPI, taskA_run1: TaskRun, client_id: str +): with authenticated_socket(app) as socket: - socket.send_json({"type": "subscribe", "keys": ["mytasks.taskA"]}) + socket.send_json( + {"type": "subscribe", "keys": ["mytasks.taskA"], "client_id": client_id} + ) received = socket.receive_json() assert received["id"] == str(taskA_run1.id) @@ -227,14 +250,18 @@ def test_server_redelivers_unacknowledged_runs(app: FastAPI, taskA_run1: TaskRun socket.close() with authenticated_socket(app) as socket: - socket.send_json({"type": "subscribe", "keys": ["mytasks.taskA"]}) + socket.send_json( + {"type": "subscribe", "keys": ["mytasks.taskA"], "client_id": client_id} + ) (received,) = drain(socket) assert received.id == taskA_run1.id @pytest.fixture -async def preexisting_runs(session: AsyncSession, reset_task_queues) -> List[TaskRun]: +async def preexisting_runs( + session: AsyncSession, reset_task_queues +) -> List[ServerTaskRun]: stored_runA = ServerTaskRun.model_validate( await models.task_runs.create_task_run( session, @@ -267,9 +294,12 @@ async def preexisting_runs(session: AsyncSession, reset_task_queues) -> List[Tas def test_server_restores_scheduled_task_runs_at_startup( app: FastAPI, preexisting_runs: List[TaskRun], + client_id: str, ): with authenticated_socket(app) as socket: - socket.send_json({"type": "subscribe", "keys": ["mytasks.taskA"]}) + socket.send_json( + {"type": "subscribe", "keys": ["mytasks.taskA"], "client_id": client_id} + ) received = drain(socket, expecting=len(preexisting_runs)) @@ -288,7 +318,7 @@ async def test_task_queue_scheduled_size_limit(self): queue = task_runs.TaskQueue.for_key(task_key) for _ in range(max_scheduled_size): - task_run = TaskRun( + task_run = ServerTaskRun( id=uuid4(), flow_run_id=None, task_key=task_key, @@ -299,7 +329,7 @@ async def test_task_queue_scheduled_size_limit(self): with patch("asyncio.sleep", return_value=None), pytest.raises( asyncio.TimeoutError ): - extra_task_run = TaskRun( + extra_task_run = ServerTaskRun( id=uuid4(), flow_run_id=None, task_key=task_key, @@ -321,7 +351,7 @@ async def test_task_queue_retry_size_limit(self): queue = task_runs.TaskQueue.for_key(task_key) - task_run = TaskRun( + task_run = ServerTaskRun( id=uuid4(), flow_run_id=None, task_key=task_key, dynamic_key=f"{task_key}-1" ) await queue.retry(task_run) @@ -329,7 +359,7 @@ async def test_task_queue_retry_size_limit(self): with patch("asyncio.sleep", return_value=None), pytest.raises( asyncio.TimeoutError ): - extra_task_run = TaskRun( + extra_task_run = ServerTaskRun( id=uuid4(), flow_run_id=None, task_key=task_key, @@ -340,3 +370,45 @@ async def test_task_queue_retry_size_limit(self): assert ( queue._retry_queue.qsize() == max_retry_size ), "Retry queue size should be at its configured limit" + + +@pytest.fixture +def reset_tracker(): + models.task_workers.task_worker_tracker.reset() + yield + models.task_workers.task_worker_tracker.reset() + + +class TestTaskWorkerTracking: + @pytest.mark.parametrize( + "num_connections,task_keys,expected_workers", + [ + (2, ["taskA", "taskB"], 1), + (1, ["taskA", "taskB", "taskC"], 1), + ], + ids=["multiple_connections_single_worker", "single_connection_multiple_tasks"], + ) + @pytest.mark.usefixtures("reset_tracker") + async def test_task_worker_basic_tracking( + self, + app, + num_connections, + task_keys, + expected_workers, + client_id, + prefect_client, + ): + for _ in range(num_connections): + with authenticated_socket(app) as socket: + socket.send_json( + {"type": "subscribe", "keys": task_keys, "client_id": client_id} + ) + + response = await prefect_client._client.post("/task_workers/filter") + assert response.status_code == 200 + tracked_workers = response.json() + assert len(tracked_workers) == expected_workers + + for worker in tracked_workers: + assert worker["identifier"] == client_id + assert set(worker["task_keys"]) == set(task_keys) diff --git a/tests/server/orchestration/api/test_task_workers.py b/tests/server/orchestration/api/test_task_workers.py new file mode 100644 index 000000000000..530d9f9fa498 --- /dev/null +++ b/tests/server/orchestration/api/test_task_workers.py @@ -0,0 +1,43 @@ +import pytest + +from prefect.server.models.task_workers import observe_worker + + +@pytest.mark.parametrize( + "initial_workers,certain_tasks,expected_count", + [ + ({"worker1": ["task1"]}, None, 1), + ({"worker1": ["task1"], "worker2": ["task2"]}, ["task1"], 1), + ({"worker1": ["task1"], "worker2": ["task2"]}, None, 2), + ({"worker1": ["task1", "task2"], "worker2": ["task2", "task3"]}, ["task2"], 2), + ], + ids=[ + "one_worker_no_filter", + "one_worker_filter", + "two_workers_no_filter", + "two_workers_filter", + ], +) +async def test_read_task_workers( + prefect_client, initial_workers, certain_tasks, expected_count +): + for worker, tasks in initial_workers.items(): + await observe_worker(tasks, worker) + + response = await prefect_client._client.post( + "/task_workers/filter", + json={"task_worker_filter": {"task_keys": certain_tasks}} + if certain_tasks + else None, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == expected_count + + if expected_count > 0: + for worker in data: + assert worker["identifier"] in initial_workers + assert set(worker["task_keys"]).issubset( + set(initial_workers[worker["identifier"]]) + ) diff --git a/tests/test_task_worker.py b/tests/test_task_worker.py index 64957469dd63..99dce2562fd8 100644 --- a/tests/test_task_worker.py +++ b/tests/test_task_worker.py @@ -199,7 +199,7 @@ async def test_task_worker_emits_run_ui_url_upon_submission( with temporary_settings({PREFECT_UI_URL: "http://test/api"}): await task_worker.execute_task_run(task_run) - assert "in the UI at 'http://test/api/runs/task-run/" in caplog.text + assert "in the UI: http://test/api/runs/task-run/" in caplog.text @pytest.mark.usefixtures("mock_task_worker_start")