From 868a0ea02dbcb706551900de296b77097722c3b8 Mon Sep 17 00:00:00 2001 From: Chris White Date: Wed, 24 Jul 2024 14:21:10 -0700 Subject: [PATCH] Always pass failed state to retry handlers and support async --- src/prefect/task_engine.py | 19 ++++++--- tests/test_task_engine.py | 86 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 5 deletions(-) diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index 70c4718e687f..01bb93b6f573 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -126,8 +126,7 @@ def state(self) -> State: raise ValueError("Task run is not set") return self.task_run.state - @property - def can_retry(self) -> bool: + def can_retry(self, exc: Exception) -> bool: retry_condition: Optional[ Callable[[Task[P, Coroutine[Any, Any, R]], TaskRun, State], bool] ] = self.task.retry_condition_fn @@ -138,9 +137,19 @@ def can_retry(self) -> bool: f"Running `retry_condition_fn` check {retry_condition!r} for task" f" {self.task.name!r}" ) - return not retry_condition or retry_condition( - self.task, self.task_run, self.state + state = Failed( + data=exc, + message=f"Task run encountered unexpected exception: {repr(exc)}", ) + if inspect.iscoroutinefunction(retry_condition): + should_retry = run_coro_as_sync( + retry_condition(self.task, self.task_run, state) + ) + elif inspect.isfunction(retry_condition): + should_retry = retry_condition(self.task, self.task_run, state) + else: + should_retry = not retry_condition + return should_retry except Exception: self.logger.error( ( @@ -418,7 +427,7 @@ def handle_retry(self, exc: Exception) -> bool: - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time. - If the task has no retries left, or the retry condition is not met, return False. """ - if self.retries < self.task.retries and self.can_retry: + if self.retries < self.task.retries and self.can_retry(exc): if self.task.retry_delay_seconds: delay = ( self.task.retry_delay_seconds[ diff --git a/tests/test_task_engine.py b/tests/test_task_engine.py index c69bf8b8a7e2..1e003eaac834 100644 --- a/tests/test_task_engine.py +++ b/tests/test_task_engine.py @@ -935,6 +935,92 @@ async def test_flow(): "Completed", ] + async def test_task_passes_failed_state_to_retry_fn(self): + mock = MagicMock() + exc = SyntaxError("oops") + handler_mock = MagicMock() + + async def handler(task, task_run, state): + handler_mock() + assert state.is_failed() + try: + await state.result() + except SyntaxError: + return True + return False + + @task(retries=3, retry_condition_fn=handler) + async def flaky_function(): + mock() + if mock.call_count == 2: + return True + raise exc + + @flow + async def test_flow(): + return await flaky_function(return_state=True) + + task_run_state = await test_flow() + task_run_id = task_run_state.state_details.task_run_id + + assert task_run_state.is_completed() + assert await task_run_state.result() is True + assert mock.call_count == 2 + assert handler_mock.call_count == 1 + + states = await get_task_run_states(task_run_id) + + state_names = [state.name for state in states] + assert state_names == [ + "Pending", + "Running", + "Retrying", + "Completed", + ] + + async def test_task_passes_failed_state_to_retry_fn_sync(self): + mock = MagicMock() + exc = SyntaxError("oops") + handler_mock = MagicMock() + + def handler(task, task_run, state): + handler_mock() + assert state.is_failed() + try: + state.result() + except SyntaxError: + return True + return False + + @task(retries=3, retry_condition_fn=handler) + def flaky_function(): + mock() + if mock.call_count == 2: + return True + raise exc + + @flow + def test_flow(): + return flaky_function(return_state=True) + + task_run_state = test_flow() + task_run_id = task_run_state.state_details.task_run_id + + assert task_run_state.is_completed() + assert await task_run_state.result() is True + assert mock.call_count == 2 + assert handler_mock.call_count == 1 + + states = await get_task_run_states(task_run_id) + + state_names = [state.name for state in states] + assert state_names == [ + "Pending", + "Running", + "Retrying", + "Completed", + ] + async def test_task_retries_receive_latest_task_run_in_context(self): state_names: List[str] = [] run_counts = []