diff --git a/src/prefect/server/services/task_run_recorder.py b/src/prefect/server/services/task_run_recorder.py index f9e75813abe2..f8e91e97b834 100644 --- a/src/prefect/server/services/task_run_recorder.py +++ b/src/prefect/server/services/task_run_recorder.py @@ -1,6 +1,5 @@ import asyncio -from contextlib import AsyncExitStack, asynccontextmanager -from datetime import timedelta +from contextlib import asynccontextmanager from typing import Any, AsyncGenerator, Dict, Optional from uuid import UUID @@ -119,43 +118,31 @@ def task_run_from_event(event: ReceivedEvent) -> TaskRun: ) -async def record_task_run_event(event: ReceivedEvent, depth: int = 0): - db = provide_database_interface() - - async with AsyncExitStack() as stack: - await stack.enter_async_context( - ( - causal_ordering().preceding_event_confirmed( - record_task_run_event, event, depth=depth - ) - ) - ) - - task_run = task_run_from_event(event) +async def record_task_run_event(event: ReceivedEvent): + task_run = task_run_from_event(event) - task_run_attributes = task_run.model_dump_for_orm( - exclude={ - "state_id", - "state", - "created", - "estimated_run_time", - "estimated_start_time_delta", - }, - exclude_unset=True, - ) + task_run_attributes = task_run.model_dump_for_orm( + exclude={ + "state_id", + "state", + "created", + "estimated_run_time", + "estimated_start_time_delta", + }, + exclude_unset=True, + ) - assert task_run.state + assert task_run.state - denormalized_state_attributes = { - "state_id": task_run.state.id, - "state_type": task_run.state.type, - "state_name": task_run.state.name, - "state_timestamp": task_run.state.timestamp, - } - session = await stack.enter_async_context( - db.session_context(begin_transaction=True) - ) + denormalized_state_attributes = { + "state_id": task_run.state.id, + "state_type": task_run.state.type, + "state_name": task_run.state.name, + "state_timestamp": task_run.state.timestamp, + } + db = provide_database_interface() + async with db.session_context(begin_transaction=True) as session: await _insert_task_run(session, task_run, task_run_attributes) await _insert_task_run_state(session, task_run) await _update_task_run_with_state( @@ -177,39 +164,8 @@ async def record_task_run_event(event: ReceivedEvent, depth: int = 0): ) -async def record_lost_follower_task_run_events(): - events = await causal_ordering().get_lost_followers() - - for event in events: - await record_task_run_event(event) - - -async def periodically_process_followers(periodic_granularity: timedelta): - """Periodically process followers that are waiting on a leader event that never arrived""" - logger.debug( - "Starting periodically process followers task every %s seconds", - periodic_granularity.total_seconds(), - ) - while True: - try: - await record_lost_follower_task_run_events() - except asyncio.CancelledError: - logger.debug("Periodically process followers task cancelled") - return - except Exception: - logger.exception("Error while processing task-run-recorders followers.") - finally: - await asyncio.sleep(periodic_granularity.total_seconds()) - - @asynccontextmanager -async def consumer( - periodic_granularity: timedelta = timedelta(seconds=5), -) -> AsyncGenerator[MessageHandler, None]: - record_lost_followers_task = asyncio.create_task( - periodically_process_followers(periodic_granularity=periodic_granularity) - ) - +async def consumer() -> AsyncGenerator[MessageHandler, None]: async def message_handler(message: Message): event: ReceivedEvent = ReceivedEvent.model_validate_json(message.data) @@ -231,14 +187,7 @@ async def message_handler(message: Message): # event arrives. pass - try: - yield message_handler - finally: - try: - record_lost_followers_task.cancel() - await record_lost_followers_task - except asyncio.CancelledError: - logger.debug("Periodically process followers task cancelled successfully") + yield message_handler class TaskRunRecorder: diff --git a/tests/server/services/test_task_run_recorder.py b/tests/server/services/test_task_run_recorder.py index fda050101147..08846aa392b2 100644 --- a/tests/server/services/test_task_run_recorder.py +++ b/tests/server/services/test_task_run_recorder.py @@ -1,25 +1,23 @@ import asyncio from datetime import timedelta +from itertools import permutations from typing import AsyncGenerator -from unittest.mock import AsyncMock, patch -from uuid import UUID, uuid4 +from uuid import UUID import pendulum import pytest from sqlalchemy.ext.asyncio import AsyncSession -from prefect.server.events.ordering import EventArrivedEarly from prefect.server.events.schemas.events import ReceivedEvent from prefect.server.models.flow_runs import create_flow_run -from prefect.server.models.task_run_states import read_task_run_state +from prefect.server.models.task_run_states import ( + read_task_run_state, + read_task_run_states, +) from prefect.server.models.task_runs import read_task_run from prefect.server.schemas.core import FlowRun, TaskRunPolicy from prefect.server.schemas.states import StateDetails, StateType from prefect.server.services import task_run_recorder -from prefect.server.services.task_run_recorder import ( - record_lost_follower_task_run_events, - record_task_run_event, -) from prefect.server.utilities.messaging import MessageHandler from prefect.server.utilities.messaging.memory import MemoryMessage @@ -40,9 +38,7 @@ async def test_start_and_stop_service(): @pytest.fixture async def task_run_recorder_handler() -> AsyncGenerator[MessageHandler, None]: - async with task_run_recorder.consumer( - periodic_granularity=timedelta(seconds=0.0001) - ) as handler: + async with task_run_recorder.consumer() as handler: yield handler @@ -711,50 +707,49 @@ async def test_updates_task_run_on_out_of_order_state_change( ) -async def test_lost_followers_are_recorded(monkeypatch: pytest.MonkeyPatch): - now = pendulum.now("UTC") - event = ReceivedEvent( - occurred=now.subtract(minutes=2), - received=now.subtract(minutes=1), - event="prefect.task-run.Running", - resource={ - "prefect.resource.id": f"prefect.task-run.{str(uuid4())}", - }, - account=uuid4(), - workspace=uuid4(), - follows=uuid4(), - id=uuid4(), - ) - # record a follower that never sees its leader - with pytest.raises(EventArrivedEarly): - await record_task_run_event(event) - - record_task_run_event_mock = AsyncMock() - monkeypatch.setattr( - "prefect.server.services.task_run_recorder.record_task_run_event", - record_task_run_event_mock, - ) - - # move time forward so we can record the lost follower - with patch("prefect.server.events.ordering.pendulum.now") as the_future: - the_future.return_value = now.add(minutes=20) - await record_lost_follower_task_run_events() - - assert record_task_run_event_mock.await_count == 1 - record_task_run_event_mock.assert_awaited_with(event) +@pytest.mark.parametrize( + "event_order", + list(permutations(["PENDING", "RUNNING", "COMPLETED"])), + ids=lambda x: "->".join(x), +) +async def test_task_run_recorder_handles_all_out_of_order_permutations( + session: AsyncSession, + pending_event: ReceivedEvent, + running_event: ReceivedEvent, + completed_event: ReceivedEvent, + task_run_recorder_handler: MessageHandler, + event_order: tuple, +): + # Set up event times + base_time = pendulum.datetime(2024, 1, 1, 0, 0, 0, 0, "UTC") + pending_event.occurred = base_time + running_event.occurred = base_time.add(minutes=1) + completed_event.occurred = base_time.add(minutes=2) + + event_map = { + "PENDING": pending_event, + "RUNNING": running_event, + "COMPLETED": completed_event, + } + # Process events in the specified order + for event_name in event_order: + await task_run_recorder_handler(message(event_map[event_name])) -async def test_lost_followers_are_recorded_periodically( - task_run_recorder_handler, - monkeypatch: pytest.MonkeyPatch, -): - record_lost_follower_task_run_events_mock = AsyncMock() - monkeypatch.setattr( - "prefect.server.services.task_run_recorder.record_lost_follower_task_run_events", - record_lost_follower_task_run_events_mock, + # Verify the task run always has the "final" state + task_run = await read_task_run( + session=session, + task_run_id=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), ) - # let the period task run a few times - await asyncio.sleep(0.1) + assert task_run + assert task_run.state_type == StateType.COMPLETED + assert task_run.state_name == "Completed" + assert task_run.state_timestamp == completed_event.occurred + + # Verify all states are recorded + states = await read_task_run_states(session, task_run.id) + assert len(states) == 3 - assert record_lost_follower_task_run_events_mock.await_count >= 1 + state_types = set(state.type for state in states) + assert state_types == {StateType.PENDING, StateType.RUNNING, StateType.COMPLETED}