Skip to content

Commit

Permalink
Tests for autonomous task subscriptions (#11801)
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisguidry authored Feb 1, 2024
1 parent 4ac4232 commit ceba448
Show file tree
Hide file tree
Showing 3 changed files with 349 additions and 16 deletions.
54 changes: 48 additions & 6 deletions src/prefect/server/api/task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
import prefect.server.schemas as schemas
from prefect.logging import get_logger
from prefect.server.api.run_history import run_history
from prefect.server.database.dependencies import provide_database_interface
from prefect.server.database.dependencies import inject_db, provide_database_interface
from prefect.server.database.interface import PrefectDBInterface
from prefect.server.orchestration import dependencies as orchestration_dependencies
from prefect.server.orchestration.policies import BaseOrchestrationPolicy
from prefect.server.schemas import filters, states
from prefect.server.schemas.responses import OrchestrationResult
from prefect.server.utilities import subscriptions
from prefect.server.utilities.schemas import DateTimeTZ
Expand All @@ -35,6 +36,7 @@

logger = get_logger("server.api")


router = PrefectRouter(prefix="/task_runs", tags=["Task Runs"])


Expand Down Expand Up @@ -296,15 +298,18 @@ async def scheduled_task_subscription(websocket: WebSocket):
if not websocket:
return

await restore_scheduled_tasks()

scheduled_queue = scheduled_task_runs_queue()
retry_queue = retry_task_runs_queue()

while True:
task_run: schemas.core.TaskRun = None
# First, check if there's anything in the retry queue
if not retry_queue.empty():
task_run = await retry_queue.get()
else:
task_run: schemas.core.TaskRun

try:
# First, check if there's anything in the retry queue
task_run = retry_queue.get_nowait()
except asyncio.QueueEmpty:
task_run = await scheduled_queue.get()

try:
Expand All @@ -318,3 +323,40 @@ async def scheduled_task_subscription(websocket: WebSocket):
# If sending fails or pong fails, put the task back into the retry queue
await retry_queue.put(task_run)
break


_scheduled_tasks_already_restored: bool = False


@inject_db
async def restore_scheduled_tasks(db: PrefectDBInterface):
global _scheduled_tasks_already_restored
if _scheduled_tasks_already_restored:
return

_scheduled_tasks_already_restored = True

if not PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING.value():
return

async with db.session_context() as session:
task_runs = await models.task_runs.read_task_runs(
session=session,
task_run_filter=filters.TaskRunFilter(
flow_run_id=filters.TaskRunFilterFlowRunId(is_null_=True),
state=filters.TaskRunFilterState(
type=filters.TaskRunFilterStateType(
any_=[states.StateType.SCHEDULED]
)
),
),
)

if not task_runs:
return

queue = retry_task_runs_queue()
for task_run in task_runs:
queue.put_nowait(schemas.core.TaskRun.from_orm(task_run))

logger.info("Restored %s scheduled task runs", len(task_runs))
18 changes: 8 additions & 10 deletions src/prefect/server/utilities/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,19 @@
WebSocket,
)
from starlette.status import WS_1002_PROTOCOL_ERROR, WS_1008_POLICY_VIOLATION
from starlette.websockets import WebSocketDisconnect
from websockets.exceptions import ConnectionClosed

NORMAL_DISCONNECT_EXCEPTIONS = (IOError, ConnectionClosed)
NORMAL_DISCONNECT_EXCEPTIONS = (IOError, ConnectionClosed, WebSocketDisconnect)


async def ping_pong(websocket: WebSocket):
try:
await websocket.send_json({"type": "ping"})

response = await websocket.receive_json()
if response.get("type") == "pong":
return True
else:
return False
except Exception:
await websocket.send_json({"type": "ping"})

response = await websocket.receive_json()
if response.get("type") == "pong":
return True
else:
return False


Expand Down
Loading

0 comments on commit ceba448

Please sign in to comment.