diff --git a/src/prefect/client/schemas/objects.py b/src/prefect/client/schemas/objects.py index 3d375ed6ce63..d4d170c74dd2 100644 --- a/src/prefect/client/schemas/objects.py +++ b/src/prefect/client/schemas/objects.py @@ -179,7 +179,10 @@ def result(self: "State[R]", raise_on_failure: bool = False) -> Union[R, Excepti ... def result( - self, raise_on_failure: bool = True, fetch: Optional[bool] = None + self, + raise_on_failure: bool = True, + fetch: Optional[bool] = None, + retry_result_failure: bool = True, ) -> Union[R, Exception]: """ Retrieve the result attached to this state. @@ -191,6 +194,8 @@ def result( results into data. For synchronous users, this defaults to `True`. For asynchronous users, this defaults to `False` for backwards compatibility. + retry_result_failure: a boolean specifying whether to retry on failures to + load the result from result storage Raises: TypeError: If the state is failed but the result is not an exception. @@ -253,7 +258,12 @@ def result( """ from prefect.states import get_state_result - return get_state_result(self, raise_on_failure=raise_on_failure, fetch=fetch) + return get_state_result( + self, + raise_on_failure=raise_on_failure, + fetch=fetch, + retry_result_failure=retry_result_failure, + ) def to_state_create(self): """ diff --git a/src/prefect/states.py b/src/prefect/states.py index 4be6264132da..59a7b721ab19 100644 --- a/src/prefect/states.py +++ b/src/prefect/states.py @@ -34,7 +34,10 @@ def get_state_result( - state: State[R], raise_on_failure: bool = True, fetch: Optional[bool] = None + state: State[R], + raise_on_failure: bool = True, + fetch: Optional[bool] = None, + retry_result_failure: bool = True, ) -> R: """ Get the result from a state. @@ -62,37 +65,50 @@ def get_state_result( return state.data else: - return _get_state_result(state, raise_on_failure=raise_on_failure) + return _get_state_result( + state, + raise_on_failure=raise_on_failure, + retry_result_failure=retry_result_failure, + ) RESULT_READ_MAXIMUM_ATTEMPTS = 10 RESULT_READ_RETRY_DELAY = 0.25 -async def _get_state_result_data_with_retries(state: State[R]) -> R: +async def _get_state_result_data_with_retries( + state: State[R], retry_result_failure: bool = True +) -> R: # Results may be written asynchronously, possibly after their corresponding # state has been written and events have been emitted, so we should give some # grace here about missing results. The exception below could come in the form # of a missing file, a short read, or other types of errors depending on the # result storage backend. - for i in range(1, RESULT_READ_MAXIMUM_ATTEMPTS + 1): + if retry_result_failure is False: + max_attempts = 1 + else: + max_attempts = RESULT_READ_MAXIMUM_ATTEMPTS + + for i in range(1, max_attempts + 1): try: return await state.data.get() except Exception as e: - if i == RESULT_READ_MAXIMUM_ATTEMPTS: + if i == max_attempts: raise logger.debug( "Exception %r while reading result, retry %s/%s in %ss...", e, i, - RESULT_READ_MAXIMUM_ATTEMPTS, + max_attempts, RESULT_READ_RETRY_DELAY, ) await asyncio.sleep(RESULT_READ_RETRY_DELAY) @sync_compatible -async def _get_state_result(state: State[R], raise_on_failure: bool) -> R: +async def _get_state_result( + state: State[R], raise_on_failure: bool, retry_result_failure: bool = True +) -> R: """ Internal implementation for `get_state_result` without async backwards compatibility """ @@ -111,7 +127,9 @@ async def _get_state_result(state: State[R], raise_on_failure: bool) -> R: raise await get_state_exception(state) if isinstance(state.data, BaseResult): - result = await _get_state_result_data_with_retries(state) + result = await _get_state_result_data_with_retries( + state, retry_result_failure=retry_result_failure + ) elif state.data is None: if state.is_failed() or state.is_crashed() or state.is_cancelled(): diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index da01c6a8afe3..947286d73f07 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -249,6 +249,16 @@ def begin_run(self): new_state = Running() state = self.set_state(new_state) + # TODO: this is temporary until the API stops rejecting state transitions + # and the client / transaction store becomes the source of truth + # this is a bandaid caused by the API storing a Completed state with a bad + # result reference that no longer exists + if state.is_completed(): + try: + state.result(retry_result_failure=False, _sync=True) + except Exception: + state = self.set_state(new_state, force=True) + BACKOFF_MAX = 10 backoff_count = 0 diff --git a/tests/test_task_engine.py b/tests/test_task_engine.py index 7b62903d7a1a..9b62422abb32 100644 --- a/tests/test_task_engine.py +++ b/tests/test_task_engine.py @@ -1,5 +1,6 @@ import asyncio import logging +import os import time from datetime import timedelta from pathlib import Path @@ -1243,6 +1244,33 @@ def my_param_flow(x: int, other_val: str): assert third_result not in [first_result, second_result] assert fourth_result not in [first_result, second_result] + async def test_bad_api_result_references_cause_reruns(self, tmp_path: Path): + fs = LocalFileSystem(basepath=tmp_path) + + PAYLOAD = {"return": 42} + + @task(result_storage=fs, result_storage_key="tmp-first") + async def first(): + return PAYLOAD["return"], get_run_context().task_run + + result, task_run = await run_task_async(first) + + assert result == 42 + assert await fs.read_path("tmp-first") + + # delete record + path = fs._resolve_path("tmp-first") + os.unlink(path) + with pytest.raises(ValueError, match="does not exist"): + assert await fs.read_path("tmp-first") + + # rerun with same task run ID + PAYLOAD["return"] = "bar" + result, task_run = await run_task_async(first, task_run=task_run) + + assert result == "bar" + assert await fs.read_path("tmp-first") + class TestGenerators: async def test_generator_task(self):