From 1b7d64cb7d2cd31234c48260374cb59ed5e0ebba Mon Sep 17 00:00:00 2001 From: Chris White Date: Wed, 19 Jun 2024 12:19:47 -0700 Subject: [PATCH] Only persist a tasks result at commit time --- src/prefect/records/result_store.py | 6 ++- src/prefect/results.py | 7 +++- src/prefect/states.py | 16 ++++++-- src/prefect/task_engine.py | 2 + tests/test_flows.py | 58 +++++++++++++++++++++++++++++ 5 files changed, 84 insertions(+), 5 deletions(-) diff --git a/src/prefect/records/result_store.py b/src/prefect/records/result_store.py index 44d9ff2ef5bd..958aada685a8 100644 --- a/src/prefect/records/result_store.py +++ b/src/prefect/records/result_store.py @@ -44,6 +44,10 @@ def read(self, key: str) -> BaseResult: raise ValueError("Result could not be read") def write(self, key: str, value: Any) -> BaseResult: - if isinstance(value, BaseResult): + if isinstance(value, PersistedResult): + # if the value is already a persisted result, write it + value.write(_sync=True) + return value + elif isinstance(value, BaseResult): return value return run_coro_as_sync(self.result_factory.create_result(obj=value, key=key)) diff --git a/src/prefect/results.py b/src/prefect/results.py index 410777ca54db..a9ec3790d374 100644 --- a/src/prefect/results.py +++ b/src/prefect/results.py @@ -431,7 +431,11 @@ def resolve_serializer(serializer: ResultSerializer) -> Serializer: @sync_compatible async def create_result( - self, obj: R, key: Optional[str] = None, expiration: Optional[DateTime] = None + self, + obj: R, + key: Optional[str] = None, + expiration: Optional[DateTime] = None, + defer_persistence: bool = False, ) -> Union[R, "BaseResult[R]"]: """ Create a result type for the given object. @@ -464,6 +468,7 @@ def key_fn(): serializer=self.serializer, cache_object=should_cache_object, expiration=expiration, + defer_persistence=defer_persistence, ) @sync_compatible diff --git a/src/prefect/states.py b/src/prefect/states.py index f1245f9af3e2..0527e0ea6ea5 100644 --- a/src/prefect/states.py +++ b/src/prefect/states.py @@ -209,6 +209,7 @@ async def return_value_to_state( result_factory: ResultFactory, key: Optional[str] = None, expiration: Optional[datetime.datetime] = None, + defer_persistence: bool = False, ) -> State[R]: """ Given a return value from a user's function, create a `State` the run should @@ -242,7 +243,10 @@ async def return_value_to_state( # to update the data to the correct type if not isinstance(state.data, BaseResult): state.data = await result_factory.create_result( - state.data, key=key, expiration=expiration + state.data, + key=key, + expiration=expiration, + defer_persistence=defer_persistence, ) return state @@ -284,7 +288,10 @@ async def return_value_to_state( type=new_state_type, message=message, data=await result_factory.create_result( - retval, key=key, expiration=expiration + retval, + key=key, + expiration=expiration, + defer_persistence=defer_persistence, ), ) @@ -300,7 +307,10 @@ async def return_value_to_state( else: return Completed( data=await result_factory.create_result( - data, key=key, expiration=expiration + data, + key=key, + expiration=expiration, + defer_persistence=defer_persistence, ) ) diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index b4b26ed35739..6c1a42efdfc6 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -310,6 +310,8 @@ def handle_success(self, result: R, transaction: Transaction) -> R: result_factory=result_factory, key=transaction.key, expiration=expiration, + # defer persistence to transaction commit + defer_persistence=True, ) ) transaction.stage( diff --git a/tests/test_flows.py b/tests/test_flows.py index a41cfd77a92c..0f41d5659f9f 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -48,6 +48,7 @@ load_flow_from_flow_run, ) from prefect.logging import get_run_logger +from prefect.results import PersistedResultBlob from prefect.runtime import flow_run as flow_run_ctx from prefect.server.schemas.core import TaskRunResult from prefect.server.schemas.filters import FlowFilter, FlowRunFilter @@ -4280,6 +4281,63 @@ def main(): assert "called" not in data2 assert data1["called"] is True + def test_task_doesnt_persist_prior_to_commit(self, tmp_path): + result_storage = LocalFileSystem(basepath=tmp_path) + + @task(result_storage=result_storage, result_storage_key="task1-result") + def task1(): + pass + + @task(result_storage=result_storage, result_storage_key="task2-result") + def task2(): + raise RuntimeError("oopsie") + + @flow + def main(): + with transaction(): + task1() + task2() + + main(return_state=True) + + with pytest.raises(ValueError, match="does not exist"): + result_storage.read_path("task1-result", _sync=True) + + def test_task_persists_only_at_commit(self, tmp_path): + result_storage = LocalFileSystem(basepath=tmp_path) + + @task(result_storage=result_storage, result_storage_key="task1-result-A") + def task1(): + return dict(some="data") + + @task(result_storage=result_storage, result_storage_key="task2-result-B") + def task2(): + pass + + @flow + def main(): + retval = None + + with transaction(): + task1() + + try: + result_storage.read_path("task1-result-A", _sync=True) + except ValueError as exc: + retval = exc + + task2() + + return retval + + val = main() + + assert isinstance(val, ValueError) + assert "does not exist" in str(val) + content = result_storage.read_path("task1-result-A", _sync=True) + blob = PersistedResultBlob.model_validate_json(content) + assert blob.load() == {"some": "data"} + def test_commit_isnt_called_on_rollback(self): data = {}