Skip to content

Commit

Permalink
Merge branch 'main' into add-event-logging-server-side
Browse files Browse the repository at this point in the history
  • Loading branch information
jakekaplan committed Sep 5, 2024
2 parents 2e15b4c + 2b994af commit be0297d
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 140 deletions.
8 changes: 6 additions & 2 deletions src/prefect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ async def _read(self, key: str, holder: str) -> "ResultRecord":
A result record.
"""
if self.lock_manager is not None and not self.is_lock_holder(key, holder):
self.wait_for_lock(key)
await self.await_for_lock(key)

if self.result_storage is None:
self.result_storage = await get_default_result_storage()
Expand Down Expand Up @@ -407,7 +407,10 @@ def create_result_record(
return ResultRecord(
result=obj,
metadata=ResultRecordMetadata(
serializer=self.serializer, expiration=expiration, storage_key=key
serializer=self.serializer,
expiration=expiration,
storage_key=key,
storage_block_id=self.result_storage_block_id,
),
)

Expand Down Expand Up @@ -752,6 +755,7 @@ class ResultRecordMetadata(BaseModel):
expiration: Optional[DateTime] = Field(default=None)
serializer: Serializer = Field(default_factory=PickleSerializer)
prefect_version: str = Field(default=prefect.__version__)
storage_block_id: Optional[uuid.UUID] = Field(default=None)

def dump_bytes(self) -> bytes:
"""
Expand Down
99 changes: 24 additions & 75 deletions src/prefect/server/services/task_run_recorder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
from contextlib import AsyncExitStack, asynccontextmanager
from datetime import timedelta
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Optional
from uuid import UUID

Expand Down Expand Up @@ -119,43 +118,31 @@ def task_run_from_event(event: ReceivedEvent) -> TaskRun:
)


async def record_task_run_event(event: ReceivedEvent, depth: int = 0):
db = provide_database_interface()

async with AsyncExitStack() as stack:
await stack.enter_async_context(
(
causal_ordering().preceding_event_confirmed(
record_task_run_event, event, depth=depth
)
)
)

task_run = task_run_from_event(event)
async def record_task_run_event(event: ReceivedEvent):
task_run = task_run_from_event(event)

task_run_attributes = task_run.model_dump_for_orm(
exclude={
"state_id",
"state",
"created",
"estimated_run_time",
"estimated_start_time_delta",
},
exclude_unset=True,
)
task_run_attributes = task_run.model_dump_for_orm(
exclude={
"state_id",
"state",
"created",
"estimated_run_time",
"estimated_start_time_delta",
},
exclude_unset=True,
)

assert task_run.state
assert task_run.state

denormalized_state_attributes = {
"state_id": task_run.state.id,
"state_type": task_run.state.type,
"state_name": task_run.state.name,
"state_timestamp": task_run.state.timestamp,
}
session = await stack.enter_async_context(
db.session_context(begin_transaction=True)
)
denormalized_state_attributes = {
"state_id": task_run.state.id,
"state_type": task_run.state.type,
"state_name": task_run.state.name,
"state_timestamp": task_run.state.timestamp,
}

db = provide_database_interface()
async with db.session_context(begin_transaction=True) as session:
await _insert_task_run(session, task_run, task_run_attributes)
await _insert_task_run_state(session, task_run)
await _update_task_run_with_state(
Expand All @@ -177,39 +164,8 @@ async def record_task_run_event(event: ReceivedEvent, depth: int = 0):
)


async def record_lost_follower_task_run_events():
events = await causal_ordering().get_lost_followers()

for event in events:
await record_task_run_event(event)


async def periodically_process_followers(periodic_granularity: timedelta):
"""Periodically process followers that are waiting on a leader event that never arrived"""
logger.debug(
"Starting periodically process followers task every %s seconds",
periodic_granularity.total_seconds(),
)
while True:
try:
await record_lost_follower_task_run_events()
except asyncio.CancelledError:
logger.debug("Periodically process followers task cancelled")
return
except Exception:
logger.exception("Error while processing task-run-recorders followers.")
finally:
await asyncio.sleep(periodic_granularity.total_seconds())


