Skip to content

Commit

Permalink
scaffolding for client task run orchestration (#14627)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakekaplan authored Jul 16, 2024
1 parent eb738c3 commit 36e2ddf
Show file tree
Hide file tree
Showing 7 changed files with 292 additions and 0 deletions.
10 changes: 10 additions & 0 deletions docs/3.0rc/api-ref/rest-api/server/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -22165,6 +22165,11 @@
"title": "Prefect Api Log Retryable Errors",
"default": false
},
"PREFECT_API_SERVICES_TASK_RUN_RECORDER_ENABLED": {
"type": "boolean",
"title": "Prefect Api Services Task Run Recorder Enabled",
"default": true
},
"PREFECT_API_DEFAULT_LIMIT": {
"type": "integer",
"title": "Prefect Api Default Limit",
Expand Down Expand Up @@ -22257,6 +22262,11 @@
"title": "Prefect Api Max Flow Run Graph Artifacts",
"default": 10000
},
"PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION": {
"type": "boolean",
"title": "Prefect Experimental Enable Client Side Task Orchestration",
"default": false
},
"PREFECT_RUNNER_PROCESS_LIMIT": {
"type": "integer",
"title": "Prefect Runner Process Limit",
Expand Down
7 changes: 7 additions & 0 deletions src/prefect/server/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from prefect.server.events.services.event_persister import EventPersister
from prefect.server.events.services.triggers import ProactiveTriggers, ReactiveTriggers
from prefect.server.exceptions import ObjectNotFoundError
from prefect.server.services.task_run_recorder import TaskRunRecorder
from prefect.server.utilities.database import get_dialect
from prefect.server.utilities.server import method_paths_from_routes
from prefect.settings import (
Expand Down Expand Up @@ -603,6 +604,12 @@ async def start_services():
if prefect.settings.PREFECT_API_EVENTS_STREAM_OUT_ENABLED:
service_instances.append(stream.Distributor())

if (
prefect.settings.PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION
and prefect.settings.PREFECT_API_SERVICES_TASK_RUN_RECORDER_ENABLED
):
service_instances.append(TaskRunRecorder())

loop = asyncio.get_running_loop()

app.state.services = {
Expand Down
73 changes: 73 additions & 0 deletions src/prefect/server/services/task_run_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import asyncio
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional

from prefect.logging import get_logger
from prefect.server.events.schemas.events import ReceivedEvent
from prefect.server.utilities.messaging import Message, MessageHandler, create_consumer

logger = get_logger(__name__)


@asynccontextmanager
async def consumer() -> AsyncGenerator[MessageHandler, None]:
async def message_handler(message: Message):
event: ReceivedEvent = ReceivedEvent.model_validate_json(message.data)

if not event.event.startswith("prefect.task-run"):
return

if not event.resource.get("prefect.orchestration") == "client":
return

logger.info(
f"Received event: {event.event} with id: {event.id} for resource: {event.resource.get('prefect.resource.id')}"
)

yield message_handler


class TaskRunRecorder:
"""A service to record task run and task run states from events."""

name: str = "TaskRunRecorder"

consumer_task: Optional[asyncio.Task] = None

def __init__(self):
self._started_event: Optional[asyncio.Event] = None

@property
def started_event(self) -> asyncio.Event:
if self._started_event is None:
self._started_event = asyncio.Event()
return self._started_event

@started_event.setter
def started_event(self, value: asyncio.Event) -> None:
self._started_event = value

async def start(self):
assert self.consumer_task is None, "TaskRunRecorder already started"
self.consumer = create_consumer("events")

async with consumer() as handler:
self.consumer_task = asyncio.create_task(self.consumer.run(handler))
logger.debug("TaskRunRecorder started")
self.started_event.set()

try:
await self.consumer_task
except asyncio.CancelledError:
pass

async def stop(self):
assert self.consumer_task is not None, "Logger not started"
self.consumer_task.cancel()
try:
await self.consumer_task
except asyncio.CancelledError:
pass
finally:
self.consumer_task = None
logger.debug("TaskRunRecorder stopped")
12 changes: 12 additions & 0 deletions src/prefect/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,11 @@ def default_cloud_ui_url(settings, value):
PREFECT_API_LOG_RETRYABLE_ERRORS = Setting(bool, default=False)
"""If `True`, log retryable errors in the API and it's services."""

PREFECT_API_SERVICES_TASK_RUN_RECORDER_ENABLED = Setting(bool, default=True)
"""
Whether or not to start the task run recorder service in the server application.
"""


PREFECT_API_DEFAULT_LIMIT = Setting(
int,
Expand Down Expand Up @@ -1309,6 +1314,13 @@ def default_cloud_ui_url(settings, value):
"""


PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION = Setting(
bool, default=False
)
"""
Whether or not to enable experimental client side task run orchestration.
"""

# Prefect Events feature flags

PREFECT_RUNNER_PROCESS_LIMIT = Setting(int, default=5)
Expand Down
4 changes: 4 additions & 0 deletions src/prefect/utilities/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
)
from prefect.results import BaseResult
from prefect.settings import (
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION,
PREFECT_LOGGING_LOG_PRINTS,
)
from prefect.states import (
Expand Down Expand Up @@ -798,6 +799,9 @@ def emit_task_run_state_change_event(
else ""
),
"prefect.state-type": str(validated_state.type.value),
"prefect.orchestration": "client"
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION
else "server",
},
follows=follows,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def happy_path():
"prefect.state-type": "PENDING",
"prefect.state-name": "Pending",
"prefect.state-timestamp": task_run_states[0].timestamp.isoformat(),
"prefect.orchestration": "server",
}
)
assert (
Expand Down Expand Up @@ -103,6 +104,7 @@ def happy_path():
"prefect.state-type": "RUNNING",
"prefect.state-name": "Running",
"prefect.state-timestamp": task_run_states[1].timestamp.isoformat(),
"prefect.orchestration": "server",
}
)
assert (
Expand Down Expand Up @@ -159,6 +161,7 @@ def happy_path():
"prefect.state-type": "COMPLETED",
"prefect.state-name": "Completed",
"prefect.state-timestamp": task_run_states[2].timestamp.isoformat(),
"prefect.orchestration": "server",
}
)
assert (
Expand Down Expand Up @@ -251,6 +254,7 @@ def happy_path():
"prefect.state-type": "PENDING",
"prefect.state-name": "Pending",
"prefect.state-timestamp": task_run_states[0].timestamp.isoformat(),
"prefect.orchestration": "server",
}
)
assert (
Expand Down Expand Up @@ -302,6 +306,7 @@ def happy_path():
"prefect.state-type": "RUNNING",
"prefect.state-name": "Running",
"prefect.state-timestamp": task_run_states[1].timestamp.isoformat(),
"prefect.orchestration": "server",
}
)
assert (
Expand Down Expand Up @@ -361,6 +366,7 @@ def happy_path():
"prefect.state-type": "FAILED",
"prefect.state-name": "Failed",
"prefect.state-timestamp": task_run_states[2].timestamp.isoformat(),
"prefect.orchestration": "server",
}
)
assert (
Expand Down
180 changes: 180 additions & 0 deletions tests/server/services/test_task_run_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import asyncio
from typing import AsyncGenerator
from uuid import UUID

import pendulum
import pytest

from prefect.server.events.schemas.events import ReceivedEvent
from prefect.server.services import task_run_recorder
from prefect.server.utilities.messaging import MessageHandler
from prefect.server.utilities.messaging.memory import MemoryMessage


async def test_start_and_stop_service():
service = task_run_recorder.TaskRunRecorder()
service_task = asyncio.create_task(service.start())
service.started_event = asyncio.Event()

await service.started_event.wait()
assert service.consumer_task is not None

await service.stop()
assert service.consumer_task is None

await service_task


@pytest.fixture
async def task_run_recorder_handler() -> AsyncGenerator[MessageHandler, None]:
async with task_run_recorder.consumer() as handler:
yield handler


@pytest.fixture
def hello_event() -> ReceivedEvent:
return ReceivedEvent(
occurred=pendulum.datetime(2022, 1, 2, 3, 4, 5, 6, "UTC"),
event="hello",
resource={
"prefect.resource.id": "my.resource.id",
},
related=[
{"prefect.resource.id": "related-1", "prefect.resource.role": "role-1"},
{"prefect.resource.id": "related-2", "prefect.resource.role": "role-1"},
{"prefect.resource.id": "related-3", "prefect.resource.role": "role-2"},
],
payload={"hello": "world"},
account=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"),
workspace=UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"),
received=pendulum.datetime(2022, 2, 3, 4, 5, 6, 7, "UTC"),
id=UUID("eeeeeeee-eeee-eeee-eeee-eeeeeeeeeeee"),
follows=UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"),
)


@pytest.fixture
def client_orchestrated_task_run_event() -> ReceivedEvent:
return ReceivedEvent(
occurred=pendulum.datetime(2022, 1, 2, 3, 4, 5, 6, "UTC"),
event="prefect.task-run.Running",
resource={
"prefect.resource.id": "prefect.task-run.b75b283c-7cd5-439a-b23e-d0c59e78b042",
"prefect.resource.name": "my_task",
"prefect.state-message": "",
"prefect.state-name": "Running",
"prefect.state-timestamp": pendulum.datetime(
2022, 1, 2, 3, 4, 5, 6, "UTC"
).isoformat(),
"prefect.state-type": "RUNNING",
"prefect.orchestration": "client",
},
related=[],
payload={
"intended": {"from": "PENDING", "to": "RUNNING"},
"initial_state": {"type": "PENDING", "name": "Pending", "message": ""},
"validated_state": {"type": "RUNNING", "name": "Running", "message": ""},
},
account=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"),
workspace=UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"),
received=pendulum.datetime(2022, 2, 3, 4, 5, 6, 7, "UTC"),
id=UUID("eeeeeeee-eeee-eeee-eeee-eeeeeeeeeeee"),
follows=UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"),
)


@pytest.fixture
def server_orchestrated_task_run_event() -> ReceivedEvent:
return ReceivedEvent(
occurred=pendulum.datetime(2022, 1, 2, 3, 4, 5, 6, "UTC"),
event="prefect.task-run.Running",
resource={
"prefect.resource.id": "prefect.task-run.b75b283c-7cd5-439a-b23e-d0c59e78b042",
"prefect.resource.name": "my_task",
"prefect.state-message": "",
"prefect.state-name": "Running",
"prefect.state-timestamp": pendulum.datetime(
2022, 1, 2, 3, 4, 5, 6, "UTC"
).isoformat(),
"prefect.state-type": "RUNNING",
"prefect.orchestration": "server",
},
related=[],
payload={
"intended": {"from": "PENDING", "to": "RUNNING"},
"initial_state": {"type": "PENDING", "name": "Pending", "message": ""},
"validated_state": {"type": "RUNNING", "name": "Running", "message": ""},
},
account=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"),
workspace=UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"),
received=pendulum.datetime(2022, 2, 3, 4, 5, 6, 7, "UTC"),
id=UUID("eeeeeeee-eeee-eeee-eeee-eeeeeeeeeeee"),
follows=UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"),
)


@pytest.fixture
def client_orchestrated_task_run_event_message(
client_orchestrated_task_run_event: ReceivedEvent,
) -> MemoryMessage:
return MemoryMessage(
data=client_orchestrated_task_run_event.model_dump_json().encode(),
attributes={},
)


@pytest.fixture
def server_orchestrated_task_run_event_message(
server_orchestrated_task_run_event: ReceivedEvent,
) -> MemoryMessage:
return MemoryMessage(
data=server_orchestrated_task_run_event.model_dump_json().encode(),
attributes={},
)


@pytest.fixture
def hello_event_message(hello_event: ReceivedEvent) -> MemoryMessage:
return MemoryMessage(
data=hello_event.model_dump_json().encode(),
attributes={},
)


async def test_handle_client_orchestrated_task_run_event(
task_run_recorder_handler: MessageHandler,
client_orchestrated_task_run_event: ReceivedEvent,
client_orchestrated_task_run_event_message: MemoryMessage,
caplog: pytest.LogCaptureFixture,
):
with caplog.at_level("INFO"):
await task_run_recorder_handler(client_orchestrated_task_run_event_message)

assert "Received event" in caplog.text
assert str(client_orchestrated_task_run_event.id) in caplog.text


async def test_skip_non_task_run_event(
task_run_recorder_handler: MessageHandler,
hello_event: ReceivedEvent,
hello_event_message: MemoryMessage,
caplog: pytest.LogCaptureFixture,
):
with caplog.at_level("INFO"):
await task_run_recorder_handler(hello_event_message)

assert "Received event" not in caplog.text
assert str(hello_event.id) not in caplog.text


async def test_skip_server_side_orchestrated_task_run(
task_run_recorder_handler: MessageHandler,
server_orchestrated_task_run_event: ReceivedEvent,
server_orchestrated_task_run_event_message: MemoryMessage,
caplog: pytest.LogCaptureFixture,
):
with caplog.at_level("INFO"):
await task_run_recorder_handler(server_orchestrated_task_run_event_message)

assert "Received event" not in caplog.text
assert str(server_orchestrated_task_run_event.id) not in caplog.text

0 comments on commit 36e2ddf

Please sign in to comment.