Skip to content

Commit

Permalink
task run recorder handles all out of order permutations
Browse files Browse the repository at this point in the history
  • Loading branch information
jakekaplan committed Sep 5, 2024
1 parent d7e66e0 commit 9e84d9d
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 128 deletions.
99 changes: 24 additions & 75 deletions src/prefect/server/services/task_run_recorder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
from contextlib import AsyncExitStack, asynccontextmanager
from datetime import timedelta
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Optional
from uuid import UUID

Expand Down Expand Up @@ -119,43 +118,31 @@ def task_run_from_event(event: ReceivedEvent) -> TaskRun:
)


async def record_task_run_event(event: ReceivedEvent, depth: int = 0):
db = provide_database_interface()

async with AsyncExitStack() as stack:
await stack.enter_async_context(
(
causal_ordering().preceding_event_confirmed(
record_task_run_event, event, depth=depth
)
)
)

task_run = task_run_from_event(event)
async def record_task_run_event(event: ReceivedEvent):
task_run = task_run_from_event(event)

task_run_attributes = task_run.model_dump_for_orm(
exclude={
"state_id",
"state",
"created",
"estimated_run_time",
"estimated_start_time_delta",
},
exclude_unset=True,
)
task_run_attributes = task_run.model_dump_for_orm(
exclude={
"state_id",
"state",
"created",
"estimated_run_time",
"estimated_start_time_delta",
},
exclude_unset=True,
)

assert task_run.state
assert task_run.state

denormalized_state_attributes = {
"state_id": task_run.state.id,
"state_type": task_run.state.type,
"state_name": task_run.state.name,
"state_timestamp": task_run.state.timestamp,
}
session = await stack.enter_async_context(
db.session_context(begin_transaction=True)
)
denormalized_state_attributes = {
"state_id": task_run.state.id,
"state_type": task_run.state.type,
"state_name": task_run.state.name,
"state_timestamp": task_run.state.timestamp,
}

db = provide_database_interface()
async with db.session_context(begin_transaction=True) as session:
await _insert_task_run(session, task_run, task_run_attributes)
await _insert_task_run_state(session, task_run)
await _update_task_run_with_state(
Expand All @@ -177,39 +164,8 @@ async def record_task_run_event(event: ReceivedEvent, depth: int = 0):
)


async def record_lost_follower_task_run_events():
events = await causal_ordering().get_lost_followers()

for event in events:
await record_task_run_event(event)


async def periodically_process_followers(periodic_granularity: timedelta):
"""Periodically process followers that are waiting on a leader event that never arrived"""
logger.debug(
"Starting periodically process followers task every %s seconds",
periodic_granularity.total_seconds(),
)
while True:
try:
await record_lost_follower_task_run_events()
except asyncio.CancelledError:
logger.debug("Periodically process followers task cancelled")
return
except Exception:
logger.exception("Error while processing task-run-recorders followers.")
finally:
await asyncio.sleep(periodic_granularity.total_seconds())


@asynccontextmanager
async def consumer(
periodic_granularity: timedelta = timedelta(seconds=5),
) -> AsyncGenerator[MessageHandler, None]:
record_lost_followers_task = asyncio.create_task(
periodically_process_followers(periodic_granularity=periodic_granularity)
)

async def consumer() -> AsyncGenerator[MessageHandler, None]:
async def message_handler(message: Message):
event: ReceivedEvent = ReceivedEvent.model_validate_json(message.data)

Expand All @@ -231,14 +187,7 @@ async def message_handler(message: Message):
# event arrives.
pass

try:
yield message_handler
finally:
try:
record_lost_followers_task.cancel()
await record_lost_followers_task
except asyncio.CancelledError:
logger.debug("Periodically process followers task cancelled successfully")
yield message_handler


class TaskRunRecorder:
Expand Down
101 changes: 48 additions & 53 deletions tests/server/services/test_task_run_recorder.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
import asyncio
from datetime import timedelta
from itertools import permutations
from typing import AsyncGenerator
from unittest.mock import AsyncMock, patch
from uuid import UUID, uuid4
from uuid import UUID

import pendulum
import pytest
from sqlalchemy.ext.asyncio import AsyncSession

