Skip to content

Commit

Permalink
port task worker api (#14536)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Jul 16, 2024
1 parent 387a2a4 commit 109dcd9
Show file tree
Hide file tree
Showing 16 changed files with 514 additions and 42 deletions.
141 changes: 137 additions & 4 deletions docs/3.0rc/api-ref/rest-api/server/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -6500,6 +6500,67 @@
}
}
},
"/api/task_workers/filter": {
"post": {
"tags": [
"Task Workers"
],
"summary": "Read Task Workers",
"description": "Read active task workers. Optionally filter by task keys.",
"operationId": "read_task_workers_task_workers_filter_post",
"parameters": [
{
"name": "x-prefect-api-version",
"in": "header",
"required": false,
"schema": {
"type": "string",
"title": "X-Prefect-Api-Version"
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"allOf": [
{
"$ref": "#/components/schemas/Body_read_task_workers_task_workers_filter_post"
}
],
"title": "Body"
}
}
}
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/TaskWorkerResponse"
},
"title": "Response Read Task Workers Task Workers Filter Post"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
}
}
},
"/api/work_queues/": {
"post": {
"tags": [
Expand Down Expand Up @@ -14250,30 +14311,42 @@
"default": 0
},
"flows": {
"allOf": [
"anyOf": [
{
"$ref": "#/components/schemas/FlowFilter"
},
{
"type": "null"
}
]
},
"flow_runs": {
"allOf": [
"anyOf": [
{
"$ref": "#/components/schemas/FlowRunFilter"
},
{
"type": "null"
}
]
},
"task_runs": {
"allOf": [
"anyOf": [
{
"$ref": "#/components/schemas/TaskRunFilter"
},
{
"type": "null"
}
]
},
"deployments": {
"allOf": [
"anyOf": [
{
"$ref": "#/components/schemas/DeploymentFilter"
},
{
"type": "null"
}
]
},
Expand All @@ -14286,6 +14359,23 @@
"type": "object",
"title": "Body_read_task_runs_task_runs_filter_post"
},
"Body_read_task_workers_task_workers_filter_post": {
"properties": {
"task_worker_filter": {
"anyOf": [
{
"$ref": "#/components/schemas/TaskWorkerFilter"
},
{
"type": "null"
}
],
"description": "The task worker filter"
}
},
"type": "object",
"title": "Body_read_task_workers_task_workers_filter_post"
},
"Body_read_variables_variables_filter_post": {
"properties": {
"offset": {
Expand Down Expand Up @@ -23951,6 +24041,49 @@
"title": "TaskRunUpdate",
"description": "Data used by the Prefect REST API to update a task run"
},
"TaskWorkerFilter": {
"properties": {
"task_keys": {
"items": {
"type": "string"
},
"type": "array",
"title": "Task Keys"
}
},
"type": "object",
"required": [
"task_keys"
],
"title": "TaskWorkerFilter"
},
"TaskWorkerResponse": {
"properties": {
"identifier": {
"type": "string",
"title": "Identifier"
},
"task_keys": {
"items": {
"type": "string"
},
"type": "array",
"title": "Task Keys"
},
"timestamp": {
"type": "string",
"format": "date-time",
"title": "Timestamp"
}
},
"type": "object",
"required": [
"identifier",
"task_keys",
"timestamp"
],
"title": "TaskWorkerResponse"
},
"TimeUnit": {
"type": "string",
"enum": [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
openapi: post /api/task_workers/filter
---
3 changes: 3 additions & 0 deletions docs/3.0rc/api-ref/server/task-workers/read-task-workers.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
openapi: post /api/task_workers/filter
---
6 changes: 6 additions & 0 deletions docs/mint.json
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,12 @@
"3.0rc/api-ref/rest-api/server/work-pools/delete-worker"
]
},
{
"group": "Task Workers",
"pages": [
"3.0rc/api-ref/rest-api/server/task-workers/read-task-workers"
]
},
{
"group": "Work Queues",
"pages": [
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/client/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import prefect.context
import prefect.settings
from prefect.client.base import PrefectHttpxAsyncClient
from prefect.client.schemas import Workspace
from prefect.client.schemas.objects import Workspace
from prefect.exceptions import ObjectNotFound, PrefectException
from prefect.settings import (
PREFECT_API_KEY,
Expand Down
1 change: 1 addition & 0 deletions src/prefect/server/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
saved_searches,
task_run_states,
task_runs,
task_workers,
templates,
ui,
variables,
Expand Down
1 change: 1 addition & 0 deletions src/prefect/server/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
api.block_types.router,
api.block_documents.router,
api.workers.router,
api.task_workers.router,
api.work_queues.router,
api.artifacts.router,
api.block_schemas.router,
Expand Down
25 changes: 20 additions & 5 deletions src/prefect/server/api/task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import asyncio
import datetime
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from uuid import UUID

import pendulum
Expand Down Expand Up @@ -188,10 +188,10 @@ async def read_task_runs(
sort: schemas.sorting.TaskRunSort = Body(schemas.sorting.TaskRunSort.ID_DESC),
limit: int = dependencies.LimitBody(),
offset: int = Body(0, ge=0),
flows: schemas.filters.FlowFilter = None,
flow_runs: schemas.filters.FlowRunFilter = None,
task_runs: schemas.filters.TaskRunFilter = None,
deployments: schemas.filters.DeploymentFilter = None,
flows: Optional[schemas.filters.FlowFilter] = None,
flow_runs: Optional[schemas.filters.FlowRunFilter] = None,
task_runs: Optional[schemas.filters.TaskRunFilter] = None,
deployments: Optional[schemas.filters.DeploymentFilter] = None,
db: PrefectDBInterface = Depends(provide_database_interface),
) -> List[schemas.core.TaskRun]:
"""
Expand Down Expand Up @@ -296,13 +296,24 @@ async def scheduled_task_subscription(websocket: WebSocket):
code=4001, reason="Protocol violation: expected 'keys' in subscribe message"
)

if not (client_id := subscription.get("client_id")):
return await websocket.close(
code=4001,
reason="Protocol violation: expected 'client_id' in subscribe message",
)

subscribed_queue = MultiQueue(task_keys)

logger.info(f"Task worker {client_id!r} subscribed to task keys {task_keys!r}")

while True:
try:
# observe here so that all workers with active websockets are tracked
await models.task_workers.observe_worker(task_keys, client_id)
task_run = await asyncio.wait_for(subscribed_queue.get(), timeout=1)
except asyncio.TimeoutError:
if not await subscriptions.still_connected(websocket):
await models.task_workers.forget_worker(client_id)
return
continue

Expand All @@ -319,7 +330,11 @@ async def scheduled_task_subscription(websocket: WebSocket):
code=4001, reason="Protocol violation: expected 'ack' message"
)

await models.task_workers.observe_worker([task_run.task_key], client_id)

except subscriptions.NORMAL_DISCONNECT_EXCEPTIONS:
# If sending fails or pong fails, put the task back into the retry queue
await asyncio.shield(TaskQueue.for_key(task_run.task_key).retry(task_run))
return
finally:
await models.task_workers.forget_worker(client_id)
31 changes: 31 additions & 0 deletions src/prefect/server/api/task_workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List, Optional

from fastapi import Body
from pydantic import BaseModel

from prefect.server import models
from prefect.server.models.task_workers import TaskWorkerResponse
from prefect.server.utilities.server import PrefectRouter

router = PrefectRouter(prefix="/task_workers", tags=["Task Workers"])


class TaskWorkerFilter(BaseModel):
task_keys: List[str]


@router.post("/filter")
async def read_task_workers(
task_worker_filter: Optional[TaskWorkerFilter] = Body(
default=None, description="The task worker filter", embed=True
),
) -> List[TaskWorkerResponse]:
"""Read active task workers. Optionally filter by task keys."""

if task_worker_filter and task_worker_filter.task_keys:
return await models.task_workers.get_workers_for_task_keys(
task_keys=task_worker_filter.task_keys,
)

else:
return await models.task_workers.get_all_workers()
1 change: 1 addition & 0 deletions src/prefect/server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
saved_searches,
task_run_states,
task_runs,
task_workers,
variables,
work_queues,
workers,
Expand Down
Loading

0 comments on commit 109dcd9

Please sign in to comment.