@asynccontextmanager
async def consumer(
periodic_granularity: timedelta = timedelta(seconds=5),
) -> AsyncGenerator[MessageHandler, None]:
record_lost_followers_task = asyncio.create_task(
periodically_process_followers(periodic_granularity=periodic_granularity)
)

async def consumer() -> AsyncGenerator[MessageHandler, None]:
async def message_handler(message: Message):
event: ReceivedEvent = ReceivedEvent.model_validate_json(message.data)

Expand All @@ -234,14 +190,7 @@ async def message_handler(message: Message):
# event arrives.
pass

try:
yield message_handler
finally:
try:
record_lost_followers_task.cancel()
await record_lost_followers_task
except asyncio.CancelledError:
logger.debug("Periodically process followers task cancelled successfully")
yield message_handler


class TaskRunRecorder:
Expand Down
28 changes: 18 additions & 10 deletions src/prefect/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from contextvars import ContextVar, Token
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand All @@ -22,13 +21,12 @@
from prefect.exceptions import MissingContextError, SerializationError
from prefect.logging.loggers import get_logger, get_run_logger
from prefect.records import RecordStore
from prefect.records.base import TransactionRecord
from prefect.results import BaseResult, ResultRecord, ResultStore
from prefect.utilities.annotations import NotSet
from prefect.utilities.collections import AutoEnum
from prefect.utilities.engine import _get_hook_name

if TYPE_CHECKING:
from prefect.results import BaseResult


