From f9d348ea3d528ff8627b20c3fbbad938fd0dc6be Mon Sep 17 00:00:00 2001 From: "jake@prefect.io" Date: Wed, 24 Jul 2024 09:37:46 -0400 Subject: [PATCH 1/3] task run recorder --- src/prefect/server/events/schemas/events.py | 10 + .../server/services/task_run_recorder.py | 177 +++++- .../server/services/test_task_run_recorder.py | 592 ++++++++++++++++-- 3 files changed, 742 insertions(+), 37 deletions(-) diff --git a/src/prefect/server/events/schemas/events.py b/src/prefect/server/events/schemas/events.py index 07b240fd6d7e..8d46e1d70c93 100644 --- a/src/prefect/server/events/schemas/events.py +++ b/src/prefect/server/events/schemas/events.py @@ -66,6 +66,16 @@ def id(self) -> str: def name(self) -> Optional[str]: return self.get("prefect.resource.name") + def prefect_object_id(self, kind: str) -> UUID: + """Extracts the UUID from an event's resource ID if it's the expected kind + of prefect resource""" + prefix = f"{kind}." if not kind.endswith(".") else kind + + if not self.id.startswith(prefix): + raise ValueError(f"Resource ID {self.id} does not start with {prefix}") + + return UUID(self.id[len(prefix) :]) + class RelatedResource(Resource): """A Resource with a specific role in an Event""" diff --git a/src/prefect/server/services/task_run_recorder.py b/src/prefect/server/services/task_run_recorder.py index a9ad73d39612..037b8fe932c0 100644 --- a/src/prefect/server/services/task_run_recorder.py +++ b/src/prefect/server/services/task_run_recorder.py @@ -1,14 +1,179 @@ import asyncio -from contextlib import asynccontextmanager -from typing import AsyncGenerator, Optional +from contextlib import AsyncExitStack, asynccontextmanager +from typing import Any, AsyncGenerator, Dict, Optional +from uuid import UUID + +import pendulum +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncSession from prefect.logging import get_logger +from prefect.server.database.dependencies import db_injector, provide_database_interface +from prefect.server.database.interface import PrefectDBInterface +from prefect.server.events.ordering import CausalOrdering, EventArrivedEarly from prefect.server.events.schemas.events import ReceivedEvent +from prefect.server.schemas.core import TaskRun +from prefect.server.schemas.states import State from prefect.server.utilities.messaging import Message, MessageHandler, create_consumer logger = get_logger(__name__) +def causal_ordering(): + return CausalOrdering( + "task-run-recorder", + ) + + +@db_injector +async def _insert_task_run( + db: PrefectDBInterface, + session: AsyncSession, + task_run: TaskRun, + task_run_attributes: Dict[str, Any], +): + await session.execute( + db.insert(db.TaskRun) + .values( + created=pendulum.now("UTC"), + **task_run_attributes, + ) + .on_conflict_do_update( + index_elements=[ + "id", + ], + set_={ + "updated": pendulum.now("UTC"), + **task_run_attributes, + }, + where=db.TaskRun.state_timestamp < task_run.state.timestamp, + ) + ) + + +@db_injector +async def _insert_task_run_state( + db: PrefectDBInterface, session: AsyncSession, task_run: TaskRun +): + await session.execute( + db.insert(db.TaskRunState) + .values( + created=pendulum.now("UTC"), + task_run_id=task_run.id, + **task_run.state.model_dump(), + ) + .on_conflict_do_nothing( + index_elements=[ + "id", + ] + ) + ) + + +@db_injector +async def _update_task_run_with_state( + db: PrefectDBInterface, + session: AsyncSession, + task_run: TaskRun, + denormalized_state_attributes: Dict[str, Any], +): + await session.execute( + sa.update(db.TaskRun) + .where( + db.TaskRun.id == task_run.id, + sa.or_( + db.TaskRun.state_timestamp.is_(None), + db.TaskRun.state_timestamp < task_run.state.timestamp, + ), + ) + .values(**denormalized_state_attributes) + ) + + +def task_run_from_event(event: ReceivedEvent) -> TaskRun: + task_run_id = event.resource.prefect_object_id("prefect.task-run") + + flow_run_id: Optional[UUID] = None + if flow_run_resource := event.resource_in_role.get("flow-run"): + flow_run_id = flow_run_resource.prefect_object_id("prefect.flow-run") + + state: State = State.model_validate( + { + "id": event.id, + "timestamp": event.occurred, + **event.payload["validated_state"], + } + ) + state.state_details.task_run_id = task_run_id + state.state_details.flow_run_id = flow_run_id + + return TaskRun.model_validate( + { + "id": task_run_id, + "flow_run_id": flow_run_id, + "state_id": state.id, + "state": state, + **event.payload["task_run"], + } + ) + + +async def record_task_run_event(event: ReceivedEvent, depth: int = 0): + db = provide_database_interface() + + async with AsyncExitStack() as stack: + await stack.enter_async_context( + ( + causal_ordering().preceding_event_confirmed( + record_task_run_event, event, depth=depth + ) + ) + ) + + task_run = task_run_from_event(event) + + task_run_attributes = task_run.model_dump_for_orm( + exclude={ + "state_id", + "state", + "created", + "estimated_run_time", + "estimated_start_time_delta", + }, + exclude_unset=True, + ) + + assert task_run.state + + denormalized_state_attributes = { + "state_id": task_run.state.id, + "state_type": task_run.state.type, + "state_name": task_run.state.name, + "state_timestamp": task_run.state.timestamp, + } + session = await stack.enter_async_context(db.session_context()) + + await _insert_task_run(session, task_run, task_run_attributes) + await _insert_task_run_state(session, task_run) + await _update_task_run_with_state( + session, task_run, denormalized_state_attributes + ) + + logger.info( + "Recorded task run state change", + extra={ + "task_run_id": task_run.id, + "flow_run_id": task_run.flow_run_id, + "event_id": event.id, + "event_follows": event.follows, + "event": event.event, + "occurred": event.occurred, + "current_state_type": task_run.state_type, + "current_state_name": task_run.state_name, + }, + ) + + @asynccontextmanager async def consumer() -> AsyncGenerator[MessageHandler, None]: async def message_handler(message: Message): @@ -24,6 +189,14 @@ async def message_handler(message: Message): f"Received event: {event.event} with id: {event.id} for resource: {event.resource.get('prefect.resource.id')}" ) + try: + await record_task_run_event(event) + except EventArrivedEarly: + # We're safe to ACK this message because it has been parked by the + # causal ordering mechanism and will be reprocessed when the preceding + # event arrives. + pass + yield message_handler diff --git a/tests/server/services/test_task_run_recorder.py b/tests/server/services/test_task_run_recorder.py index 333bb4c9a98f..c00804128b58 100644 --- a/tests/server/services/test_task_run_recorder.py +++ b/tests/server/services/test_task_run_recorder.py @@ -1,11 +1,18 @@ import asyncio +from datetime import timedelta from typing import AsyncGenerator from uuid import UUID import pendulum import pytest +from sqlalchemy.ext.asyncio import AsyncSession from prefect.server.events.schemas.events import ReceivedEvent +from prefect.server.models.flow_runs import create_flow_run +from prefect.server.models.task_run_states import read_task_run_state +from prefect.server.models.task_runs import read_task_run +from prefect.server.schemas.core import FlowRun, TaskRunPolicy +from prefect.server.schemas.states import StateDetails, StateType from prefect.server.services import task_run_recorder from prefect.server.utilities.messaging import MessageHandler from prefect.server.utilities.messaging.memory import MemoryMessage @@ -31,6 +38,13 @@ async def task_run_recorder_handler() -> AsyncGenerator[MessageHandler, None]: yield handler +def message(event: ReceivedEvent) -> MemoryMessage: + return MemoryMessage( + data=event.model_dump_json().encode(), + attributes={}, + ) + + @pytest.fixture def hello_event() -> ReceivedEvent: return ReceivedEvent( @@ -74,6 +88,33 @@ def client_orchestrated_task_run_event() -> ReceivedEvent: "intended": {"from": "PENDING", "to": "RUNNING"}, "initial_state": {"type": "PENDING", "name": "Pending", "message": ""}, "validated_state": {"type": "RUNNING", "name": "Running", "message": ""}, + "task_run": { + "name": "my_task", + "task_key": "add-0bf8d992", + "dynamic_key": "add-0bf8d992-4bb2bae02a7f4ac6afaf493d28a57d96", + "empirical_policy": { + "max_retries": 0, + "retry_delay_seconds": 0, + "retries": 0, + "retry_delay": 0, + }, + "tags": [], + "task_inputs": {"x": [], "y": []}, + "run_count": 1, + "flow_run_run_count": 0, + "expected_start_time": pendulum.datetime( + 2022, 1, 2, 3, 4, 5, 5, "UTC" + ).isoformat(), + "start_time": pendulum.datetime( + 2022, 1, 2, 3, 4, 5, 5, "UTC" + ).isoformat(), + "end_time": pendulum.datetime( + 2022, 1, 2, 3, 4, 5, 6, "UTC" + ).isoformat(), + "total_run_time": 0.002024, + "estimated_run_time": 0, + "estimated_start_time_delta": 0, + }, }, account=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), workspace=UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"), @@ -113,55 +154,25 @@ def server_orchestrated_task_run_event() -> ReceivedEvent: ) -@pytest.fixture -def client_orchestrated_task_run_event_message( - client_orchestrated_task_run_event: ReceivedEvent, -) -> MemoryMessage: - return MemoryMessage( - data=client_orchestrated_task_run_event.model_dump_json().encode(), - attributes={}, - ) - - -@pytest.fixture -def server_orchestrated_task_run_event_message( - server_orchestrated_task_run_event: ReceivedEvent, -) -> MemoryMessage: - return MemoryMessage( - data=server_orchestrated_task_run_event.model_dump_json().encode(), - attributes={}, - ) - - -@pytest.fixture -def hello_event_message(hello_event: ReceivedEvent) -> MemoryMessage: - return MemoryMessage( - data=hello_event.model_dump_json().encode(), - attributes={}, - ) - - async def test_handle_client_orchestrated_task_run_event( task_run_recorder_handler: MessageHandler, client_orchestrated_task_run_event: ReceivedEvent, - client_orchestrated_task_run_event_message: MemoryMessage, caplog: pytest.LogCaptureFixture, ): with caplog.at_level("INFO"): - await task_run_recorder_handler(client_orchestrated_task_run_event_message) + await task_run_recorder_handler(message(client_orchestrated_task_run_event)) - assert "Received event" in caplog.text + assert "Recorded task run state change" in caplog.text assert str(client_orchestrated_task_run_event.id) in caplog.text async def test_skip_non_task_run_event( task_run_recorder_handler: MessageHandler, hello_event: ReceivedEvent, - hello_event_message: MemoryMessage, caplog: pytest.LogCaptureFixture, ): with caplog.at_level("INFO"): - await task_run_recorder_handler(hello_event_message) + await task_run_recorder_handler(message(hello_event)) assert "Received event" not in caplog.text assert str(hello_event.id) not in caplog.text @@ -170,11 +181,522 @@ async def test_skip_non_task_run_event( async def test_skip_server_side_orchestrated_task_run( task_run_recorder_handler: MessageHandler, server_orchestrated_task_run_event: ReceivedEvent, - server_orchestrated_task_run_event_message: MemoryMessage, caplog: pytest.LogCaptureFixture, ): with caplog.at_level("INFO"): - await task_run_recorder_handler(server_orchestrated_task_run_event_message) + await task_run_recorder_handler(message(server_orchestrated_task_run_event)) assert "Received event" not in caplog.text assert str(server_orchestrated_task_run_event.id) not in caplog.text + + +@pytest.fixture +async def flow_run(session: AsyncSession, flow): + flow_run = await create_flow_run( + session=session, + flow_run=FlowRun( + id=UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), + flow_id=flow.id, + ), + ) + return flow_run + + +@pytest.fixture +def pending_event(flow_run) -> ReceivedEvent: + occurred = pendulum.datetime(2024, 1, 1, 0, 0, 0, 0, "UTC") + return ReceivedEvent( + occurred=occurred, + event="prefect.task-run.Pending", + resource={ + "prefect.resource.id": "prefect.task-run.aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", + "prefect.resource.name": "my_task", + "prefect.state-message": "", + "prefect.state-type": "PENDING", + "prefect.state-name": "Pending", + "prefect.state-timestamp": occurred.isoformat(), + "prefect.orchestration": "client", + }, + related=[ + { + "prefect.resource.id": "prefect.flow-run.ffffffff-ffff-ffff-ffff-ffffffffffff", + "prefect.resource.role": "flow-run", + }, + ], + payload={ + "intended": {"from": None, "to": "PENDING"}, + "initial_state": None, + "validated_state": { + "type": "PENDING", + "name": "Pending", + "message": "Hi there!", + "state_details": { + "pause_reschedule": False, + "untrackable_result": False, + }, + "data": None, + }, + "task_run": { + "task_key": "my_task-abcdefg", + "dynamic_key": "1", + "empirical_policy": { + "max_retries": 2, + "retries": 3, + "retry_delay": 4, + "retry_delay_seconds": 5.0, + }, + "expected_start_time": "2024-01-01T00:00:00Z", + "estimated_start_time_delta": 0.1, + "name": "my_task", + "tags": [ + "tag-1", + "tag-2", + ], + "task_inputs": { + "x": [{"input_type": "parameter", "name": "x"}], + "y": [{"input_type": "parameter", "name": "y"}], + }, + }, + }, + received=occurred + timedelta(seconds=1), + follows=None, + id=UUID("11111111-1111-1111-1111-111111111111"), + ) + + +@pytest.fixture +def running_event(flow_run) -> ReceivedEvent: + occurred = pendulum.datetime(2024, 1, 1, 0, 1, 0, 0, "UTC") + return ReceivedEvent( + occurred=occurred, + event="prefect.task-run.Running", + resource={ + "prefect.resource.id": "prefect.task-run.aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", + "prefect.resource.name": "my_task", + "prefect.state-message": "", + "prefect.state-type": "RUNNING", + "prefect.state-name": "Running", + "prefect.state-timestamp": occurred.isoformat(), + "prefect.orchestration": "client", + }, + related=[ + { + "prefect.resource.id": "prefect.flow-run.ffffffff-ffff-ffff-ffff-ffffffffffff", + "prefect.resource.role": "flow-run", + }, + ], + payload={ + "intended": {"from": "PENDING", "to": "RUNNING"}, + "initial_state": { + "type": "PENDING", + "name": "Pending", + "message": "", + "state_details": { + "pause_reschedule": False, + "untrackable_result": False, + }, + }, + "validated_state": { + "type": "RUNNING", + "name": "Running", + "message": "Weeeeeee look at me go!", + "state_details": { + "pause_reschedule": False, + "untrackable_result": False, + }, + "data": None, + }, + "task_run": { + "task_key": "my_task-abcdefg", + "dynamic_key": "1", + "empirical_policy": { + "max_retries": 2, + "retries": 3, + "retry_delay": 4, + "retry_delay_seconds": 5.0, + }, + "estimated_run_time": 6.0, + "expected_start_time": "2024-01-01T00:00:00Z", + "estimated_start_time_delta": 0.1, + "flow_run_run_count": 7, + "name": "my_task", + "run_count": 8, + "start_time": "2024-01-01T00:01:00Z", + "tags": [ + "tag-1", + "tag-2", + ], + "task_inputs": { + "x": [{"input_type": "parameter", "name": "x"}], + "y": [{"input_type": "parameter", "name": "y"}], + }, + "total_run_time": 9.0, + }, + }, + received=occurred + timedelta(seconds=1), + follows=UUID("11111111-1111-1111-1111-111111111111"), + id=UUID("22222222-2222-2222-2222-222222222222"), + ) + + +@pytest.fixture +def completed_event(flow_run) -> ReceivedEvent: + occurred = pendulum.datetime(2024, 1, 1, 0, 2, 0, 0, "UTC") + return ReceivedEvent( + occurred=occurred, + event="prefect.task-run.Completed", + resource={ + "prefect.resource.id": "prefect.task-run.aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", + "prefect.resource.name": "my_task", + "prefect.state-message": "", + "prefect.state-type": "COMPLETED", + "prefect.state-name": "Completed", + "prefect.state-timestamp": occurred.isoformat(), + "prefect.orchestration": "client", + }, + related=[ + { + "prefect.resource.id": "prefect.flow-run.ffffffff-ffff-ffff-ffff-ffffffffffff", + "prefect.resource.role": "flow-run", + }, + ], + payload={ + "intended": {"from": "RUNNING", "to": "COMPLETED"}, + "initial_state": { + "type": "RUNNING", + "name": "Running", + "message": "", + "state_details": { + "pause_reschedule": False, + "untrackable_result": False, + }, + }, + "validated_state": { + "type": "COMPLETED", + "name": "Completed", + "message": "Stick a fork in me, I'm done", + "state_details": { + "pause_reschedule": False, + "untrackable_result": False, + }, + "data": {"type": "unpersisted"}, + }, + "task_run": { + # required fields + "task_key": "my_task-abcdefg", + "dynamic_key": "1", + # Only set the end_time, to test partial updates + "end_time": "2024-01-01T00:02:00Z", + }, + }, + received=occurred + timedelta(seconds=1), + follows=UUID("22222222-2222-2222-2222-222222222222"), + id=UUID("33333333-3333-3333-3333-333333333333"), + ) + + +async def test_recording_single_event( + session: AsyncSession, + pending_event: ReceivedEvent, + task_run_recorder_handler: MessageHandler, +): + pending_transition_time = pendulum.datetime(2024, 1, 1, 0, 0, 0, 0, "UTC") + assert pending_event.occurred == pending_transition_time + + await task_run_recorder_handler(message(pending_event)) + + task_run = await read_task_run( + session=session, + task_run_id=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), + ) + + assert task_run + + assert task_run.id == UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") + assert task_run.name == "my_task" + assert task_run.flow_run_id == UUID("ffffffff-ffff-ffff-ffff-ffffffffffff") + assert task_run.task_key == "my_task-abcdefg" + assert task_run.dynamic_key == "1" + assert task_run.tags == ["tag-1", "tag-2"] + + assert task_run.flow_run_run_count == 0 + assert task_run.run_count == 0 + assert task_run.total_run_time == timedelta(0) + assert task_run.task_inputs == { + "x": [{"input_type": "parameter", "name": "x"}], + "y": [{"input_type": "parameter", "name": "y"}], + } + assert task_run.empirical_policy == TaskRunPolicy( + max_retries=2, + retries=3, + retry_delay=4, + retry_delay_seconds=5.0, + ) + + assert task_run.expected_start_time == pending_transition_time + assert task_run.start_time is None + assert task_run.end_time is None + + assert task_run.state_id == UUID("11111111-1111-1111-1111-111111111111") + assert task_run.state_timestamp == pending_transition_time + assert task_run.state_type == StateType.PENDING + assert task_run.state_name == "Pending" + assert task_run.state_timestamp == pending_transition_time + + state = await read_task_run_state( + session=session, + task_run_state_id=UUID("11111111-1111-1111-1111-111111111111"), + ) + + assert state + + assert state.id == UUID("11111111-1111-1111-1111-111111111111") + assert state.type == StateType.PENDING + assert state.name == "Pending" + assert state.message == "Hi there!" + assert state.timestamp == pending_transition_time + assert state.state_details == StateDetails( + flow_run_id=UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), + task_run_id=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), + pause_reschedule=False, + untrackable_result=False, + ) + + +async def test_updates_task_run_on_subsequent_state_changes( + session: AsyncSession, + pending_event: ReceivedEvent, + running_event: ReceivedEvent, + task_run_recorder_handler: MessageHandler, +): + pending_transition_time = pendulum.datetime(2024, 1, 1, 0, 0, 0, 0, "UTC") + assert pending_event.occurred == pending_transition_time + + running_transition_time = pendulum.datetime(2024, 1, 1, 0, 1, 0, 0, "UTC") + assert running_event.occurred == running_transition_time + + await task_run_recorder_handler(message(pending_event)) + await task_run_recorder_handler(message(running_event)) + + task_run = await read_task_run( + session=session, + task_run_id=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), + ) + + assert task_run + + assert task_run.id == UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") + assert task_run.name == "my_task" + assert task_run.flow_run_id == UUID("ffffffff-ffff-ffff-ffff-ffffffffffff") + assert task_run.task_key == "my_task-abcdefg" + assert task_run.dynamic_key == "1" + assert task_run.tags == ["tag-1", "tag-2"] + + assert task_run.flow_run_run_count == 7 + assert task_run.run_count == 8 + assert task_run.total_run_time == timedelta(seconds=9) + assert task_run.task_inputs == { + "x": [{"input_type": "parameter", "name": "x"}], + "y": [{"input_type": "parameter", "name": "y"}], + } + assert task_run.empirical_policy == TaskRunPolicy( + max_retries=2, + retries=3, + retry_delay=4, + retry_delay_seconds=5.0, + ) + + assert task_run.expected_start_time == pending_transition_time + assert task_run.start_time == running_transition_time + assert task_run.end_time is None + + assert task_run.state_id == UUID("22222222-2222-2222-2222-222222222222") + assert task_run.state_timestamp == running_transition_time + assert task_run.state_type == StateType.RUNNING + assert task_run.state_name == "Running" + assert task_run.state_timestamp == running_transition_time + + state = await read_task_run_state( + session=session, + task_run_state_id=UUID("22222222-2222-2222-2222-222222222222"), + ) + + assert state + + assert state.id == UUID("22222222-2222-2222-2222-222222222222") + assert state.type == StateType.RUNNING + assert state.name == "Running" + assert state.message == "Weeeeeee look at me go!" + assert state.timestamp == running_transition_time + assert state.state_details == StateDetails( + flow_run_id=UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), + task_run_id=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), + pause_reschedule=False, + untrackable_result=False, + ) + + +async def test_updates_only_fields_that_are_set( + session: AsyncSession, + pending_event: ReceivedEvent, + running_event: ReceivedEvent, + completed_event: ReceivedEvent, + task_run_recorder_handler: MessageHandler, +): + pending_transition_time = pendulum.datetime(2024, 1, 1, 0, 0, 0, 0, "UTC") + assert pending_event.occurred == pending_transition_time + + running_transition_time = pendulum.datetime(2024, 1, 1, 0, 1, 0, 0, "UTC") + assert running_event.occurred == running_transition_time + + completed_transition_time = pendulum.datetime(2024, 1, 1, 0, 2, 0, 0, "UTC") + assert completed_event.occurred == completed_transition_time + + await task_run_recorder_handler(message(pending_event)) + await task_run_recorder_handler(message(running_event)) + await task_run_recorder_handler(message(completed_event)) + + task_run = await read_task_run( + session=session, + task_run_id=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), + ) + + assert task_run + + # The Completed transition here in the tests only sets the end_time, so we + # would expect all the other values to reflect what was set in the Running + # transition. + + assert task_run.id == UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") + assert task_run.name == "my_task" + assert task_run.flow_run_id == UUID("ffffffff-ffff-ffff-ffff-ffffffffffff") + assert task_run.task_key == "my_task-abcdefg" + assert task_run.dynamic_key == "1" + assert task_run.tags == ["tag-1", "tag-2"] + + assert task_run.flow_run_run_count == 7 + assert task_run.run_count == 8 + assert task_run.total_run_time == timedelta(seconds=9) + assert task_run.task_inputs == { + "x": [{"input_type": "parameter", "name": "x"}], + "y": [{"input_type": "parameter", "name": "y"}], + } + assert task_run.empirical_policy == TaskRunPolicy( + max_retries=2, + retries=3, + retry_delay=4, + retry_delay_seconds=5.0, + ) + + assert task_run.expected_start_time == pending_transition_time + assert task_run.start_time == running_transition_time + assert task_run.end_time == completed_transition_time + + assert task_run.state_id == UUID("33333333-3333-3333-3333-333333333333") + assert task_run.state_type == StateType.COMPLETED + assert task_run.state_name == "Completed" + assert task_run.state_timestamp == completed_transition_time + + state = await read_task_run_state( + session=session, + task_run_state_id=UUID("33333333-3333-3333-3333-333333333333"), + ) + + assert state + + assert state.id == UUID("33333333-3333-3333-3333-333333333333") + assert state.type == StateType.COMPLETED + assert state.name == "Completed" + assert state.message == "Stick a fork in me, I'm done" + assert state.timestamp == completed_transition_time + assert state.state_details == StateDetails( + flow_run_id=UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), + task_run_id=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), + pause_reschedule=False, + untrackable_result=False, + ) + + +async def test_updates_task_run_on_out_of_order_state_change( + session: AsyncSession, + pending_event: ReceivedEvent, + running_event: ReceivedEvent, + completed_event: ReceivedEvent, + task_run_recorder_handler: MessageHandler, +): + pending_transition_time = pendulum.datetime(2024, 1, 1, 0, 0, 0, 0, "UTC") + assert pending_event.occurred == pending_transition_time + + running_transition_time = pendulum.datetime(2024, 1, 1, 0, 1, 0, 0, "UTC") + assert running_event.occurred == running_transition_time + + # force the completed event to an older time so that it won't update the task run + completed_event.occurred = running_transition_time - timedelta(seconds=1) + completed_transition_time = completed_event.occurred + + await task_run_recorder_handler(message(pending_event)) + await task_run_recorder_handler(message(running_event)) + await task_run_recorder_handler(message(completed_event)) + + task_run = await read_task_run( + session=session, + task_run_id=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), + ) + + assert task_run + + # We expect that the task run will still be showing the denormalized info from + # the prior state change, not the completed state change, because the timestamp + # of the completed state is older. This isn't a sensible thing to happen in + # the wild, but we want to be explicit about the behavior when that happens... + + assert task_run.id == UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") + assert task_run.name == "my_task" + assert task_run.flow_run_id == UUID("ffffffff-ffff-ffff-ffff-ffffffffffff") + assert task_run.task_key == "my_task-abcdefg" + assert task_run.dynamic_key == "1" + assert task_run.tags == ["tag-1", "tag-2"] + + assert task_run.flow_run_run_count == 7 + assert task_run.run_count == 8 + assert task_run.total_run_time == timedelta(seconds=9) + assert task_run.task_inputs == { + "x": [{"input_type": "parameter", "name": "x"}], + "y": [{"input_type": "parameter", "name": "y"}], + } + assert task_run.empirical_policy == TaskRunPolicy( + max_retries=2, + retries=3, + retry_delay=4, + retry_delay_seconds=5.0, + ) + + assert task_run.expected_start_time == pending_transition_time + assert task_run.start_time == running_transition_time + assert task_run.end_time is None + + assert task_run.state_id == UUID("22222222-2222-2222-2222-222222222222") + assert task_run.state_timestamp == running_transition_time + assert task_run.state_type == StateType.RUNNING + assert task_run.state_name == "Running" + assert task_run.state_timestamp == running_transition_time + # ...however, the new completed state _is_ recorded + + state = await read_task_run_state( + session=session, + task_run_state_id=UUID("33333333-3333-3333-3333-333333333333"), + ) + + assert state + + assert state.id == UUID("33333333-3333-3333-3333-333333333333") + assert state.type == StateType.COMPLETED + assert state.name == "Completed" + assert state.message == "Stick a fork in me, I'm done" + assert state.timestamp == completed_transition_time + assert state.state_details == StateDetails( + flow_run_id=UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), + task_run_id=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), + pause_reschedule=False, + untrackable_result=False, + ) From d59dd3bd55d54b5b5bee3ca6486ac74a35903a9d Mon Sep 17 00:00:00 2001 From: "jake@prefect.io" Date: Wed, 24 Jul 2024 09:54:16 -0400 Subject: [PATCH 2/3] make sure to commit the flow run --- tests/server/services/test_task_run_recorder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/server/services/test_task_run_recorder.py b/tests/server/services/test_task_run_recorder.py index c00804128b58..ecb48281f4d5 100644 --- a/tests/server/services/test_task_run_recorder.py +++ b/tests/server/services/test_task_run_recorder.py @@ -199,6 +199,7 @@ async def flow_run(session: AsyncSession, flow): flow_id=flow.id, ), ) + await session.commit() return flow_run From 26e74d59176afc28e3282c1c2d4523a154210f0f Mon Sep 17 00:00:00 2001 From: "jake@prefect.io" Date: Wed, 24 Jul 2024 10:05:41 -0400 Subject: [PATCH 3/3] make sure to enter a transaction --- src/prefect/server/services/task_run_recorder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/prefect/server/services/task_run_recorder.py b/src/prefect/server/services/task_run_recorder.py index 037b8fe932c0..cb79a2d05d8c 100644 --- a/src/prefect/server/services/task_run_recorder.py +++ b/src/prefect/server/services/task_run_recorder.py @@ -151,7 +151,9 @@ async def record_task_run_event(event: ReceivedEvent, depth: int = 0): "state_name": task_run.state.name, "state_timestamp": task_run.state.timestamp, } - session = await stack.enter_async_context(db.session_context()) + session = await stack.enter_async_context( + db.session_context(begin_transaction=True) + ) await _insert_task_run(session, task_run, task_run_attributes) await _insert_task_run_state(session, task_run)