Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always pass failed state to retry handlers and support async #14746

Merged
merged 1 commit into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ def state(self) -> State:
raise ValueError("Task run is not set")
return self.task_run.state

@property
def can_retry(self) -> bool:
def can_retry(self, exc: Exception) -> bool:
retry_condition: Optional[
Callable[[Task[P, Coroutine[Any, Any, R]], TaskRun, State], bool]
] = self.task.retry_condition_fn
Expand All @@ -138,9 +137,19 @@ def can_retry(self) -> bool:
f"Running `retry_condition_fn` check {retry_condition!r} for task"
f" {self.task.name!r}"
)
return not retry_condition or retry_condition(
self.task, self.task_run, self.state
state = Failed(
data=exc,
message=f"Task run encountered unexpected exception: {repr(exc)}",
)
if inspect.iscoroutinefunction(retry_condition):
should_retry = run_coro_as_sync(
retry_condition(self.task, self.task_run, state)
)
elif inspect.isfunction(retry_condition):
should_retry = retry_condition(self.task, self.task_run, state)
else:
should_retry = not retry_condition
return should_retry
except Exception:
self.logger.error(
(
Expand Down Expand Up @@ -418,7 +427,7 @@ def handle_retry(self, exc: Exception) -> bool:
- If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
- If the task has no retries left, or the retry condition is not met, return False.
"""
if self.retries < self.task.retries and self.can_retry:
if self.retries < self.task.retries and self.can_retry(exc):
if self.task.retry_delay_seconds:
delay = (
self.task.retry_delay_seconds[
Expand Down
86 changes: 86 additions & 0 deletions tests/test_task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,92 @@ async def test_flow():
"Completed",
]

async def test_task_passes_failed_state_to_retry_fn(self):
mock = MagicMock()
exc = SyntaxError("oops")
handler_mock = MagicMock()

async def handler(task, task_run, state):
handler_mock()
assert state.is_failed()
try:
await state.result()
except SyntaxError:
return True
return False

@task(retries=3, retry_condition_fn=handler)
async def flaky_function():
mock()
if mock.call_count == 2:
return True
raise exc

@flow
async def test_flow():
return await flaky_function(return_state=True)

task_run_state = await test_flow()
task_run_id = task_run_state.state_details.task_run_id

assert task_run_state.is_completed()
assert await task_run_state.result() is True
assert mock.call_count == 2
assert handler_mock.call_count == 1

states = await get_task_run_states(task_run_id)

state_names = [state.name for state in states]
assert state_names == [
"Pending",
"Running",
"Retrying",
"Completed",
]

async def test_task_passes_failed_state_to_retry_fn_sync(self):
mock = MagicMock()
exc = SyntaxError("oops")
handler_mock = MagicMock()

def handler(task, task_run, state):
handler_mock()
assert state.is_failed()
try:
state.result()
except SyntaxError:
return True
return False

@task(retries=3, retry_condition_fn=handler)
def flaky_function():
mock()
if mock.call_count == 2:
return True
raise exc

@flow
def test_flow():
return flaky_function(return_state=True)

task_run_state = test_flow()
task_run_id = task_run_state.state_details.task_run_id

assert task_run_state.is_completed()
assert await task_run_state.result() is True
assert mock.call_count == 2
assert handler_mock.call_count == 1

states = await get_task_run_states(task_run_id)

state_names = [state.name for state in states]
assert state_names == [
"Pending",
"Running",
"Retrying",
"Completed",
]

async def test_task_retries_receive_latest_task_run_in_context(self):
state_names: List[str] = []
run_counts = []
Expand Down
Loading