from prefect.server.events.ordering import EventArrivedEarly
from prefect.server.events.schemas.events import ReceivedEvent
from prefect.server.models.flow_runs import create_flow_run
from prefect.server.models.task_run_states import read_task_run_state
from prefect.server.models.task_run_states import (
read_task_run_state,
read_task_run_states,
)
from prefect.server.models.task_runs import read_task_run
from prefect.server.schemas.core import FlowRun, TaskRunPolicy
from prefect.server.schemas.states import StateDetails, StateType
from prefect.server.services import task_run_recorder
from prefect.server.services.task_run_recorder import (
record_lost_follower_task_run_events,
record_task_run_event,
)
from prefect.server.utilities.messaging import MessageHandler
from prefect.server.utilities.messaging.memory import MemoryMessage

Expand All @@ -40,9 +38,7 @@ async def test_start_and_stop_service():

@pytest.fixture
async def task_run_recorder_handler() -> AsyncGenerator[MessageHandler, None]:
async with task_run_recorder.consumer(
periodic_granularity=timedelta(seconds=0.0001)
) as handler:
async with task_run_recorder.consumer() as handler:
yield handler


Expand Down Expand Up @@ -711,50 +707,49 @@ async def test_updates_task_run_on_out_of_order_state_change(
)


async def test_lost_followers_are_recorded(monkeypatch: pytest.MonkeyPatch):
now = pendulum.now("UTC")
event = ReceivedEvent(
occurred=now.subtract(minutes=2),
received=now.subtract(minutes=1),
event="prefect.task-run.Running",
resource={
"prefect.resource.id": f"prefect.task-run.{str(uuid4())}",
},
account=uuid4(),
workspace=uuid4(),
follows=uuid4(),
id=uuid4(),
)
# record a follower that never sees its leader
with pytest.raises(EventArrivedEarly):
await record_task_run_event(event)

record_task_run_event_mock = AsyncMock()
monkeypatch.setattr(
"prefect.server.services.task_run_recorder.record_task_run_event",
record_task_run_event_mock,
)

# move time forward so we can record the lost follower
with patch("prefect.server.events.ordering.pendulum.now") as the_future:
the_future.return_value = now.add(minutes=20)
await record_lost_follower_task_run_events()

assert record_task_run_event_mock.await_count == 1
record_task_run_event_mock.assert_awaited_with(event)
@pytest.mark.parametrize(
"event_order",
list(permutations(["PENDING", "RUNNING", "COMPLETED"])),
ids=lambda x: "->".join(x),
)
async def test_task_run_recorder_handles_all_out_of_order_permutations(
session: AsyncSession,
pending_event: ReceivedEvent,
running_event: ReceivedEvent,
completed_event: ReceivedEvent,
task_run_recorder_handler: MessageHandler,
event_order: tuple,
):
# Set up event times
base_time = pendulum.datetime(2024, 1, 1, 0, 0, 0, 0, "UTC")
pending_event.occurred = base_time
running_event.occurred = base_time.add(minutes=1)
completed_event.occurred = base_time.add(minutes=2)

event_map = {
"PENDING": pending_event,
"RUNNING": running_event,
"COMPLETED": completed_event,
}

# Process events in the specified order
for event_name in event_order:
await task_run_recorder_handler(message(event_map[event_name]))

async def test_lost_followers_are_recorded_periodically(
task_run_recorder_handler,
monkeypatch: pytest.MonkeyPatch,
):
record_lost_follower_task_run_events_mock = AsyncMock()
monkeypatch.setattr(
"prefect.server.services.task_run_recorder.record_lost_follower_task_run_events",
record_lost_follower_task_run_events_mock,
# Verify the task run always has the "final" state
task_run = await read_task_run(
session=session,
task_run_id=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"),
)

# let the period task run a few times
await asyncio.sleep(0.1)
assert task_run
assert task_run.state_type == StateType.COMPLETED
assert task_run.state_name == "Completed"
assert task_run.state_timestamp == completed_event.occurred

# Verify all states are recorded
states = await read_task_run_states(session, task_run.id)
assert len(states) == 3

assert record_lost_follower_task_run_events_mock.await_count >= 1
state_types = set(state.type for state in states)
assert state_types == {StateType.PENDING, StateType.RUNNING, StateType.COMPLETED}

0 comments on commit 9e84d9d

Please sign in to comment.