From 02c799191a43d66d1198e5ad4a7a9d18a7539ee7 Mon Sep 17 00:00:00 2001 From: "jake@prefect.io" Date: Thu, 25 Jul 2024 16:18:39 -0400 Subject: [PATCH] add eventspipelien and test utility --- src/prefect/events/clients.py | 4 + src/prefect/testing/fixtures.py | 21 +- tests/test_task_engine.py | 556 +++++++++++++++++--------------- 3 files changed, 322 insertions(+), 259 deletions(-) diff --git a/src/prefect/events/clients.py b/src/prefect/events/clients.py index 8e32b53297d7..d9cc67ae2895 100644 --- a/src/prefect/events/clients.py +++ b/src/prefect/events/clients.py @@ -140,6 +140,10 @@ def reset(cls) -> None: cls.last = None cls.all = [] + @classmethod + def reset_events(self) -> None: + self.events = [] + async def _emit(self, event: Event) -> None: self.events.append(event) diff --git a/src/prefect/testing/fixtures.py b/src/prefect/testing/fixtures.py index 6a4d5977bff8..c5cd2fd1ae36 100644 --- a/src/prefect/testing/fixtures.py +++ b/src/prefect/testing/fixtures.py @@ -1,3 +1,4 @@ +import asyncio import json import os import socket @@ -19,6 +20,7 @@ from prefect.events.clients import AssertingEventsClient from prefect.events.filters import EventFilter from prefect.events.worker import EventsWorker +from prefect.server.events.pipeline import EventsPipeline from prefect.settings import ( PREFECT_API_URL, PREFECT_SERVER_CSRF_PROTECTION_ENABLED, @@ -335,7 +337,7 @@ def mock_should_emit_events(monkeypatch) -> mock.Mock: return m -@pytest.fixture +@pytest.fixture(autouse=True) def asserting_events_worker(monkeypatch) -> Generator[EventsWorker, None, None]: worker = EventsWorker.instance(AssertingEventsClient) # Always yield the asserting worker when new instances are retrieved @@ -346,6 +348,23 @@ def asserting_events_worker(monkeypatch) -> Generator[EventsWorker, None, None]: worker.drain() +@pytest.fixture +async def events_pipeline(asserting_events_worker: EventsWorker): + class AssertingEventsPipeline(EventsPipeline): + def sync_process_events(self): + asyncio.run(self.process_events()) + + async def process_events(self): + asserting_events_worker.wait_until_empty() + events = asserting_events_worker._client.events + + messages = self.events_to_messages(events) + await self.process_messages(messages) + asserting_events_worker._client.reset_events() + + yield AssertingEventsPipeline() + + @pytest.fixture def reset_worker_events(asserting_events_worker: EventsWorker): yield diff --git a/tests/test_task_engine.py b/tests/test_task_engine.py index 1e003eaac834..d669b55f7681 100644 --- a/tests/test_task_engine.py +++ b/tests/test_task_engine.py @@ -14,15 +14,14 @@ from prefect import Task, flow, task from prefect.cache_policies import FLOW_PARAMETERS -from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client -from prefect.client.schemas.objects import StateType, TaskRun +from prefect.client.orchestration import PrefectClient, SyncPrefectClient +from prefect.client.schemas.objects import StateType from prefect.context import ( EngineContext, FlowRunContext, TaskRunContext, get_run_context, ) -from prefect.events.clients import AssertingEventsClient from prefect.events.worker import EventsWorker from prefect.exceptions import CrashedRun, MissingResult from prefect.filesystems import LocalFileSystem @@ -52,125 +51,6 @@ def enable_client_side_task_run_orchestration( yield enabled -def state_from_event(event) -> State: - return State( - id=event.id, - timestamp=event.occurred, - **event.payload["validated_state"], - ) - - -async def get_task_run(task_run_id: Optional[UUID]) -> TaskRun: - if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: - task_run = get_task_run_sync(task_run_id) - else: - client = get_client() - if task_run_id: - task_run = await client.read_task_run(task_run_id) - else: - task_runs = await client.read_task_runs() - task_run = task_runs[-1] - - return task_run - - -def get_task_run_sync(task_run_id: Optional[UUID]) -> TaskRun: - if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: - # the asserting_events_worker fixture - # ensures that calling .instance() here will always - # yield the same one - worker = EventsWorker.instance() - worker.wait_until_empty() - - events = AssertingEventsClient.last.events - events = sorted(events, key=lambda e: e.occurred) - if task_run_id: - events = [ - e - for e in events - if e.resource.prefect_object_id("prefect.task-run") == task_run_id - ] - last_event = events[-1] - state = state_from_event(last_event) - task_run = TaskRun( - id=last_event.resource.prefect_object_id("prefect.task-run"), - state=state, - state_id=state.id, - state_type=state.type, - state_name=state.name, - **last_event.payload["task_run"], - ) - else: - client = get_client(sync_client=True) - if task_run_id: - task_run = client.read_task_run(task_run_id) - else: - task_runs = client.read_task_runs() - task_run = task_runs[-1] - - return task_run - - -async def get_task_run_states( - task_run_id: UUID, state_type: Optional[StateType] = None -) -> List[State]: - if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: - # the asserting_events_worker fixture - # ensures that calling .instance() here will always - # yield the same one - worker = EventsWorker.instance() - worker.wait_until_empty() - events = AssertingEventsClient.last.events - events = sorted(events, key=lambda e: e.occurred) - events = [ - e - for e in events - if e.resource.prefect_object_id("prefect.task-run") == task_run_id - ] - states = [state_from_event(e) for e in events] - else: - client = get_client() - states = await client.read_task_run_states(task_run_id) - - if state_type: - states = [state for state in states if state.type == state_type] - - return states - - -async def get_task_run_state( - task_run_id: UUID, - state_type: StateType, -) -> State: - """ - Get a single state of a given type for a task run. If more than one state - of the given type is found, an error is raised. - """ - - if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: - # the asserting_events_worker fixture - # ensures that calling .instance() here will always - # yield the same one - worker = EventsWorker.instance() - worker.wait_until_empty() - events = AssertingEventsClient.last.events - events = sorted(events, key=lambda e: e.occurred) - events = [ - e - for e in events - if e.resource.prefect_object_id("prefect.task-run") == task_run_id - ] - states = [state_from_event(e) for e in events] - else: - client = get_client() - states = await client.read_task_run_states(task_run_id) - - states = [state for state in states if state.type == state_type] - - assert len(states) == 1 - return states[0] - - @task async def foo(): return 42 @@ -199,7 +79,9 @@ async def test_client_attr_returns_client_after_starting(self): class TestRunTask: - def test_run_task_with_client_provided_uuid(self): + def test_run_task_with_client_provided_uuid( + self, sync_prefect_client, events_pipeline + ): @task def foo(): return 42 @@ -208,7 +90,9 @@ def foo(): run_task_sync(foo, task_run_id=task_run_id) - task_run = get_task_run_sync(task_run_id) + events_pipeline.sync_process_events() + + task_run = sync_prefect_client.read_task_run(task_run_id) assert task_run.id == task_run_id async def test_with_provided_context(self, prefect_client): @@ -242,7 +126,9 @@ def foo(): class TestTaskRunsAsync: async def test_run_task_async_with_client_provided_uuid( - self, prefect_client: PrefectClient + self, + prefect_client: PrefectClient, + events_pipeline, ): @task async def foo(): @@ -252,7 +138,9 @@ async def foo(): await run_task_async(foo, task_run_id=task_run_id) - task_run = await get_task_run(task_run_id) + await events_pipeline.process_events() + + task_run = await prefect_client.read_task_run(task_run_id) assert task_run.id == task_run_id async def test_with_provided_context(self, prefect_client): @@ -328,13 +216,15 @@ async def f(*args, x, **kwargs): result = await f(1, 2, x=5, y=6, z=7) assert result == ((1, 2), 5, dict(y=6, z=7)) - async def test_task_run_name(self, prefect_client): + async def test_task_run_name(self, prefect_client, events_pipeline): @task(task_run_name="name is {x}") async def foo(x): return TaskRunContext.get().task_run.id result = await run_task_async(foo, parameters=dict(x="blue")) - run = await get_task_run(result) + await events_pipeline.process_events() + + run = await prefect_client.read_task_run(result) assert run.name == "name is blue" @@ -371,17 +261,19 @@ async def workflow(): assert await workflow() == flow_run_id - async def test_task_ends_in_completed(self, prefect_client): + async def test_task_ends_in_completed(self, prefect_client, events_pipeline): @task async def foo(): return TaskRunContext.get().task_run.id result = await run_task_async(foo) - run = await get_task_run(result) + await events_pipeline.process_events() + + run = await prefect_client.read_task_run(result) assert run.state_type == StateType.COMPLETED - async def test_task_ends_in_failed(self, prefect_client): + async def test_task_ends_in_failed(self, prefect_client, events_pipeline): ID = None @task @@ -393,11 +285,15 @@ async def foo(): with pytest.raises(ValueError, match="xyz"): await run_task_async(foo) - run = await get_task_run(ID) + await events_pipeline.process_events() + + run = await prefect_client.read_task_run(ID) assert run.state_type == StateType.FAILED - async def test_task_ends_in_failed_after_retrying(self, prefect_client): + async def test_task_ends_in_failed_after_retrying( + self, prefect_client, events_pipeline + ): ID = None @task(retries=1) @@ -411,11 +307,15 @@ async def foo(): result = await run_task_async(foo) - run = await get_task_run(result) + await events_pipeline.process_events() + + run = await prefect_client.read_task_run(result) assert run.state_type == StateType.COMPLETED - async def test_task_tracks_nested_parent_as_dependency(self, prefect_client): + async def test_task_tracks_nested_parent_as_dependency( + self, prefect_client, events_pipeline + ): @task async def inner(): return TaskRunContext.get().task_run.id @@ -428,16 +328,20 @@ async def outer(): a, b = await run_task_async(outer) assert a != b + await events_pipeline.process_events() + # assertions on outer - outer_run = await get_task_run(b) + outer_run = await prefect_client.read_task_run(b) assert outer_run.task_inputs == {} # assertions on inner - inner_run = await get_task_run(a) + inner_run = await prefect_client.read_task_run(a) assert "__parents__" in inner_run.task_inputs assert inner_run.task_inputs["__parents__"][0].id == b - async def test_multiple_nested_tasks_track_parent(self, prefect_client): + async def test_multiple_nested_tasks_track_parent( + self, prefect_client, events_pipeline + ): @task def level_3(): return TaskRunContext.get().task_run.id @@ -459,16 +363,18 @@ def f(): id1, id2, id3 = f() assert id1 != id2 != id3 + await events_pipeline.process_events() + for id_, parent_id in [(id3, id2), (id2, id1)]: - run = await get_task_run(id_) + run = await prefect_client.read_task_run(id_) assert "__parents__" in run.task_inputs assert run.task_inputs["__parents__"][0].id == parent_id - run = await get_task_run(id1) + run = await prefect_client.read_task_run(id1) assert "__parents__" not in run.task_inputs async def test_tasks_in_subflow_do_not_track_subflow_dummy_task_as_parent( - self, + self, prefect_client, events_pipeline ): """ Ensures that tasks in a subflow do not track the subflow's dummy task as @@ -501,11 +407,13 @@ def level_1(): level_3_id = level_1() - tr = await get_task_run(level_3_id) + await events_pipeline.process_events() + + tr = await prefect_client.read_task_run(level_3_id) assert "__parents__" not in tr.task_inputs async def test_tasks_in_subflow_do_not_track_subflow_dummy_task_parent_as_parent( - self, + self, prefect_client, events_pipeline ): """ Ensures that tasks in a subflow do not track the subflow's dummy task as @@ -538,11 +446,15 @@ def level_1(): level_4_id = level_1() - tr = await get_task_run(level_4_id) + await events_pipeline.process_events() + + tr = await prefect_client.read_task_run(level_4_id) assert "__parents__" not in tr.task_inputs - async def test_task_runs_respect_result_persistence(self, prefect_client): + async def test_task_runs_respect_result_persistence( + self, prefect_client, events_pipeline + ): @task(persist_result=False) async def no_persist(): return TaskRunContext.get().task_run.id @@ -553,7 +465,8 @@ async def persist(): # assert no persistence run_id = await run_task_async(no_persist) - task_run = await get_task_run(run_id) + await events_pipeline.process_events() + task_run = await prefect_client.read_task_run(run_id) api_state = task_run.state with pytest.raises(MissingResult): @@ -561,7 +474,8 @@ async def persist(): # assert persistence run_id = await run_task_async(persist) - task_run = await get_task_run(run_id) + await events_pipeline.process_events() + task_run = await prefect_client.read_task_run(run_id) api_state = task_run.state assert await api_state.result() == run_id @@ -626,13 +540,14 @@ def f(*args, x, **kwargs): result = f(1, 2, x=5, y=6, z=7) assert result == ((1, 2), 5, dict(y=6, z=7)) - async def test_task_run_name(self, prefect_client): + async def test_task_run_name(self, prefect_client, events_pipeline): @task(task_run_name="name is {x}") def foo(x): return TaskRunContext.get().task_run.id result = run_task_sync(foo, parameters=dict(x="blue")) - run = await get_task_run(result) + await events_pipeline.process_events() + run = await prefect_client.read_task_run(result) assert run.name == "name is blue" def test_get_run_logger(self, caplog): @@ -668,17 +583,18 @@ def workflow(): assert workflow() == flow_run_id - async def test_task_ends_in_completed(self, prefect_client): + async def test_task_ends_in_completed(self, prefect_client, events_pipeline): @task def foo(): return TaskRunContext.get().task_run.id result = run_task_sync(foo) - run = await get_task_run(result) + await events_pipeline.process_events() + run = await prefect_client.read_task_run(result) assert run.state_type == StateType.COMPLETED - async def test_task_ends_in_failed(self, prefect_client): + async def test_task_ends_in_failed(self, prefect_client, events_pipeline): ID = None @task @@ -690,11 +606,14 @@ def foo(): with pytest.raises(ValueError, match="xyz"): run_task_sync(foo) - run = await get_task_run(ID) + await events_pipeline.process_events() + run = await prefect_client.read_task_run(ID) assert run.state_type == StateType.FAILED - async def test_task_ends_in_failed_after_retrying(self, prefect_client): + async def test_task_ends_in_failed_after_retrying( + self, prefect_client, events_pipeline + ): ID = None @task(retries=1) @@ -708,11 +627,14 @@ def foo(): result = run_task_sync(foo) - run = await get_task_run(result) + await events_pipeline.process_events() + run = await prefect_client.read_task_run(result) assert run.state_type == StateType.COMPLETED - async def test_task_tracks_nested_parent_as_dependency(self, prefect_client): + async def test_task_tracks_nested_parent_as_dependency( + self, prefect_client, events_pipeline + ): @task def inner(): return TaskRunContext.get().task_run.id @@ -724,17 +646,20 @@ def outer(): a, b = run_task_sync(outer) assert a != b + await events_pipeline.process_events() # assertions on outer - outer_run = await get_task_run(b) + outer_run = await prefect_client.read_task_run(b) assert outer_run.task_inputs == {} # assertions on inner - inner_run = await get_task_run(a) + inner_run = await prefect_client.read_task_run(a) assert "__parents__" in inner_run.task_inputs assert inner_run.task_inputs["__parents__"][0].id == b - async def test_task_runs_respect_result_persistence(self, prefect_client): + async def test_task_runs_respect_result_persistence( + self, prefect_client, events_pipeline + ): @task(persist_result=False) def no_persist(): ctx = TaskRunContext.get() @@ -749,7 +674,8 @@ def persist(): # assert no persistence run_id = run_task_sync(no_persist) - task_run = await get_task_run(run_id) + await events_pipeline.process_events() + task_run = await prefect_client.read_task_run(run_id) api_state = task_run.state with pytest.raises(MissingResult): @@ -757,7 +683,8 @@ def persist(): # assert persistence run_id = run_task_sync(persist) - task_run = await get_task_run(run_id) + await events_pipeline.process_events() + task_run = await prefect_client.read_task_run(run_id) api_state = task_run.state assert await api_state.result() == run_id @@ -809,7 +736,9 @@ async def foo(): class TestTaskRetries: @pytest.mark.parametrize("always_fail", [True, False]) - async def test_task_respects_retry_count(self, always_fail, prefect_client): + async def test_task_respects_retry_count( + self, always_fail, prefect_client, events_pipeline + ): mock = MagicMock() exc = ValueError() @@ -843,7 +772,8 @@ async def test_flow(): assert await task_run_state.result() is True assert mock.call_count == 4 - states = await get_task_run_states(task_run_id) + await events_pipeline.process_events() + states = await prefect_client.read_task_run_states(task_run_id) state_names = [state.name for state in states] assert state_names == [ @@ -856,7 +786,9 @@ async def test_flow(): ] @pytest.mark.parametrize("always_fail", [True, False]) - async def test_task_respects_retry_count_sync(self, always_fail): + async def test_task_respects_retry_count_sync( + self, always_fail, prefect_client, events_pipeline + ): mock = MagicMock() exc = ValueError() @@ -891,7 +823,8 @@ def test_flow(): assert await task_run_state.result() is True # type: ignore assert mock.call_count == 4 - states = await get_task_run_states(task_run_id) + await events_pipeline.process_events() + states = await prefect_client.read_task_run_states(task_run_id) state_names = [state.name for state in states] assert state_names == [ @@ -903,7 +836,9 @@ def test_flow(): "Failed" if always_fail else "Completed", ] - async def test_task_only_uses_necessary_retries(self): + async def test_task_only_uses_necessary_retries( + self, prefect_client, events_pipeline + ): mock = MagicMock() exc = ValueError() @@ -925,7 +860,8 @@ async def test_flow(): assert await task_run_state.result() is True assert mock.call_count == 2 - states = await get_task_run_states(task_run_id) + await events_pipeline.process_events() + states = await prefect_client.read_task_run_states(task_run_id) state_names = [state.name for state in states] assert state_names == [ @@ -935,7 +871,9 @@ async def test_flow(): "Completed", ] - async def test_task_passes_failed_state_to_retry_fn(self): + async def test_task_passes_failed_state_to_retry_fn( + self, prefect_client, events_pipeline + ): mock = MagicMock() exc = SyntaxError("oops") handler_mock = MagicMock() @@ -968,7 +906,8 @@ async def test_flow(): assert mock.call_count == 2 assert handler_mock.call_count == 1 - states = await get_task_run_states(task_run_id) + await events_pipeline.process_events() + states = await prefect_client.read_task_run_states(task_run_id) state_names = [state.name for state in states] assert state_names == [ @@ -978,7 +917,9 @@ async def test_flow(): "Completed", ] - async def test_task_passes_failed_state_to_retry_fn_sync(self): + async def test_task_passes_failed_state_to_retry_fn_sync( + self, prefect_client, events_pipeline + ): mock = MagicMock() exc = SyntaxError("oops") handler_mock = MagicMock() @@ -1011,7 +952,8 @@ def test_flow(): assert mock.call_count == 2 assert handler_mock.call_count == 1 - states = await get_task_run_states(task_run_id) + await events_pipeline.process_events() + states = await prefect_client.read_task_run_states(task_run_id) state_names = [state.name for state in states] assert state_names == [ @@ -1089,7 +1031,12 @@ async def test_flow(): ], ) async def test_async_task_respects_retry_delay_seconds( - self, retry_delay_seconds, expected_delay_sequence, prefect_client, monkeypatch + self, + retry_delay_seconds, + expected_delay_sequence, + prefect_client, + events_pipeline, + monkeypatch, ): mock_sleep = AsyncMock() monkeypatch.setattr(anyio, "sleep", mock_sleep) @@ -1107,7 +1054,8 @@ async def flaky_function(): call(pytest.approx(delay, abs=1)) for delay in expected_delay_sequence ] - states = await get_task_run_states(task_run_id) + await events_pipeline.process_events() + states = await prefect_client.read_task_run_states(task_run_id) state_names = [state.name for state in states] assert state_names == [ "Pending", @@ -1133,7 +1081,12 @@ async def flaky_function(): ], ) async def test_sync_task_respects_retry_delay_seconds( - self, retry_delay_seconds, expected_delay_sequence, prefect_client, monkeypatch + self, + retry_delay_seconds, + expected_delay_sequence, + prefect_client, + events_pipeline, + monkeypatch, ): mock_sleep = AsyncMock() monkeypatch.setattr(anyio, "sleep", mock_sleep) @@ -1150,8 +1103,8 @@ def flaky_function(): assert mock_sleep.call_args_list == [ call(pytest.approx(delay, abs=1)) for delay in expected_delay_sequence ] - - states = await get_task_run_states(task_run_id) + await events_pipeline.process_events() + states = await prefect_client.read_task_run_states(task_run_id) state_names = [state.name for state in states] assert state_names == [ "Pending", @@ -1169,7 +1122,7 @@ def flaky_function(): class TestTaskCrashDetection: @pytest.mark.parametrize("interrupt_type", [KeyboardInterrupt, SystemExit]) async def test_interrupt_in_task_function_crashes_task( - self, prefect_client, interrupt_type + self, prefect_client, interrupt_type, events_pipeline ): @task async def my_task(): @@ -1178,7 +1131,10 @@ async def my_task(): with pytest.raises(interrupt_type): await my_task() - task_run = await get_task_run(task_run_id=None) + await events_pipeline.process_events() + task_runs = await prefect_client.read_task_runs() + assert len(task_runs) == 1 + task_run = task_runs[0] assert task_run.state.is_crashed() assert task_run.state.type == StateType.CRASHED assert "Execution was aborted" in task_run.state.message @@ -1187,7 +1143,7 @@ async def my_task(): @pytest.mark.parametrize("interrupt_type", [KeyboardInterrupt, SystemExit]) async def test_interrupt_in_task_function_crashes_task_sync( - self, prefect_client, interrupt_type + self, prefect_client, events_pipeline, interrupt_type ): @task def my_task(): @@ -1196,7 +1152,10 @@ def my_task(): with pytest.raises(interrupt_type): my_task() - task_run = await get_task_run(task_run_id=None) + await events_pipeline.process_events() + task_runs = await prefect_client.read_task_runs() + assert len(task_runs) == 1 + task_run = task_runs[0] assert task_run.state.is_crashed() assert task_run.state.type == StateType.CRASHED assert "Execution was aborted" in task_run.state.message @@ -1205,7 +1164,7 @@ def my_task(): @pytest.mark.parametrize("interrupt_type", [KeyboardInterrupt, SystemExit]) async def test_interrupt_in_task_orchestration_crashes_task_and_flow( - self, interrupt_type, monkeypatch + self, prefect_client, events_pipeline, interrupt_type, monkeypatch ): monkeypatch.setattr( TaskRunEngine, "begin_run", MagicMock(side_effect=interrupt_type) @@ -1218,7 +1177,10 @@ async def my_task(): with pytest.raises(interrupt_type): await my_task() - task_run = await get_task_run(task_run_id=None) + await events_pipeline.process_events() + task_runs = await prefect_client.read_task_runs() + assert len(task_runs) == 1 + task_run = task_runs[0] assert task_run.state.is_crashed() assert task_run.state.type == StateType.CRASHED assert "Execution was aborted" in task_run.state.message @@ -1227,61 +1189,79 @@ async def my_task(): class TestTaskTimeTracking: - async def test_sync_task_sets_start_time_on_running(self): + async def test_sync_task_sets_start_time_on_running( + self, prefect_client, events_pipeline + ): @task def foo(): return TaskRunContext.get().task_run.id task_run_id = run_task_sync(foo) - run = await get_task_run(task_run_id) + await events_pipeline.process_events() - running = await get_task_run_state(task_run_id, StateType.RUNNING) + run = await prefect_client.read_task_run(task_run_id) + + states = await prefect_client.read_task_run_states(task_run_id) + running = [state for state in states if state.type == StateType.RUNNING][0] assert run.start_time assert run.start_time == running.timestamp - async def test_async_task_sets_start_time_on_running(self): + async def test_async_task_sets_start_time_on_running( + self, prefect_client, events_pipeline + ): @task async def foo(): return TaskRunContext.get().task_run.id task_run_id = await run_task_async(foo) - run = await get_task_run(task_run_id) + await events_pipeline.process_events() + run = await prefect_client.read_task_run(task_run_id) - running = await get_task_run_state(run.id, StateType.RUNNING) + states = await prefect_client.read_task_run_states(task_run_id) + running = [state for state in states if state.type == StateType.RUNNING][0] assert run.start_time assert run.start_time == running.timestamp - async def test_sync_task_sets_end_time_on_completed(self): + async def test_sync_task_sets_end_time_on_completed( + self, prefect_client, events_pipeline + ): @task def foo(): return TaskRunContext.get().task_run.id task_run_id = run_task_sync(foo) - run = await get_task_run(task_run_id) + await events_pipeline.process_events() + run = await prefect_client.read_task_run(task_run_id) - running = await get_task_run_state(task_run_id, StateType.RUNNING) - completed = await get_task_run_state(task_run_id, StateType.COMPLETED) + states = await prefect_client.read_task_run_states(task_run_id) + running = [state for state in states if state.type == StateType.RUNNING][0] + completed = [state for state in states if state.type == StateType.COMPLETED][0] assert run.end_time assert run.end_time == completed.timestamp assert run.total_run_time == completed.timestamp - running.timestamp - async def test_async_task_sets_end_time_on_completed(self): + async def test_async_task_sets_end_time_on_completed( + self, prefect_client, events_pipeline + ): @task async def foo(): return TaskRunContext.get().task_run.id task_run_id = await run_task_async(foo) - run = await get_task_run(task_run_id) - - running = await get_task_run_state(task_run_id, StateType.RUNNING) - completed = await get_task_run_state(task_run_id, StateType.COMPLETED) + await events_pipeline.process_events() + run = await prefect_client.read_task_run(task_run_id) + states = await prefect_client.read_task_run_states(task_run_id) + running = [state for state in states if state.type == StateType.RUNNING][0] + completed = [state for state in states if state.type == StateType.COMPLETED][0] assert run.end_time assert run.end_time == completed.timestamp assert run.total_run_time == completed.timestamp - running.timestamp - async def test_sync_task_sets_end_time_on_failed(self): + async def test_sync_task_sets_end_time_on_failed( + self, prefect_client, events_pipeline + ): ID = None @task @@ -1293,16 +1273,21 @@ def foo(): with pytest.raises(ValueError): run_task_sync(foo) - run = await get_task_run(ID) + await events_pipeline.process_events() + + run = await prefect_client.read_task_run(ID) - running = await get_task_run_state(run.id, StateType.RUNNING) - failed = await get_task_run_state(run.id, StateType.FAILED) + states = await prefect_client.read_task_run_states(ID) + running = [state for state in states if state.type == StateType.RUNNING][0] + failed = [state for state in states if state.type == StateType.FAILED][0] assert run.end_time assert run.end_time == failed.timestamp assert run.total_run_time == failed.timestamp - running.timestamp - async def test_async_task_sets_end_time_on_failed(self): + async def test_async_task_sets_end_time_on_failed( + self, prefect_client, events_pipeline + ): ID = None @task @@ -1314,16 +1299,19 @@ async def foo(): with pytest.raises(ValueError): await run_task_async(foo) - run = await get_task_run(ID) - - running = await get_task_run_state(run.id, StateType.RUNNING) - failed = await get_task_run_state(run.id, StateType.FAILED) + await events_pipeline.process_events() + run = await prefect_client.read_task_run(ID) + states = await prefect_client.read_task_run_states(ID) + running = [state for state in states if state.type == StateType.RUNNING][0] + failed = [state for state in states if state.type == StateType.FAILED][0] assert run.end_time assert run.end_time == failed.timestamp assert run.total_run_time == failed.timestamp - running.timestamp - async def test_sync_task_sets_end_time_on_crashed(self): + async def test_sync_task_sets_end_time_on_crashed( + self, prefect_client, events_pipeline + ): ID = None @task @@ -1334,17 +1322,20 @@ def foo(): with pytest.raises(SystemExit): run_task_sync(foo) + await events_pipeline.process_events() - run = await get_task_run(ID) - - running = await get_task_run_state(run.id, StateType.RUNNING) - crashed = await get_task_run_state(run.id, StateType.CRASHED) + run = await prefect_client.read_task_run(ID) + states = await prefect_client.read_task_run_states(ID) + running = [state for state in states if state.type == StateType.RUNNING][0] + crashed = [state for state in states if state.type == StateType.CRASHED][0] assert run.end_time assert run.end_time == crashed.timestamp assert run.total_run_time == crashed.timestamp - running.timestamp - async def test_async_task_sets_end_time_on_crashed(self): + async def test_async_task_sets_end_time_on_crashed( + self, prefect_client, events_pipeline + ): ID = None @task @@ -1356,17 +1347,18 @@ async def foo(): with pytest.raises(SystemExit): await run_task_async(foo) - run = await get_task_run(ID) - - running = await get_task_run_state(run.id, StateType.RUNNING) - crashed = await get_task_run_state(run.id, StateType.CRASHED) + await events_pipeline.process_events() + run = await prefect_client.read_task_run(ID) + states = await prefect_client.read_task_run_states(ID) + running = [state for state in states if state.type == StateType.RUNNING][0] + crashed = [state for state in states if state.type == StateType.CRASHED][0] assert run.end_time assert run.end_time == crashed.timestamp assert run.total_run_time == crashed.timestamp - running.timestamp async def test_sync_task_does_not_set_end_time_on_crash_pre_runnning( - self, monkeypatch + self, monkeypatch, prefect_client, events_pipeline ): monkeypatch.setattr( TaskRunEngine, "begin_run", MagicMock(side_effect=SystemExit) @@ -1379,12 +1371,15 @@ def my_task(): with pytest.raises(SystemExit): my_task() - run = await get_task_run(task_run_id=None) + await events_pipeline.process_events() + task_runs = await prefect_client.read_task_runs() + assert len(task_runs) == 1 + run = task_runs[0] assert run.end_time is None async def test_async_task_does_not_set_end_time_on_crash_pre_running( - self, monkeypatch + self, monkeypatch, prefect_client, events_pipeline ): monkeypatch.setattr( TaskRunEngine, "begin_run", MagicMock(side_effect=SystemExit) @@ -1397,31 +1392,43 @@ async def my_task(): with pytest.raises(SystemExit): await my_task() - run = await get_task_run(task_run_id=None) + await events_pipeline.process_events() + + task_runs = await prefect_client.read_task_runs() + assert len(task_runs) == 1 + run = task_runs[0] assert run.end_time is None - async def test_sync_task_sets_expected_start_time_on_pending(self): + async def test_sync_task_sets_expected_start_time_on_pending( + self, prefect_client, events_pipeline + ): @task def foo(): return TaskRunContext.get().task_run.id task_run_id = run_task_sync(foo) - run = await get_task_run(task_run_id) + await events_pipeline.process_events() + run = await prefect_client.read_task_run(task_run_id) + + states = await prefect_client.read_task_run_states(task_run_id) + pending = [state for state in states if state.type == StateType.PENDING][0] - pending = await get_task_run_state(task_run_id, StateType.PENDING) assert run.expected_start_time assert run.expected_start_time == pending.timestamp - async def test_async_task_sets_expected_start_time_on_pending(self): + async def test_async_task_sets_expected_start_time_on_pending( + self, prefect_client, events_pipeline + ): @task async def foo(): return TaskRunContext.get().task_run.id task_run_id = await run_task_async(foo) - run = await get_task_run(task_run_id) - - pending = await get_task_run_state(run.id, StateType.PENDING) + await events_pipeline.process_events() + run = await prefect_client.read_task_run(task_run_id) + states = await prefect_client.read_task_run_states(task_run_id) + pending = [state for state in states if state.type == StateType.PENDING][0] assert run.expected_start_time assert run.expected_start_time == pending.timestamp @@ -1450,7 +1457,9 @@ def f(): parameters={"x": "y"}, ) - def test_sync_task_run_counts(self, flow_run_context: EngineContext): + def test_sync_task_run_counts( + self, flow_run_context: EngineContext, sync_prefect_client, events_pipeline + ): ID = None proof_that_i_ran = uuid4() @@ -1473,12 +1482,15 @@ def foo(): with flow_run_context: assert run_task_sync(foo) == proof_that_i_ran - task_run = get_task_run_sync(ID) + events_pipeline.sync_process_events() + task_run = sync_prefect_client.read_task_run(ID) assert task_run assert task_run.run_count == 1 assert task_run.flow_run_run_count == flow_run_context.flow_run.run_count - async def test_async_task_run_counts(self, flow_run_context: EngineContext): + async def test_async_task_run_counts( + self, flow_run_context: EngineContext, prefect_client, events_pipeline + ): ID = None proof_that_i_ran = uuid4() @@ -1501,7 +1513,8 @@ async def foo(): with flow_run_context: assert await run_task_async(foo) == proof_that_i_ran - task_run = await get_task_run(ID) + await events_pipeline.process_events() + task_run = await prefect_client.read_task_run(ID) assert task_run assert task_run.run_count == 1 assert task_run.flow_run_run_count == flow_run_context.flow_run.run_count @@ -1812,7 +1825,9 @@ def g(): for i in g(return_state=True): pass - async def test_generator_task_states(self, prefect_client: PrefectClient): + async def test_generator_task_states( + self, prefect_client: PrefectClient, events_pipeline + ): """ Test for generator behavior including StopIteration """ @@ -1824,14 +1839,16 @@ def g(): gen = g() tr_id = next(gen) - tr = await get_task_run(tr_id) + await events_pipeline.process_events() + tr = await prefect_client.read_task_run(tr_id) assert tr.state.is_running() # exhaust the generator for _ in gen: pass - tr = await get_task_run(tr_id) + await events_pipeline.process_events() + tr = await prefect_client.read_task_run(tr_id) assert tr.state.is_completed() async def test_generator_task_with_return(self): @@ -1863,7 +1880,7 @@ def g(): next(gen) async def test_generator_task_with_exception_is_failed( - self, prefect_client: PrefectClient + self, prefect_client: PrefectClient, events_pipeline ): @task def g(): @@ -1874,10 +1891,13 @@ def g(): tr_id = next(gen) with pytest.raises(ValueError, match="xyz"): next(gen) - tr = await get_task_run(tr_id) + await events_pipeline.process_events() + tr = await prefect_client.read_task_run(tr_id) assert tr.state.is_failed() - async def test_generator_parent_tracking(self, prefect_client: PrefectClient): + async def test_generator_parent_tracking( + self, prefect_client: PrefectClient, events_pipeline + ): """ """ @task(task_run_name="gen-1000") @@ -1895,14 +1915,16 @@ def parent_tracking(): return tr_id tr_id = parent_tracking() - tr = await get_task_run(tr_id) + await events_pipeline.process_events() + tr = await prefect_client.read_task_run(tr_id) assert "x" in tr.task_inputs assert "__parents__" in tr.task_inputs # the parent run and upstream 'x' run are the same assert tr.task_inputs["__parents__"][0].id == tr.task_inputs["x"][0].id # the parent run is "gen-1000" gen_id = tr.task_inputs["__parents__"][0].id - gen_tr = await get_task_run(gen_id) + await events_pipeline.process_events() + gen_tr = await prefect_client.read_task_run(gen_id) assert gen_tr.name == "gen-1000" async def test_generator_retries(self): @@ -2019,7 +2041,9 @@ async def g(): async for i in g(return_state=True): pass - async def test_generator_task_states(self, prefect_client: PrefectClient): + async def test_generator_task_states( + self, prefect_client: PrefectClient, events_pipeline + ): """ Test for generator behavior including StopIteration """ @@ -2030,10 +2054,12 @@ async def g(): async for val in g(): tr_id = val - tr = await get_task_run(tr_id) + await events_pipeline.process_events() + tr = await prefect_client.read_task_run(tr_id) assert tr.state.is_running() - tr = await get_task_run(tr_id) + await events_pipeline.process_events() + tr = await prefect_client.read_task_run(tr_id) assert tr.state.is_completed() async def test_generator_task_with_exception(self): @@ -2047,7 +2073,7 @@ async def g(): assert val == 1 async def test_generator_task_with_exception_is_failed( - self, prefect_client: PrefectClient + self, prefect_client: PrefectClient, events_pipeline ): @task async def g(): @@ -2057,11 +2083,13 @@ async def g(): with pytest.raises(ValueError, match="xyz"): async for val in g(): tr_id = val - - tr = await get_task_run(tr_id) + await events_pipeline.process_events() + tr = await prefect_client.read_task_run(tr_id) assert tr.state.is_failed() - async def test_generator_parent_tracking(self, prefect_client: PrefectClient): + async def test_generator_parent_tracking( + self, prefect_client: PrefectClient, events_pipeline + ): """ """ @task(task_run_name="gen-1000") @@ -2079,14 +2107,16 @@ async def parent_tracking(): return tr_id tr_id = await parent_tracking() - tr = await get_task_run(tr_id) + await events_pipeline.process_events() + tr = await prefect_client.read_task_run(tr_id) assert "x" in tr.task_inputs assert "__parents__" in tr.task_inputs # the parent run and upstream 'x' run are the same assert tr.task_inputs["__parents__"][0].id == tr.task_inputs["x"][0].id # the parent run is "gen-1000" gen_id = tr.task_inputs["__parents__"][0].id - gen_tr = await get_task_run(gen_id) + await events_pipeline.process_events() + gen_tr = await prefect_client.read_task_run(gen_id) assert gen_tr.name == "gen-1000" async def test_generator_retries(self): @@ -2165,7 +2195,9 @@ async def g(): class TestRunStateIsDenormalized: - async def test_state_attributes_are_denormalized_async_success(self): + async def test_state_attributes_are_denormalized_async_success( + self, prefect_client, events_pipeline + ): ID = None @task @@ -2183,8 +2215,8 @@ async def foo(): assert task_run.state_name == task_run.state.name == "Running" await run_task_async(foo) - - task_run = await get_task_run(ID) + await events_pipeline.process_events() + task_run = await prefect_client.read_task_run(ID) assert task_run assert task_run.state @@ -2193,7 +2225,9 @@ async def foo(): assert task_run.state_type == task_run.state.type == StateType.COMPLETED assert task_run.state_name == task_run.state.name == "Completed" - async def test_state_attributes_are_denormalized_async_failure(self): + async def test_state_attributes_are_denormalized_async_failure( + self, prefect_client, events_pipeline + ): ID = None @task @@ -2215,7 +2249,8 @@ async def foo(): with pytest.raises(ValueError, match="woops!"): await run_task_async(foo) - task_run = await get_task_run(ID) + await events_pipeline.process_events() + task_run = await prefect_client.read_task_run(ID) assert task_run assert task_run.state @@ -2224,7 +2259,9 @@ async def foo(): assert task_run.state_type == task_run.state.type == StateType.FAILED assert task_run.state_name == task_run.state.name == "Failed" - def test_state_attributes_are_denormalized_sync_success(self): + def test_state_attributes_are_denormalized_sync_success( + self, sync_prefect_client, events_pipeline + ): ID = None @task @@ -2242,8 +2279,8 @@ def foo(): assert task_run.state_name == task_run.state.name == "Running" run_task_sync(foo) - - task_run = get_task_run_sync(ID) + events_pipeline.sync_process_events() + task_run = sync_prefect_client.read_task_run(ID) assert task_run assert task_run.state @@ -2252,7 +2289,9 @@ def foo(): assert task_run.state_type == task_run.state.type == StateType.COMPLETED assert task_run.state_name == task_run.state.name == "Completed" - def test_state_attributes_are_denormalized_sync_failure(self): + def test_state_attributes_are_denormalized_sync_failure( + self, sync_prefect_client, events_pipeline + ): ID = None @task @@ -2274,7 +2313,8 @@ def foo(): with pytest.raises(ValueError, match="woops!"): run_task_sync(foo) - task_run = get_task_run_sync(ID) + events_pipeline.sync_process_events() + task_run = sync_prefect_client.read_task_run(ID) assert task_run assert task_run.state