From 033cb35ef8d228805fb132f9463b9e361e940aa4 Mon Sep 17 00:00:00 2001 From: jakekaplan <40362401+jakekaplan@users.noreply.github.com> Date: Fri, 19 Jul 2024 15:15:20 -0400 Subject: [PATCH] set `end_time` during client-side task orchestration (#14681) --- src/prefect/task_engine.py | 11 +++ tests/test_task_engine.py | 142 ++++++++++++++++++++++++++++++++++++- 2 files changed, 151 insertions(+), 2 deletions(-) diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index 1c2417d9699e..d9c4d9fa2fd2 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -398,6 +398,11 @@ def handle_success(self, result: R, transaction: Transaction) -> R: ) if transaction.is_committed(): terminal_state.name = "Cached" + + if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + if self.task_run.start_time and not self.task_run.end_time: + self.task_run.end_time = terminal_state.timestamp + self.set_state(terminal_state) self._return_value = result return result @@ -458,6 +463,9 @@ def handle_exception(self, exc: Exception) -> None: result_factory=getattr(context, "result_factory", None), ) ) + if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + if self.task_run.start_time and not self.task_run.end_time: + self.task_run.end_time = state.timestamp self.set_state(state) self._raised = exc @@ -480,6 +488,9 @@ def handle_crash(self, exc: BaseException) -> None: state = run_coro_as_sync(exception_to_crashed_state(exc)) self.logger.error(f"Crash detected! {state.message}") self.logger.debug("Crash details:", exc_info=exc) + if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + if self.task_run.start_time and not self.task_run.end_time: + self.task_run.end_time = state.timestamp self.set_state(state, force=True) self._raised = exc diff --git a/tests/test_task_engine.py b/tests/test_task_engine.py index 0cd118ed4c2c..574605af4ad4 100644 --- a/tests/test_task_engine.py +++ b/tests/test_task_engine.py @@ -1108,7 +1108,7 @@ async def my_task(): class TestTaskTimeTracking: - async def test_sync_task_start_time_set_on_running(self): + async def test_sync_task_sets_start_time_on_running(self): @task def foo(): return TaskRunContext.get().task_run.id @@ -1121,7 +1121,7 @@ def foo(): assert len(states) == 1 and states[0].type == StateType.RUNNING assert states[0].timestamp == run.start_time - async def test_async_task_start_time_set_on_running(self): + async def test_async_task_sets_start_time_on_running(self): ID = None @task @@ -1137,6 +1137,144 @@ async def foo(): assert len(states) == 1 and states[0].type == StateType.RUNNING assert states[0].timestamp == run.start_time + async def test_sync_task_sets_end_time_on_completed(self): + @task + def foo(): + return TaskRunContext.get().task_run.id + + task_run_id = run_task_sync(foo) + run = await get_task_run(task_run_id) + + assert run.end_time is not None + states = await get_task_run_states(task_run_id, StateType.COMPLETED) + assert len(states) == 1 and states[0].type == StateType.COMPLETED + assert states[0].timestamp == run.end_time + + async def test_async_task_sets_end_time_on_completed(self): + @task + async def foo(): + return TaskRunContext.get().task_run.id + + task_run_id = await run_task_async(foo) + run = await get_task_run(task_run_id) + + assert run.end_time is not None + states = await get_task_run_states(task_run_id, StateType.COMPLETED) + assert len(states) == 1 and states[0].type == StateType.COMPLETED + assert states[0].timestamp == run.end_time + + async def test_sync_task_sets_end_time_on_failed(self): + ID = None + + @task + def foo(): + nonlocal ID + ID = TaskRunContext.get().task_run.id + raise ValueError("failure!!!") + + with pytest.raises(ValueError): + run_task_sync(foo) + + run = await get_task_run(ID) + + assert run.end_time is not None + states = await get_task_run_states(ID, StateType.FAILED) + assert len(states) == 1 and states[0].type == StateType.FAILED + assert states[0].timestamp == run.end_time + + async def test_async_task_sets_end_time_on_failed(self): + ID = None + + @task + async def foo(): + nonlocal ID + ID = TaskRunContext.get().task_run.id + raise ValueError("failure!!!") + + with pytest.raises(ValueError): + await run_task_async(foo) + + run = await get_task_run(ID) + + assert run.end_time is not None + states = await get_task_run_states(ID, StateType.FAILED) + assert len(states) == 1 and states[0].type == StateType.FAILED + assert states[0].timestamp == run.end_time + + async def test_sync_task_sets_end_time_on_crashed(self): + ID = None + + @task + def foo(): + nonlocal ID + ID = TaskRunContext.get().task_run.id + raise SystemExit + + with pytest.raises(SystemExit): + run_task_sync(foo) + + run = await get_task_run(ID) + + assert run.end_time is not None + states = await get_task_run_states(ID, StateType.CRASHED) + assert len(states) == 1 and states[0].type == StateType.CRASHED + assert states[0].timestamp == run.end_time + + async def test_async_task_sets_end_time_on_crashed(self): + ID = None + + @task + async def foo(): + nonlocal ID + ID = TaskRunContext.get().task_run.id + raise SystemExit + + with pytest.raises(SystemExit): + await run_task_async(foo) + + run = await get_task_run(ID) + + assert run.end_time is not None + states = await get_task_run_states(ID, StateType.CRASHED) + assert len(states) == 1 and states[0].type == StateType.CRASHED + assert states[0].timestamp == run.end_time + + async def test_sync_task_does_not_set_end_time_on_crash_pre_runnning( + self, monkeypatch + ): + monkeypatch.setattr( + TaskRunEngine, "begin_run", MagicMock(side_effect=SystemExit) + ) + + @task + def my_task(): + pass + + with pytest.raises(SystemExit): + my_task() + + run = await get_task_run(task_run_id=None) + + assert run.end_time is None + + async def test_async_task_does_not_set_end_time_on_crash_pre_running( + self, monkeypatch + ): + monkeypatch.setattr( + TaskRunEngine, "begin_run", MagicMock(side_effect=SystemExit) + ) + + @task + async def my_task(): + pass + + with pytest.raises(SystemExit): + await my_task() + + run = await get_task_run(task_run_id=None) + + assert run.end_time is None + class TestRunCountTracking: @pytest.fixture