-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
let TaskRunRecorder
process events into task runs/task run states
#14729
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,179 @@ | ||
import asyncio | ||
from contextlib import asynccontextmanager | ||
from typing import AsyncGenerator, Optional | ||
from contextlib import AsyncExitStack, asynccontextmanager | ||
from typing import Any, AsyncGenerator, Dict, Optional | ||
from uuid import UUID | ||
|
||
import pendulum | ||
import sqlalchemy as sa | ||
from sqlalchemy.ext.asyncio import AsyncSession | ||
|
||
from prefect.logging import get_logger | ||
from prefect.server.database.dependencies import db_injector, provide_database_interface | ||
from prefect.server.database.interface import PrefectDBInterface | ||
from prefect.server.events.ordering import CausalOrdering, EventArrivedEarly | ||
from prefect.server.events.schemas.events import ReceivedEvent | ||
from prefect.server.schemas.core import TaskRun | ||
from prefect.server.schemas.states import State | ||
from prefect.server.utilities.messaging import Message, MessageHandler, create_consumer | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
def causal_ordering(): | ||
return CausalOrdering( | ||
"task-run-recorder", | ||
) | ||
|
||
|
||
@db_injector | ||
async def _insert_task_run( | ||
db: PrefectDBInterface, | ||
session: AsyncSession, | ||
task_run: TaskRun, | ||
task_run_attributes: Dict[str, Any], | ||
): | ||
await session.execute( | ||
db.insert(db.TaskRun) | ||
.values( | ||
created=pendulum.now("UTC"), | ||
**task_run_attributes, | ||
) | ||
.on_conflict_do_update( | ||
index_elements=[ | ||
"id", | ||
], | ||
set_={ | ||
"updated": pendulum.now("UTC"), | ||
**task_run_attributes, | ||
}, | ||
where=db.TaskRun.state_timestamp < task_run.state.timestamp, | ||
) | ||
) | ||
|
||
|
||
@db_injector | ||
async def _insert_task_run_state( | ||
db: PrefectDBInterface, session: AsyncSession, task_run: TaskRun | ||
): | ||
await session.execute( | ||
db.insert(db.TaskRunState) | ||
.values( | ||
created=pendulum.now("UTC"), | ||
task_run_id=task_run.id, | ||
**task_run.state.model_dump(), | ||
) | ||
.on_conflict_do_nothing( | ||
index_elements=[ | ||
"id", | ||
] | ||
) | ||
) | ||
|
||
|
||
@db_injector | ||
async def _update_task_run_with_state( | ||
db: PrefectDBInterface, | ||
session: AsyncSession, | ||
task_run: TaskRun, | ||
denormalized_state_attributes: Dict[str, Any], | ||
): | ||
await session.execute( | ||
sa.update(db.TaskRun) | ||
.where( | ||
db.TaskRun.id == task_run.id, | ||
sa.or_( | ||
db.TaskRun.state_timestamp.is_(None), | ||
db.TaskRun.state_timestamp < task_run.state.timestamp, | ||
), | ||
) | ||
.values(**denormalized_state_attributes) | ||
) | ||
|
||
|
||
def task_run_from_event(event: ReceivedEvent) -> TaskRun: | ||
task_run_id = event.resource.prefect_object_id("prefect.task-run") | ||
|
||
flow_run_id: Optional[UUID] = None | ||
if flow_run_resource := event.resource_in_role.get("flow-run"): | ||
flow_run_id = flow_run_resource.prefect_object_id("prefect.flow-run") | ||
|
||
state: State = State.model_validate( | ||
{ | ||
"id": event.id, | ||
"timestamp": event.occurred, | ||
**event.payload["validated_state"], | ||
} | ||
) | ||
state.state_details.task_run_id = task_run_id | ||
state.state_details.flow_run_id = flow_run_id | ||
|
||
return TaskRun.model_validate( | ||
{ | ||
"id": task_run_id, | ||
"flow_run_id": flow_run_id, | ||
"state_id": state.id, | ||
"state": state, | ||
**event.payload["task_run"], | ||
} | ||
) | ||
|
||
|
||
async def record_task_run_event(event: ReceivedEvent, depth: int = 0): | ||
db = provide_database_interface() | ||
|
||
async with AsyncExitStack() as stack: | ||
await stack.enter_async_context( | ||
( | ||
causal_ordering().preceding_event_confirmed( | ||
record_task_run_event, event, depth=depth | ||
) | ||
) | ||
) | ||
|
||
task_run = task_run_from_event(event) | ||
|
||
task_run_attributes = task_run.model_dump_for_orm( | ||
exclude={ | ||
"state_id", | ||
"state", | ||
"created", | ||
"estimated_run_time", | ||
"estimated_start_time_delta", | ||
}, | ||
exclude_unset=True, | ||
) | ||
|
||
assert task_run.state | ||
|
||
denormalized_state_attributes = { | ||
"state_id": task_run.state.id, | ||
"state_type": task_run.state.type, | ||
"state_name": task_run.state.name, | ||
"state_timestamp": task_run.state.timestamp, | ||
} | ||
session = await stack.enter_async_context(db.session_context()) | ||
|
||
await _insert_task_run(session, task_run, task_run_attributes) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately due to FK's we need to insert the task run, then insert the task run state, then go back and update the task run with the denormalized state attributes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, sometimes this trips me up with the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question, that might make the lock a little more aggressive? I'm thinking about the SQLite implementation hitting |
||
await _insert_task_run_state(session, task_run) | ||
await _update_task_run_with_state( | ||
session, task_run, denormalized_state_attributes | ||
) | ||
|
||
logger.info( | ||
"Recorded task run state change", | ||
extra={ | ||
"task_run_id": task_run.id, | ||
"flow_run_id": task_run.flow_run_id, | ||
"event_id": event.id, | ||
"event_follows": event.follows, | ||
"event": event.event, | ||
"occurred": event.occurred, | ||
"current_state_type": task_run.state_type, | ||
"current_state_name": task_run.state_name, | ||
}, | ||
) | ||
|
||
|
||
@asynccontextmanager | ||
async def consumer() -> AsyncGenerator[MessageHandler, None]: | ||
async def message_handler(message: Message): | ||
|
@@ -24,6 +189,14 @@ async def message_handler(message: Message): | |
f"Received event: {event.event} with id: {event.id} for resource: {event.resource.get('prefect.resource.id')}" | ||
) | ||
|
||
try: | ||
await record_task_run_event(event) | ||
except EventArrivedEarly: | ||
# We're safe to ACK this message because it has been parked by the | ||
# causal ordering mechanism and will be reprocessed when the preceding | ||
# event arrives. | ||
pass | ||
|
||
yield message_handler | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this fits our pattern so that we'll only do this if things are in order