-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
scaffolding for client task run orchestration (#14627)
- Loading branch information
1 parent
eb738c3
commit 36e2ddf
Showing
7 changed files
with
292 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import asyncio | ||
from contextlib import asynccontextmanager | ||
from typing import AsyncGenerator, Optional | ||
|
||
from prefect.logging import get_logger | ||
from prefect.server.events.schemas.events import ReceivedEvent | ||
from prefect.server.utilities.messaging import Message, MessageHandler, create_consumer | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
@asynccontextmanager | ||
async def consumer() -> AsyncGenerator[MessageHandler, None]: | ||
async def message_handler(message: Message): | ||
event: ReceivedEvent = ReceivedEvent.model_validate_json(message.data) | ||
|
||
if not event.event.startswith("prefect.task-run"): | ||
return | ||
|
||
if not event.resource.get("prefect.orchestration") == "client": | ||
return | ||
|
||
logger.info( | ||
f"Received event: {event.event} with id: {event.id} for resource: {event.resource.get('prefect.resource.id')}" | ||
) | ||
|
||
yield message_handler | ||
|
||
|
||
class TaskRunRecorder: | ||
"""A service to record task run and task run states from events.""" | ||
|
||
name: str = "TaskRunRecorder" | ||
|
||
consumer_task: Optional[asyncio.Task] = None | ||
|
||
def __init__(self): | ||
self._started_event: Optional[asyncio.Event] = None | ||
|
||
@property | ||
def started_event(self) -> asyncio.Event: | ||
if self._started_event is None: | ||
self._started_event = asyncio.Event() | ||
return self._started_event | ||
|
||
@started_event.setter | ||
def started_event(self, value: asyncio.Event) -> None: | ||
self._started_event = value | ||
|
||
async def start(self): | ||
assert self.consumer_task is None, "TaskRunRecorder already started" | ||
self.consumer = create_consumer("events") | ||
|
||
async with consumer() as handler: | ||
self.consumer_task = asyncio.create_task(self.consumer.run(handler)) | ||
logger.debug("TaskRunRecorder started") | ||
self.started_event.set() | ||
|
||
try: | ||
await self.consumer_task | ||
except asyncio.CancelledError: | ||
pass | ||
|
||
async def stop(self): | ||
assert self.consumer_task is not None, "Logger not started" | ||
self.consumer_task.cancel() | ||
try: | ||
await self.consumer_task | ||
except asyncio.CancelledError: | ||
pass | ||
finally: | ||
self.consumer_task = None | ||
logger.debug("TaskRunRecorder stopped") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
import asyncio | ||
from typing import AsyncGenerator | ||
from uuid import UUID | ||
|
||
import pendulum | ||
import pytest | ||
|
||
from prefect.server.events.schemas.events import ReceivedEvent | ||
from prefect.server.services import task_run_recorder | ||
from prefect.server.utilities.messaging import MessageHandler | ||
from prefect.server.utilities.messaging.memory import MemoryMessage | ||
|
||
|
||
async def test_start_and_stop_service(): | ||
service = task_run_recorder.TaskRunRecorder() | ||
service_task = asyncio.create_task(service.start()) | ||
service.started_event = asyncio.Event() | ||
|
||
await service.started_event.wait() | ||
assert service.consumer_task is not None | ||
|
||
await service.stop() | ||
assert service.consumer_task is None | ||
|
||
await service_task | ||
|
||
|
||
@pytest.fixture | ||
async def task_run_recorder_handler() -> AsyncGenerator[MessageHandler, None]: | ||
async with task_run_recorder.consumer() as handler: | ||
yield handler | ||
|
||
|
||
@pytest.fixture | ||
def hello_event() -> ReceivedEvent: | ||
return ReceivedEvent( | ||
occurred=pendulum.datetime(2022, 1, 2, 3, 4, 5, 6, "UTC"), | ||
event="hello", | ||
resource={ | ||
"prefect.resource.id": "my.resource.id", | ||
}, | ||
related=[ | ||
{"prefect.resource.id": "related-1", "prefect.resource.role": "role-1"}, | ||
{"prefect.resource.id": "related-2", "prefect.resource.role": "role-1"}, | ||
{"prefect.resource.id": "related-3", "prefect.resource.role": "role-2"}, | ||
], | ||
payload={"hello": "world"}, | ||
account=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), | ||
workspace=UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"), | ||
received=pendulum.datetime(2022, 2, 3, 4, 5, 6, 7, "UTC"), | ||
id=UUID("eeeeeeee-eeee-eeee-eeee-eeeeeeeeeeee"), | ||
follows=UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def client_orchestrated_task_run_event() -> ReceivedEvent: | ||
return ReceivedEvent( | ||
occurred=pendulum.datetime(2022, 1, 2, 3, 4, 5, 6, "UTC"), | ||
event="prefect.task-run.Running", | ||
resource={ | ||
"prefect.resource.id": "prefect.task-run.b75b283c-7cd5-439a-b23e-d0c59e78b042", | ||
"prefect.resource.name": "my_task", | ||
"prefect.state-message": "", | ||
"prefect.state-name": "Running", | ||
"prefect.state-timestamp": pendulum.datetime( | ||
2022, 1, 2, 3, 4, 5, 6, "UTC" | ||
).isoformat(), | ||
"prefect.state-type": "RUNNING", | ||
"prefect.orchestration": "client", | ||
}, | ||
related=[], | ||
payload={ | ||
"intended": {"from": "PENDING", "to": "RUNNING"}, | ||
"initial_state": {"type": "PENDING", "name": "Pending", "message": ""}, | ||
"validated_state": {"type": "RUNNING", "name": "Running", "message": ""}, | ||
}, | ||
account=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), | ||
workspace=UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"), | ||
received=pendulum.datetime(2022, 2, 3, 4, 5, 6, 7, "UTC"), | ||
id=UUID("eeeeeeee-eeee-eeee-eeee-eeeeeeeeeeee"), | ||
follows=UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def server_orchestrated_task_run_event() -> ReceivedEvent: | ||
return ReceivedEvent( | ||
occurred=pendulum.datetime(2022, 1, 2, 3, 4, 5, 6, "UTC"), | ||
event="prefect.task-run.Running", | ||
resource={ | ||
"prefect.resource.id": "prefect.task-run.b75b283c-7cd5-439a-b23e-d0c59e78b042", | ||
"prefect.resource.name": "my_task", | ||
"prefect.state-message": "", | ||
"prefect.state-name": "Running", | ||
"prefect.state-timestamp": pendulum.datetime( | ||
2022, 1, 2, 3, 4, 5, 6, "UTC" | ||
).isoformat(), | ||
"prefect.state-type": "RUNNING", | ||
"prefect.orchestration": "server", | ||
}, | ||
related=[], | ||
payload={ | ||
"intended": {"from": "PENDING", "to": "RUNNING"}, | ||
"initial_state": {"type": "PENDING", "name": "Pending", "message": ""}, | ||
"validated_state": {"type": "RUNNING", "name": "Running", "message": ""}, | ||
}, | ||
account=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), | ||
workspace=UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"), | ||
received=pendulum.datetime(2022, 2, 3, 4, 5, 6, 7, "UTC"), | ||
id=UUID("eeeeeeee-eeee-eeee-eeee-eeeeeeeeeeee"), | ||
follows=UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def client_orchestrated_task_run_event_message( | ||
client_orchestrated_task_run_event: ReceivedEvent, | ||
) -> MemoryMessage: | ||
return MemoryMessage( | ||
data=client_orchestrated_task_run_event.model_dump_json().encode(), | ||
attributes={}, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def server_orchestrated_task_run_event_message( | ||
server_orchestrated_task_run_event: ReceivedEvent, | ||
) -> MemoryMessage: | ||
return MemoryMessage( | ||
data=server_orchestrated_task_run_event.model_dump_json().encode(), | ||
attributes={}, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def hello_event_message(hello_event: ReceivedEvent) -> MemoryMessage: | ||
return MemoryMessage( | ||
data=hello_event.model_dump_json().encode(), | ||
attributes={}, | ||
) | ||
|
||
|
||
async def test_handle_client_orchestrated_task_run_event( | ||
task_run_recorder_handler: MessageHandler, | ||
client_orchestrated_task_run_event: ReceivedEvent, | ||
client_orchestrated_task_run_event_message: MemoryMessage, | ||
caplog: pytest.LogCaptureFixture, | ||
): | ||
with caplog.at_level("INFO"): | ||
await task_run_recorder_handler(client_orchestrated_task_run_event_message) | ||
|
||
assert "Received event" in caplog.text | ||
assert str(client_orchestrated_task_run_event.id) in caplog.text | ||
|
||
|
||
async def test_skip_non_task_run_event( | ||
task_run_recorder_handler: MessageHandler, | ||
hello_event: ReceivedEvent, | ||
hello_event_message: MemoryMessage, | ||
caplog: pytest.LogCaptureFixture, | ||
): | ||
with caplog.at_level("INFO"): | ||
await task_run_recorder_handler(hello_event_message) | ||
|
||
assert "Received event" not in caplog.text | ||
assert str(hello_event.id) not in caplog.text | ||
|
||
|
||
async def test_skip_server_side_orchestrated_task_run( | ||
task_run_recorder_handler: MessageHandler, | ||
server_orchestrated_task_run_event: ReceivedEvent, | ||
server_orchestrated_task_run_event_message: MemoryMessage, | ||
caplog: pytest.LogCaptureFixture, | ||
): | ||
with caplog.at_level("INFO"): | ||
await task_run_recorder_handler(server_orchestrated_task_run_event_message) | ||
|
||
assert "Received event" not in caplog.text | ||
assert str(server_orchestrated_task_run_event.id) not in caplog.text |