class IsolationLevel(AutoEnum):
READ_COMMITTED = AutoEnum.auto()
Expand All @@ -54,7 +52,7 @@ class Transaction(ContextModel):
A base model for transaction state.
"""

store: Optional[RecordStore] = None
store: Union[RecordStore, ResultStore, None] = None
key: Optional[str] = None
children: List["Transaction"] = Field(default_factory=list)
commit_mode: Optional[CommitMode] = None
Expand Down Expand Up @@ -175,10 +173,14 @@ def begin(self):
):
self.state = TransactionState.COMMITTED

def read(self) -> Optional["BaseResult"]:
def read(self) -> Union["BaseResult", ResultRecord, None]:
if self.store and self.key:
record = self.store.read(key=self.key)
if record is not None:
if isinstance(record, ResultRecord):
return record
# for backwards compatibility, if we encounter a transaction record, return the result
# This happens when the transaction is using a `ResultStore`
if isinstance(record, TransactionRecord):
return record.result
return None

Expand Down Expand Up @@ -228,7 +230,13 @@ def commit(self) -> bool:
self.run_hook(hook, "commit")

if self.store and self.key:
self.store.write(key=self.key, result=self._staged_value)
if isinstance(self.store, ResultStore):
if isinstance(self._staged_value, BaseResult):
self.store.write(self.key, self._staged_value.get(_sync=True))
else:
self.store.write(self.key, self._staged_value)
else:
self.store.write(self.key, self._staged_value)
self.state = TransactionState.COMMITTED
if (
self.store
Expand Down Expand Up @@ -279,7 +287,7 @@ def run_hook(self, hook, hook_type: str) -> None:

def stage(
self,
value: "BaseResult",
value: Union["BaseResult", Any],
on_rollback_hooks: Optional[List] = None,
on_commit_hooks: Optional[List] = None,
) -> None:
Expand Down Expand Up @@ -337,7 +345,7 @@ def get_transaction() -> Optional[Transaction]:
@contextmanager
def transaction(
key: Optional[str] = None,
store: Optional[RecordStore] = None,
store: Union[RecordStore, ResultStore, None] = None,
commit_mode: Optional[CommitMode] = None,
isolation_level: Optional[IsolationLevel] = None,
overwrite: bool = False,
Expand Down
101 changes: 48 additions & 53 deletions tests/server/services/test_task_run_recorder.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
import asyncio
from datetime import timedelta
from itertools import permutations
from typing import AsyncGenerator
from unittest.mock import AsyncMock, patch
from uuid import UUID, uuid4
from uuid import UUID

import pendulum
import pytest
from sqlalchemy.ext.asyncio import AsyncSession

from prefect.server.events.ordering import EventArrivedEarly
from prefect.server.events.schemas.events import ReceivedEvent
from prefect.server.models.flow_runs import create_flow_run
from prefect.server.models.task_run_states import read_task_run_state
from prefect.server.models.task_run_states import (
read_task_run_state,
read_task_run_states,
)
from prefect.server.models.task_runs import read_task_run
from prefect.server.schemas.core import FlowRun, TaskRunPolicy
from prefect.server.schemas.states import StateDetails, StateType
from prefect.server.services import task_run_recorder
from prefect.server.services.task_run_recorder import (
record_lost_follower_task_run_events,
record_task_run_event,
)
from prefect.server.utilities.messaging import MessageHandler
from prefect.server.utilities.messaging.memory import MemoryMessage

Expand All @@ -40,9 +38,7 @@ async def test_start_and_stop_service():

@pytest.fixture
async def task_run_recorder_handler() -> AsyncGenerator[MessageHandler, None]:
async with task_run_recorder.consumer(
periodic_granularity=timedelta(seconds=0.0001)
) as handler:
async with task_run_recorder.consumer() as handler:
yield handler


Expand Down Expand Up @@ -711,50 +707,49 @@ async def test_updates_task_run_on_out_of_order_state_change(
)


async def test_lost_followers_are_recorded(monkeypatch: pytest.MonkeyPatch):
now = pendulum.now("UTC")
event = ReceivedEvent(
occurred=now.subtract(minutes=2),
received=now.subtract(minutes=1),
event="prefect.task-run.Running",
resource={
"prefect.resource.id": f"prefect.task-run.{str(uuid4())}",
},
account=uuid4(),
workspace=uuid4(),
follows=uuid4(),
id=uuid4(),
)
# record a follower that never sees its leader
with pytest.raises(EventArrivedEarly):
await record_task_run_event(event)

record_task_run_event_mock = AsyncMock()
monkeypatch.setattr(
"prefect.server.services.task_run_recorder.record_task_run_event",
record_task_run_event_mock,
)

# move time forward so we can record the lost follower
with patch("prefect.server.events.ordering.pendulum.now") as the_future:
the_future.return_value = now.add(minutes=20)
await record_lost_follower_task_run_events()

assert record_task_run_event_mock.await_count == 1
record_task_run_event_mock.assert_awaited_with(event)
@pytest.mark.parametrize(
"event_order",
list(permutations(["PENDING", "RUNNING", "COMPLETED"])),
ids=lambda x: "->".join(x),
)
async def test_task_run_recorder_handles_all_out_of_order_permutations(
session: AsyncSession,
pending_event: ReceivedEvent,
running_event: ReceivedEvent,
completed_event: ReceivedEvent,
task_run_recorder_handler: MessageHandler,
event_order: tuple,
):
# Set up event times
base_time = pendulum.datetime(2024, 1, 1, 0, 0, 0, 0, "UTC")
pending_event.occurred = base_time
running_event.occurred = base_time.add(minutes=1)
completed_event.occurred = base_time.add(minutes=2)

event_map = {
"PENDING": pending_event,
"RUNNING": running_event,
"COMPLETED": completed_event,
}

# Process events in the specified order
for event_name in event_order:
await task_run_recorder_handler(message(event_map[event_name]))

async def test_lost_followers_are_recorded_periodically(
task_run_recorder_handler,
monkeypatch: pytest.MonkeyPatch,
):
record_lost_follower_task_run_events_mock = AsyncMock()
monkeypatch.setattr(
"prefect.server.services.task_run_recorder.record_lost_follower_task_run_events",
record_lost_follower_task_run_events_mock,
# Verify the task run always has the "final" state
task_run = await read_task_run(
session=session,
task_run_id=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"),
)

# let the period task run a few times
await asyncio.sleep(0.1)
assert task_run
assert task_run.state_type == StateType.COMPLETED
assert task_run.state_name == "Completed"
assert task_run.state_timestamp == completed_event.occurred

# Verify all states are recorded
states = await read_task_run_states(session, task_run.id)
assert len(states) == 3

assert record_lost_follower_task_run_events_mock.await_count >= 1
state_types = set(state.type for state in states)
assert state_types == {StateType.PENDING, StateType.RUNNING, StateType.COMPLETED}
Loading

0 comments on commit be0297d

Please sign in to comment.