From ceba4482c253a4a8d0e2d75fa646a9e0d8679495 Mon Sep 17 00:00:00 2001 From: Chris Guidry Date: Thu, 1 Feb 2024 16:23:34 -0500 Subject: [PATCH] Tests for autonomous task subscriptions (#11801) --- src/prefect/server/api/task_runs.py | 54 +++- src/prefect/server/utilities/subscriptions.py | 18 +- .../server/api/test_task_run_subscriptions.py | 293 ++++++++++++++++++ 3 files changed, 349 insertions(+), 16 deletions(-) create mode 100644 tests/server/api/test_task_run_subscriptions.py diff --git a/src/prefect/server/api/task_runs.py b/src/prefect/server/api/task_runs.py index fca381ad6a8b..78a8796c2e94 100644 --- a/src/prefect/server/api/task_runs.py +++ b/src/prefect/server/api/task_runs.py @@ -23,10 +23,11 @@ import prefect.server.schemas as schemas from prefect.logging import get_logger from prefect.server.api.run_history import run_history -from prefect.server.database.dependencies import provide_database_interface +from prefect.server.database.dependencies import inject_db, provide_database_interface from prefect.server.database.interface import PrefectDBInterface from prefect.server.orchestration import dependencies as orchestration_dependencies from prefect.server.orchestration.policies import BaseOrchestrationPolicy +from prefect.server.schemas import filters, states from prefect.server.schemas.responses import OrchestrationResult from prefect.server.utilities import subscriptions from prefect.server.utilities.schemas import DateTimeTZ @@ -35,6 +36,7 @@ logger = get_logger("server.api") + router = PrefectRouter(prefix="/task_runs", tags=["Task Runs"]) @@ -296,15 +298,18 @@ async def scheduled_task_subscription(websocket: WebSocket): if not websocket: return + await restore_scheduled_tasks() + scheduled_queue = scheduled_task_runs_queue() retry_queue = retry_task_runs_queue() while True: - task_run: schemas.core.TaskRun = None - # First, check if there's anything in the retry queue - if not retry_queue.empty(): - task_run = await retry_queue.get() - else: + task_run: schemas.core.TaskRun + + try: + # First, check if there's anything in the retry queue + task_run = retry_queue.get_nowait() + except asyncio.QueueEmpty: task_run = await scheduled_queue.get() try: @@ -318,3 +323,40 @@ async def scheduled_task_subscription(websocket: WebSocket): # If sending fails or pong fails, put the task back into the retry queue await retry_queue.put(task_run) break + + +_scheduled_tasks_already_restored: bool = False + + +@inject_db +async def restore_scheduled_tasks(db: PrefectDBInterface): + global _scheduled_tasks_already_restored + if _scheduled_tasks_already_restored: + return + + _scheduled_tasks_already_restored = True + + if not PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING.value(): + return + + async with db.session_context() as session: + task_runs = await models.task_runs.read_task_runs( + session=session, + task_run_filter=filters.TaskRunFilter( + flow_run_id=filters.TaskRunFilterFlowRunId(is_null_=True), + state=filters.TaskRunFilterState( + type=filters.TaskRunFilterStateType( + any_=[states.StateType.SCHEDULED] + ) + ), + ), + ) + + if not task_runs: + return + + queue = retry_task_runs_queue() + for task_run in task_runs: + queue.put_nowait(schemas.core.TaskRun.from_orm(task_run)) + + logger.info("Restored %s scheduled task runs", len(task_runs)) diff --git a/src/prefect/server/utilities/subscriptions.py b/src/prefect/server/utilities/subscriptions.py index 751fa43219f1..b2e766bb5937 100644 --- a/src/prefect/server/utilities/subscriptions.py +++ b/src/prefect/server/utilities/subscriptions.py @@ -4,21 +4,19 @@ WebSocket, ) from starlette.status import WS_1002_PROTOCOL_ERROR, WS_1008_POLICY_VIOLATION +from starlette.websockets import WebSocketDisconnect from websockets.exceptions import ConnectionClosed -NORMAL_DISCONNECT_EXCEPTIONS = (IOError, ConnectionClosed) +NORMAL_DISCONNECT_EXCEPTIONS = (IOError, ConnectionClosed, WebSocketDisconnect) async def ping_pong(websocket: WebSocket): - try: - await websocket.send_json({"type": "ping"}) - - response = await websocket.receive_json() - if response.get("type") == "pong": - return True - else: - return False - except Exception: + await websocket.send_json({"type": "ping"}) + + response = await websocket.receive_json() + if response.get("type") == "pong": + return True + else: return False diff --git a/tests/server/api/test_task_run_subscriptions.py b/tests/server/api/test_task_run_subscriptions.py new file mode 100644 index 000000000000..1fea425c7008 --- /dev/null +++ b/tests/server/api/test_task_run_subscriptions.py @@ -0,0 +1,293 @@ +import json +from asyncio import AbstractEventLoop, CancelledError, gather +from contextlib import asynccontextmanager +from typing import AsyncGenerator, Callable, List +from uuid import uuid4 + +import anyio +import httpx +import pytest +import websockets +from sqlalchemy.ext.asyncio import AsyncSession +from uvicorn import Config, Server + +from prefect.client.schemas import TaskRun +from prefect.server import models +from prefect.server.api import task_runs +from prefect.server.api.server import create_app +from prefect.settings import ( + PREFECT_API_SERVICES_CANCELLATION_CLEANUP_ENABLED, + PREFECT_API_SERVICES_FLOW_RUN_NOTIFICATIONS_ENABLED, + PREFECT_API_SERVICES_LATE_RUNS_ENABLED, + PREFECT_API_SERVICES_PAUSE_EXPIRATIONS_ENABLED, + PREFECT_API_SERVICES_SCHEDULER_ENABLED, + PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING, + PREFECT_SERVER_ANALYTICS_ENABLED, + temporary_settings, +) +from prefect.states import Scheduled + + +@pytest.fixture(scope="module", autouse=True) +def services_disabled() -> None: + with temporary_settings( + { + PREFECT_SERVER_ANALYTICS_ENABLED: False, + PREFECT_API_SERVICES_SCHEDULER_ENABLED: False, + PREFECT_API_SERVICES_LATE_RUNS_ENABLED: False, + PREFECT_API_SERVICES_FLOW_RUN_NOTIFICATIONS_ENABLED: False, + PREFECT_API_SERVICES_PAUSE_EXPIRATIONS_ENABLED: False, + PREFECT_API_SERVICES_CANCELLATION_CLEANUP_ENABLED: False, + PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING: True, + } + ): + yield + + +@asynccontextmanager +async def running_prefect_server( + event_loop: AbstractEventLoop, port: int +) -> AsyncGenerator[str, None]: + api_url = f"http://localhost:{port}/api" + + app = create_app(ignore_cache=True) + server = Server(Config(app=app, host="127.0.0.1", port=port)) + server_task = event_loop.create_task(server.serve()) + + # Wait for the server to be ready + async with httpx.AsyncClient() as client: + response = None + with anyio.move_on_after(20): + while True: + try: + response = await client.get(api_url + "/health") + except httpx.ConnectError: + pass + else: + if response.status_code == 200: + break + await anyio.sleep(0.1) + if response: + response.raise_for_status() + if not response: + raise RuntimeError("Timed out while attempting to connect to test API") + + try: + yield api_url + finally: + server_task.cancel() + try: + await server_task + except CancelledError: + pass + + +@pytest.fixture(scope="module") +async def prefect_server( + unused_tcp_port_factory: Callable[[], int], + event_loop: AbstractEventLoop, +) -> AsyncGenerator[str, None]: + async with running_prefect_server(event_loop, unused_tcp_port_factory()) as api_url: + yield api_url + + +@pytest.fixture(autouse=True) +async def reset_task_queues(): + task_runs._scheduled_task_runs_queues = {} + task_runs._retry_task_runs_queues = {} + task_runs._scheduled_tasks_already_restored = False + + yield + + task_runs._scheduled_task_runs_queues = {} + task_runs._retry_task_runs_queues = {} + task_runs._scheduled_tasks_already_restored = False + + +@pytest.fixture +async def socket_url(prefect_server: str) -> str: + return prefect_server.replace("http", "ws", 1) + + +async def auth_dance(socket: websockets.WebSocketClientProtocol) -> None: + await socket.send(json.dumps({"type": "auth", "token": None})) + response = await socket.recv() + assert json.loads(response) == {"type": "auth_success"} + + +@pytest.fixture +async def authenticated_socket( + socket_url: str, +) -> AsyncGenerator[websockets.WebSocketClientProtocol, None]: + async with websockets.connect( + f"{socket_url}/task_runs/subscriptions/scheduled", + subprotocols=["prefect"], + ) as socket: + await auth_dance(socket) + yield socket + + +async def test_receiving_task_run( + authenticated_socket: websockets.WebSocketClientProtocol, +): + queued = TaskRun( + id=uuid4(), flow_run_id=None, task_key="runme", dynamic_key="runme-1" + ) + task_runs.scheduled_task_runs_queue().put_nowait(queued) + + received = TaskRun.parse_raw(await authenticated_socket.recv()) + + assert received.id == queued.id + + +async def test_receiving_ping_between_each_run( + authenticated_socket: websockets.WebSocketClientProtocol, +): + queue = task_runs.scheduled_task_runs_queue() + queue.put_nowait( + TaskRun(id=uuid4(), flow_run_id=None, task_key="runme", dynamic_key="runme-1") + ) + queue.put_nowait( + TaskRun(id=uuid4(), flow_run_id=None, task_key="runme", dynamic_key="runme-1") + ) + + run = json.loads(await authenticated_socket.recv()) + assert run["task_key"] == "runme" + + ping = json.loads(await authenticated_socket.recv()) + assert ping["type"] == "ping" + + await authenticated_socket.send(json.dumps({"type": "pong"})) + + run = json.loads(await authenticated_socket.recv()) + assert run["task_key"] == "runme" + + ping = json.loads(await authenticated_socket.recv()) + assert ping["type"] == "ping" + + await authenticated_socket.send(json.dumps({"type": "pong"})) + + +async def drain( + socket: websockets.WebSocketClientProtocol, timeout: float = 1 +) -> List[TaskRun]: + messages = [] + + with anyio.move_on_after(timeout): + while True: + message = json.loads(await socket.recv()) + if message.get("type") == "ping": + await socket.send(json.dumps({"type": "pong"})) + continue + messages.append(TaskRun.parse_obj(message)) + + return messages + + +@pytest.fixture +async def another_socket( + socket_url: str, +) -> AsyncGenerator[websockets.WebSocketClientProtocol, None]: + async with websockets.connect( + f"{socket_url}/task_runs/subscriptions/scheduled", + subprotocols=["prefect"], + ) as socket: + await auth_dance(socket) + yield socket + + +async def test_only_one_socket_gets_each_task_run( + authenticated_socket: websockets.WebSocketClientProtocol, + another_socket: websockets.WebSocketClientProtocol, +): + queue = task_runs.scheduled_task_runs_queue() + + queued: List[TaskRun] = [] + for _ in range(10): + run = TaskRun( + id=uuid4(), flow_run_id=None, task_key="runme", dynamic_key="runme-1" + ) + queue.put_nowait(run) + queued.append(run) + + received1, received2 = await gather( + drain(authenticated_socket), + drain(another_socket), + ) + + received1_ids = {r.id for r in received1} + received2_ids = {r.id for r in received2} + + # Each socket should have gotten some runs, and each run should have only been + # sent to one of the sockets + assert received1_ids, "Each socket should have gotten at least one run" + assert received2_ids, "Each socket should have gotten at least one run" + assert received1_ids.isdisjoint(received2_ids) + + queued_ids = {r.id for r in queued} + received_ids = received1_ids | received2_ids + + # While the asynchrony of this test means we won't necessarily get all 10, + # we should be getting at least 5 task runs + assert 5 <= len(received_ids) <= len(queued) + + assert received_ids.issubset(queued_ids) + + +async def test_server_redelivers_unacknowledged_runs(socket_url: str): + queue = task_runs.scheduled_task_runs_queue() + + run = TaskRun(id=uuid4(), flow_run_id=None, task_key="runme", dynamic_key="runme-1") + queue.put_nowait(run) + + async with websockets.connect( + f"{socket_url}/task_runs/subscriptions/scheduled", + subprotocols=["prefect"], + ) as socket1: + await auth_dance(socket1) + + received = json.loads(await socket1.recv()) + assert received["id"] == str(run.id) + + ping = json.loads(await socket1.recv()) + assert ping["type"] == "ping" + + # but importantly, disconnect without acknowledging the ping + + async with websockets.connect( + f"{socket_url}/task_runs/subscriptions/scheduled", + subprotocols=["prefect"], + ) as socket2: + await auth_dance(socket2) + + received = json.loads(await socket2.recv()) + assert received["id"] == str(run.id) + + +async def test_server_restores_scheduled_task_runs_at_startup( + session: AsyncSession, event_loop: AbstractEventLoop, unused_tcp_port: int +): + stored_run = TaskRun( + id=uuid4(), + flow_run_id=None, + task_key="runme", + dynamic_key="runme-1", + state=Scheduled(), + ) + await models.task_runs.create_task_run(session, stored_run) + await session.commit() + + async with running_prefect_server(event_loop, unused_tcp_port) as api_url: + # TODO: why isn't this startup item running? + await task_runs.restore_scheduled_tasks() + + socket_url = api_url.replace("http", "ws", 1) + + async with websockets.connect( + f"{socket_url}/task_runs/subscriptions/scheduled", + subprotocols=["prefect"], + ) as socket: + await auth_dance(socket) + + received = json.loads(await socket.recv()) + assert received["id"] == str(stored_run.id)