From c57e35fb99584febae65cbab91e1a2475ae8b0ee Mon Sep 17 00:00:00 2001 From: Chris White Date: Wed, 19 Jun 2024 13:44:43 -0700 Subject: [PATCH] Only persist a tasks result at commit time (#14157) --- src/prefect/records/result_store.py | 6 +- src/prefect/results.py | 18 +++++- src/prefect/states.py | 16 ++++- src/prefect/task_engine.py | 2 + src/prefect/transactions.py | 4 +- tests/test_flows.py | 58 +++++++++++++++++++ tests/test_tasks.py | 90 +++++++++++++++++++++++++++-- tests/test_transactions.py | 30 +++++----- 8 files changed, 197 insertions(+), 27 deletions(-) diff --git a/src/prefect/records/result_store.py b/src/prefect/records/result_store.py index 44d9ff2ef5bd..958aada685a8 100644 --- a/src/prefect/records/result_store.py +++ b/src/prefect/records/result_store.py @@ -44,6 +44,10 @@ def read(self, key: str) -> BaseResult: raise ValueError("Result could not be read") def write(self, key: str, value: Any) -> BaseResult: - if isinstance(value, BaseResult): + if isinstance(value, PersistedResult): + # if the value is already a persisted result, write it + value.write(_sync=True) + return value + elif isinstance(value, BaseResult): return value return run_coro_as_sync(self.result_factory.create_result(obj=value, key=key)) diff --git a/src/prefect/results.py b/src/prefect/results.py index 410777ca54db..e273f8c7234c 100644 --- a/src/prefect/results.py +++ b/src/prefect/results.py @@ -431,7 +431,11 @@ def resolve_serializer(serializer: ResultSerializer) -> Serializer: @sync_compatible async def create_result( - self, obj: R, key: Optional[str] = None, expiration: Optional[DateTime] = None + self, + obj: R, + key: Optional[str] = None, + expiration: Optional[DateTime] = None, + defer_persistence: bool = False, ) -> Union[R, "BaseResult[R]"]: """ Create a result type for the given object. @@ -464,6 +468,7 @@ def key_fn(): serializer=self.serializer, cache_object=should_cache_object, expiration=expiration, + defer_persistence=defer_persistence, ) @sync_compatible @@ -682,6 +687,9 @@ async def write(self, obj: R = NotSet, client: "PrefectClient" = None) -> None: await storage_block.write_path(self.storage_key, content=blob.to_bytes()) self._persisted = True + if not self._should_cache_object: + self._cache = NotSet + @classmethod @sync_compatible async def create( @@ -729,7 +737,7 @@ async def create( expiration=expiration, ) - if cache_object: + if cache_object and not defer_persistence: # Attach the object to the result so it's available without deserialization result._cache_object( obj, storage_block=storage_block, serializer=serializer @@ -739,6 +747,12 @@ async def create( if not defer_persistence: await result.write(obj=obj) + else: + # we must cache temporarily to allow for writing later + # the cache will be removed on write + result._cache_object( + obj, storage_block=storage_block, serializer=serializer + ) return result diff --git a/src/prefect/states.py b/src/prefect/states.py index f1245f9af3e2..0527e0ea6ea5 100644 --- a/src/prefect/states.py +++ b/src/prefect/states.py @@ -209,6 +209,7 @@ async def return_value_to_state( result_factory: ResultFactory, key: Optional[str] = None, expiration: Optional[datetime.datetime] = None, + defer_persistence: bool = False, ) -> State[R]: """ Given a return value from a user's function, create a `State` the run should @@ -242,7 +243,10 @@ async def return_value_to_state( # to update the data to the correct type if not isinstance(state.data, BaseResult): state.data = await result_factory.create_result( - state.data, key=key, expiration=expiration + state.data, + key=key, + expiration=expiration, + defer_persistence=defer_persistence, ) return state @@ -284,7 +288,10 @@ async def return_value_to_state( type=new_state_type, message=message, data=await result_factory.create_result( - retval, key=key, expiration=expiration + retval, + key=key, + expiration=expiration, + defer_persistence=defer_persistence, ), ) @@ -300,7 +307,10 @@ async def return_value_to_state( else: return Completed( data=await result_factory.create_result( - data, key=key, expiration=expiration + data, + key=key, + expiration=expiration, + defer_persistence=defer_persistence, ) ) diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index b4b26ed35739..6c1a42efdfc6 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -310,6 +310,8 @@ def handle_success(self, result: R, transaction: Transaction) -> R: result_factory=result_factory, key=transaction.key, expiration=expiration, + # defer persistence to transaction commit + defer_persistence=True, ) ) transaction.stage( diff --git a/src/prefect/transactions.py b/src/prefect/transactions.py index 439c8359549c..497bf20fce43 100644 --- a/src/prefect/transactions.py +++ b/src/prefect/transactions.py @@ -89,7 +89,7 @@ def __enter__(self): if parent: self.commit_mode = parent.commit_mode else: - self.commit_mode = CommitMode.EAGER + self.commit_mode = CommitMode.LAZY # this needs to go before begin, which could set the state to committed self.state = TransactionState.ACTIVE @@ -236,7 +236,7 @@ def get_transaction() -> Optional[Transaction]: def transaction( key: Optional[str] = None, store: Optional[RecordStore] = None, - commit_mode: CommitMode = CommitMode.LAZY, + commit_mode: Optional[CommitMode] = None, overwrite: bool = False, ) -> Generator[Transaction, None, None]: """ diff --git a/tests/test_flows.py b/tests/test_flows.py index a41cfd77a92c..0f41d5659f9f 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -48,6 +48,7 @@ load_flow_from_flow_run, ) from prefect.logging import get_run_logger +from prefect.results import PersistedResultBlob from prefect.runtime import flow_run as flow_run_ctx from prefect.server.schemas.core import TaskRunResult from prefect.server.schemas.filters import FlowFilter, FlowRunFilter @@ -4280,6 +4281,63 @@ def main(): assert "called" not in data2 assert data1["called"] is True + def test_task_doesnt_persist_prior_to_commit(self, tmp_path): + result_storage = LocalFileSystem(basepath=tmp_path) + + @task(result_storage=result_storage, result_storage_key="task1-result") + def task1(): + pass + + @task(result_storage=result_storage, result_storage_key="task2-result") + def task2(): + raise RuntimeError("oopsie") + + @flow + def main(): + with transaction(): + task1() + task2() + + main(return_state=True) + + with pytest.raises(ValueError, match="does not exist"): + result_storage.read_path("task1-result", _sync=True) + + def test_task_persists_only_at_commit(self, tmp_path): + result_storage = LocalFileSystem(basepath=tmp_path) + + @task(result_storage=result_storage, result_storage_key="task1-result-A") + def task1(): + return dict(some="data") + + @task(result_storage=result_storage, result_storage_key="task2-result-B") + def task2(): + pass + + @flow + def main(): + retval = None + + with transaction(): + task1() + + try: + result_storage.read_path("task1-result-A", _sync=True) + except ValueError as exc: + retval = exc + + task2() + + return retval + + val = main() + + assert isinstance(val, ValueError) + assert "does not exist" in str(val) + content = result_storage.read_path("task1-result-A", _sync=True) + blob = PersistedResultBlob.model_validate_json(content) + assert blob.load() == {"some": "data"} + def test_commit_isnt_called_on_rollback(self): data = {} diff --git a/tests/test_tasks.py b/tests/test_tasks.py index e410c8c81614..a3cfab81ece1 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -44,7 +44,7 @@ from prefect.states import State from prefect.tasks import Task, task, task_input_hash from prefect.testing.utilities import exceptions_equal -from prefect.transactions import Transaction +from prefect.transactions import CommitMode, Transaction, transaction from prefect.utilities.annotations import allow_failure, unmapped from prefect.utilities.asyncutils import run_coro_as_sync from prefect.utilities.collections import quote @@ -4412,7 +4412,13 @@ async def my_flow(): assert await state1.result() == 4 assert await state2.result() == 4 - async def test_nested_cache_key_fn_inner_task_cached(self): + async def test_nested_cache_key_fn_inner_task_cached_default(self): + """ + By default, task transactions are LAZY committed and therefore + inner tasks do not persist data (i.e., create a cache) until + the outer task is complete. + """ + @task(cache_key_fn=task_input_hash) def inner_task(x): return x * 2 @@ -4432,12 +4438,54 @@ def my_flow(): assert state.name == "Completed" inner_state1, inner_state2 = await state.result() assert inner_state1.name == "Completed" - assert inner_state2.name == "Cached" + assert inner_state2.name == "Completed" assert await inner_state1.result() == 4 assert await inner_state2.result() == 4 - async def test_nested_async_cache_key_fn_inner_task_cached(self): + async def test_nested_cache_key_fn_inner_task_cached_eager(self): + """ + By default, task transactions are LAZY committed and therefore + inner tasks do not persist data (i.e., create a cache) until + the outer task is complete. + + This behavior can be modified by using a transaction context manager. + """ + + @task(cache_key_fn=task_input_hash) + def inner_task(x): + return x * 2 + + @task + def outer_task(x): + with transaction(commit_mode=CommitMode.EAGER): + state1 = inner_task(x, return_state=True) + state2 = inner_task(x, return_state=True) + return state1, state2 + + @flow + def my_flow(): + state = outer_task(4, return_state=True) + return state + + state = my_flow() + assert state.name == "Completed" + inner_state1, inner_state2 = await state.result() + assert inner_state1.name == "Completed" + assert inner_state2.name == "Cached" + + assert await inner_state1.result() == 8 + assert await inner_state2.result() == 8 + + async def test_nested_async_cache_key_fn_inner_task_cached_default(self): + """ + By default, task transactions are LAZY committed and therefore + inner tasks do not persist data (i.e., create a cache) until + the outer task is complete. + + This behavior can be modified by using a transaction context manager. + """ + @task(cache_key_fn=task_input_hash) async def inner_task(x): return x * 2 @@ -4453,6 +4501,40 @@ async def my_flow(): state = await outer_task(2, return_state=True) return state + state = await my_flow() + assert state.name == "Completed" + inner_state1, inner_state2 = await state.result() + assert inner_state1.name == "Completed" + assert inner_state2.name == "Completed" + + assert await inner_state1.result() == 4 + assert await inner_state2.result() == 4 + + async def test_nested_async_cache_key_fn_inner_task_cached_eager(self): + """ + By default, task transactions are LAZY committed and therefore + inner tasks do not persist data (i.e., create a cache) until + the outer task is complete. + + This behavior can be modified by using a transaction context manager. + """ + + @task(cache_key_fn=task_input_hash) + async def inner_task(x): + return x * 2 + + @task + async def outer_task(x): + with transaction(commit_mode=CommitMode.EAGER): + state1 = await inner_task(x, return_state=True) + state2 = await inner_task(x, return_state=True) + return state1, state2 + + @flow + async def my_flow(): + state = await outer_task(2, return_state=True) + return state + state = await my_flow() assert state.name == "Completed" inner_state1, inner_state2 = await state.result() diff --git a/tests/test_transactions.py b/tests/test_transactions.py index fc0d8235a13d..d687fc540657 100644 --- a/tests/test_transactions.py +++ b/tests/test_transactions.py @@ -82,13 +82,25 @@ def test_get_parent_with_parent(self): class TestCommitMode: - def test_txns_auto_commit(self): - with Transaction() as txn: + def test_txns_dont_auto_commit(self): + with Transaction(key="outer") as outer: + assert not outer.is_committed() + + with Transaction(key="inner") as inner: + pass + + assert not inner.is_committed() + + assert outer.is_committed() + assert inner.is_committed() + + def test_txns_auto_commit_in_eager(self): + with Transaction(commit_mode=CommitMode.EAGER) as txn: assert txn.is_active() assert not txn.is_committed() assert txn.is_committed() - with Transaction(key="outer") as outer: + with Transaction(key="outer", commit_mode=CommitMode.EAGER) as outer: assert not outer.is_committed() with Transaction(key="inner") as inner: pass @@ -110,18 +122,6 @@ def test_txns_dont_commit_on_rollback(self): assert not txn.is_committed() assert txn.is_rolled_back() - def test_txns_dont_auto_commit_with_lazy_parent(self): - with Transaction(key="outer", commit_mode=CommitMode.LAZY) as outer: - assert not outer.is_committed() - - with Transaction(key="inner") as inner: - pass - - assert not inner.is_committed() - - assert outer.is_committed() - assert inner.is_committed() - def test_txns_commit_with_lazy_parent_if_eager(self): with Transaction(key="outer", commit_mode=CommitMode.LAZY) as outer: assert not outer.is_committed()