From 36e2ddfd1cf454e296ec034ba5c6863e6a13b298 Mon Sep 17 00:00:00 2001 From: jakekaplan <40362401+jakekaplan@users.noreply.github.com> Date: Tue, 16 Jul 2024 14:55:31 -0400 Subject: [PATCH] scaffolding for client task run orchestration (#14627) --- .../3.0rc/api-ref/rest-api/server/schema.json | 10 + src/prefect/server/api/server.py | 7 + .../server/services/task_run_recorder.py | 73 +++++++ src/prefect/settings.py | 12 ++ src/prefect/utilities/engine.py | 4 + .../test_task_run_state_change_events.py | 6 + .../server/services/test_task_run_recorder.py | 180 ++++++++++++++++++ 7 files changed, 292 insertions(+) create mode 100644 src/prefect/server/services/task_run_recorder.py create mode 100644 tests/server/services/test_task_run_recorder.py diff --git a/docs/3.0rc/api-ref/rest-api/server/schema.json b/docs/3.0rc/api-ref/rest-api/server/schema.json index 700ae8efdc04..9a0a745ca51a 100644 --- a/docs/3.0rc/api-ref/rest-api/server/schema.json +++ b/docs/3.0rc/api-ref/rest-api/server/schema.json @@ -22165,6 +22165,11 @@ "title": "Prefect Api Log Retryable Errors", "default": false }, + "PREFECT_API_SERVICES_TASK_RUN_RECORDER_ENABLED": { + "type": "boolean", + "title": "Prefect Api Services Task Run Recorder Enabled", + "default": true + }, "PREFECT_API_DEFAULT_LIMIT": { "type": "integer", "title": "Prefect Api Default Limit", @@ -22257,6 +22262,11 @@ "title": "Prefect Api Max Flow Run Graph Artifacts", "default": 10000 }, + "PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION": { + "type": "boolean", + "title": "Prefect Experimental Enable Client Side Task Orchestration", + "default": false + }, "PREFECT_RUNNER_PROCESS_LIMIT": { "type": "integer", "title": "Prefect Runner Process Limit", diff --git a/src/prefect/server/api/server.py b/src/prefect/server/api/server.py index 71634fff30e7..23dfcb178a88 100644 --- a/src/prefect/server/api/server.py +++ b/src/prefect/server/api/server.py @@ -40,6 +40,7 @@ from prefect.server.events.services.event_persister import EventPersister from prefect.server.events.services.triggers import ProactiveTriggers, ReactiveTriggers from prefect.server.exceptions import ObjectNotFoundError +from prefect.server.services.task_run_recorder import TaskRunRecorder from prefect.server.utilities.database import get_dialect from prefect.server.utilities.server import method_paths_from_routes from prefect.settings import ( @@ -603,6 +604,12 @@ async def start_services(): if prefect.settings.PREFECT_API_EVENTS_STREAM_OUT_ENABLED: service_instances.append(stream.Distributor()) + if ( + prefect.settings.PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION + and prefect.settings.PREFECT_API_SERVICES_TASK_RUN_RECORDER_ENABLED + ): + service_instances.append(TaskRunRecorder()) + loop = asyncio.get_running_loop() app.state.services = { diff --git a/src/prefect/server/services/task_run_recorder.py b/src/prefect/server/services/task_run_recorder.py new file mode 100644 index 000000000000..a9ad73d39612 --- /dev/null +++ b/src/prefect/server/services/task_run_recorder.py @@ -0,0 +1,73 @@ +import asyncio +from contextlib import asynccontextmanager +from typing import AsyncGenerator, Optional + +from prefect.logging import get_logger +from prefect.server.events.schemas.events import ReceivedEvent +from prefect.server.utilities.messaging import Message, MessageHandler, create_consumer + +logger = get_logger(__name__) + + +@asynccontextmanager +async def consumer() -> AsyncGenerator[MessageHandler, None]: + async def message_handler(message: Message): + event: ReceivedEvent = ReceivedEvent.model_validate_json(message.data) + + if not event.event.startswith("prefect.task-run"): + return + + if not event.resource.get("prefect.orchestration") == "client": + return + + logger.info( + f"Received event: {event.event} with id: {event.id} for resource: {event.resource.get('prefect.resource.id')}" + ) + + yield message_handler + + +class TaskRunRecorder: + """A service to record task run and task run states from events.""" + + name: str = "TaskRunRecorder" + + consumer_task: Optional[asyncio.Task] = None + + def __init__(self): + self._started_event: Optional[asyncio.Event] = None + + @property + def started_event(self) -> asyncio.Event: + if self._started_event is None: + self._started_event = asyncio.Event() + return self._started_event + + @started_event.setter + def started_event(self, value: asyncio.Event) -> None: + self._started_event = value + + async def start(self): + assert self.consumer_task is None, "TaskRunRecorder already started" + self.consumer = create_consumer("events") + + async with consumer() as handler: + self.consumer_task = asyncio.create_task(self.consumer.run(handler)) + logger.debug("TaskRunRecorder started") + self.started_event.set() + + try: + await self.consumer_task + except asyncio.CancelledError: + pass + + async def stop(self): + assert self.consumer_task is not None, "Logger not started" + self.consumer_task.cancel() + try: + await self.consumer_task + except asyncio.CancelledError: + pass + finally: + self.consumer_task = None + logger.debug("TaskRunRecorder stopped") diff --git a/src/prefect/settings.py b/src/prefect/settings.py index ee39e1a45744..c7d133909dc7 100644 --- a/src/prefect/settings.py +++ b/src/prefect/settings.py @@ -1160,6 +1160,11 @@ def default_cloud_ui_url(settings, value): PREFECT_API_LOG_RETRYABLE_ERRORS = Setting(bool, default=False) """If `True`, log retryable errors in the API and it's services.""" +PREFECT_API_SERVICES_TASK_RUN_RECORDER_ENABLED = Setting(bool, default=True) +""" +Whether or not to start the task run recorder service in the server application. +""" + PREFECT_API_DEFAULT_LIMIT = Setting( int, @@ -1309,6 +1314,13 @@ def default_cloud_ui_url(settings, value): """ +PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION = Setting( + bool, default=False +) +""" +Whether or not to enable experimental client side task run orchestration. +""" + # Prefect Events feature flags PREFECT_RUNNER_PROCESS_LIMIT = Setting(int, default=5) diff --git a/src/prefect/utilities/engine.py b/src/prefect/utilities/engine.py index 3c11d69aaa17..fca42bb881ae 100644 --- a/src/prefect/utilities/engine.py +++ b/src/prefect/utilities/engine.py @@ -51,6 +51,7 @@ ) from prefect.results import BaseResult from prefect.settings import ( + PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION, PREFECT_LOGGING_LOG_PRINTS, ) from prefect.states import ( @@ -798,6 +799,9 @@ def emit_task_run_state_change_event( else "" ), "prefect.state-type": str(validated_state.type.value), + "prefect.orchestration": "client" + if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION + else "server", }, follows=follows, ) diff --git a/tests/events/client/instrumentation/test_task_run_state_change_events.py b/tests/events/client/instrumentation/test_task_run_state_change_events.py index 338ba5a0fd89..88fbb5b392c9 100644 --- a/tests/events/client/instrumentation/test_task_run_state_change_events.py +++ b/tests/events/client/instrumentation/test_task_run_state_change_events.py @@ -52,6 +52,7 @@ def happy_path(): "prefect.state-type": "PENDING", "prefect.state-name": "Pending", "prefect.state-timestamp": task_run_states[0].timestamp.isoformat(), + "prefect.orchestration": "server", } ) assert ( @@ -103,6 +104,7 @@ def happy_path(): "prefect.state-type": "RUNNING", "prefect.state-name": "Running", "prefect.state-timestamp": task_run_states[1].timestamp.isoformat(), + "prefect.orchestration": "server", } ) assert ( @@ -159,6 +161,7 @@ def happy_path(): "prefect.state-type": "COMPLETED", "prefect.state-name": "Completed", "prefect.state-timestamp": task_run_states[2].timestamp.isoformat(), + "prefect.orchestration": "server", } ) assert ( @@ -251,6 +254,7 @@ def happy_path(): "prefect.state-type": "PENDING", "prefect.state-name": "Pending", "prefect.state-timestamp": task_run_states[0].timestamp.isoformat(), + "prefect.orchestration": "server", } ) assert ( @@ -302,6 +306,7 @@ def happy_path(): "prefect.state-type": "RUNNING", "prefect.state-name": "Running", "prefect.state-timestamp": task_run_states[1].timestamp.isoformat(), + "prefect.orchestration": "server", } ) assert ( @@ -361,6 +366,7 @@ def happy_path(): "prefect.state-type": "FAILED", "prefect.state-name": "Failed", "prefect.state-timestamp": task_run_states[2].timestamp.isoformat(), + "prefect.orchestration": "server", } ) assert ( diff --git a/tests/server/services/test_task_run_recorder.py b/tests/server/services/test_task_run_recorder.py new file mode 100644 index 000000000000..333bb4c9a98f --- /dev/null +++ b/tests/server/services/test_task_run_recorder.py @@ -0,0 +1,180 @@ +import asyncio +from typing import AsyncGenerator +from uuid import UUID + +import pendulum +import pytest + +from prefect.server.events.schemas.events import ReceivedEvent +from prefect.server.services import task_run_recorder +from prefect.server.utilities.messaging import MessageHandler +from prefect.server.utilities.messaging.memory import MemoryMessage + + +async def test_start_and_stop_service(): + service = task_run_recorder.TaskRunRecorder() + service_task = asyncio.create_task(service.start()) + service.started_event = asyncio.Event() + + await service.started_event.wait() + assert service.consumer_task is not None + + await service.stop() + assert service.consumer_task is None + + await service_task + + +@pytest.fixture +async def task_run_recorder_handler() -> AsyncGenerator[MessageHandler, None]: + async with task_run_recorder.consumer() as handler: + yield handler + + +@pytest.fixture +def hello_event() -> ReceivedEvent: + return ReceivedEvent( + occurred=pendulum.datetime(2022, 1, 2, 3, 4, 5, 6, "UTC"), + event="hello", + resource={ + "prefect.resource.id": "my.resource.id", + }, + related=[ + {"prefect.resource.id": "related-1", "prefect.resource.role": "role-1"}, + {"prefect.resource.id": "related-2", "prefect.resource.role": "role-1"}, + {"prefect.resource.id": "related-3", "prefect.resource.role": "role-2"}, + ], + payload={"hello": "world"}, + account=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), + workspace=UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"), + received=pendulum.datetime(2022, 2, 3, 4, 5, 6, 7, "UTC"), + id=UUID("eeeeeeee-eeee-eeee-eeee-eeeeeeeeeeee"), + follows=UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), + ) + + +@pytest.fixture +def client_orchestrated_task_run_event() -> ReceivedEvent: + return ReceivedEvent( + occurred=pendulum.datetime(2022, 1, 2, 3, 4, 5, 6, "UTC"), + event="prefect.task-run.Running", + resource={ + "prefect.resource.id": "prefect.task-run.b75b283c-7cd5-439a-b23e-d0c59e78b042", + "prefect.resource.name": "my_task", + "prefect.state-message": "", + "prefect.state-name": "Running", + "prefect.state-timestamp": pendulum.datetime( + 2022, 1, 2, 3, 4, 5, 6, "UTC" + ).isoformat(), + "prefect.state-type": "RUNNING", + "prefect.orchestration": "client", + }, + related=[], + payload={ + "intended": {"from": "PENDING", "to": "RUNNING"}, + "initial_state": {"type": "PENDING", "name": "Pending", "message": ""}, + "validated_state": {"type": "RUNNING", "name": "Running", "message": ""}, + }, + account=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), + workspace=UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"), + received=pendulum.datetime(2022, 2, 3, 4, 5, 6, 7, "UTC"), + id=UUID("eeeeeeee-eeee-eeee-eeee-eeeeeeeeeeee"), + follows=UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), + ) + + +@pytest.fixture +def server_orchestrated_task_run_event() -> ReceivedEvent: + return ReceivedEvent( + occurred=pendulum.datetime(2022, 1, 2, 3, 4, 5, 6, "UTC"), + event="prefect.task-run.Running", + resource={ + "prefect.resource.id": "prefect.task-run.b75b283c-7cd5-439a-b23e-d0c59e78b042", + "prefect.resource.name": "my_task", + "prefect.state-message": "", + "prefect.state-name": "Running", + "prefect.state-timestamp": pendulum.datetime( + 2022, 1, 2, 3, 4, 5, 6, "UTC" + ).isoformat(), + "prefect.state-type": "RUNNING", + "prefect.orchestration": "server", + }, + related=[], + payload={ + "intended": {"from": "PENDING", "to": "RUNNING"}, + "initial_state": {"type": "PENDING", "name": "Pending", "message": ""}, + "validated_state": {"type": "RUNNING", "name": "Running", "message": ""}, + }, + account=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), + workspace=UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"), + received=pendulum.datetime(2022, 2, 3, 4, 5, 6, 7, "UTC"), + id=UUID("eeeeeeee-eeee-eeee-eeee-eeeeeeeeeeee"), + follows=UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), + ) + + +@pytest.fixture +def client_orchestrated_task_run_event_message( + client_orchestrated_task_run_event: ReceivedEvent, +) -> MemoryMessage: + return MemoryMessage( + data=client_orchestrated_task_run_event.model_dump_json().encode(), + attributes={}, + ) + + +@pytest.fixture +def server_orchestrated_task_run_event_message( + server_orchestrated_task_run_event: ReceivedEvent, +) -> MemoryMessage: + return MemoryMessage( + data=server_orchestrated_task_run_event.model_dump_json().encode(), + attributes={}, + ) + + +@pytest.fixture +def hello_event_message(hello_event: ReceivedEvent) -> MemoryMessage: + return MemoryMessage( + data=hello_event.model_dump_json().encode(), + attributes={}, + ) + + +async def test_handle_client_orchestrated_task_run_event( + task_run_recorder_handler: MessageHandler, + client_orchestrated_task_run_event: ReceivedEvent, + client_orchestrated_task_run_event_message: MemoryMessage, + caplog: pytest.LogCaptureFixture, +): + with caplog.at_level("INFO"): + await task_run_recorder_handler(client_orchestrated_task_run_event_message) + + assert "Received event" in caplog.text + assert str(client_orchestrated_task_run_event.id) in caplog.text + + +async def test_skip_non_task_run_event( + task_run_recorder_handler: MessageHandler, + hello_event: ReceivedEvent, + hello_event_message: MemoryMessage, + caplog: pytest.LogCaptureFixture, +): + with caplog.at_level("INFO"): + await task_run_recorder_handler(hello_event_message) + + assert "Received event" not in caplog.text + assert str(hello_event.id) not in caplog.text + + +async def test_skip_server_side_orchestrated_task_run( + task_run_recorder_handler: MessageHandler, + server_orchestrated_task_run_event: ReceivedEvent, + server_orchestrated_task_run_event_message: MemoryMessage, + caplog: pytest.LogCaptureFixture, +): + with caplog.at_level("INFO"): + await task_run_recorder_handler(server_orchestrated_task_run_event_message) + + assert "Received event" not in caplog.text + assert str(server_orchestrated_task_run_event.id) not in caplog.text