Skip to content

Commit

Permalink
Only persist a tasks result at commit time
Browse files Browse the repository at this point in the history
  • Loading branch information
cicdw committed Jun 19, 2024
1 parent 4b29ba6 commit 1b7d64c
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 5 deletions.
6 changes: 5 additions & 1 deletion src/prefect/records/result_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def read(self, key: str) -> BaseResult:
raise ValueError("Result could not be read")

def write(self, key: str, value: Any) -> BaseResult:
if isinstance(value, BaseResult):
if isinstance(value, PersistedResult):
# if the value is already a persisted result, write it
value.write(_sync=True)
return value
elif isinstance(value, BaseResult):
return value
return run_coro_as_sync(self.result_factory.create_result(obj=value, key=key))
7 changes: 6 additions & 1 deletion src/prefect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,11 @@ def resolve_serializer(serializer: ResultSerializer) -> Serializer:

@sync_compatible
async def create_result(
self, obj: R, key: Optional[str] = None, expiration: Optional[DateTime] = None
self,
obj: R,
key: Optional[str] = None,
expiration: Optional[DateTime] = None,
defer_persistence: bool = False,
) -> Union[R, "BaseResult[R]"]:
"""
Create a result type for the given object.
Expand Down Expand Up @@ -464,6 +468,7 @@ def key_fn():
serializer=self.serializer,
cache_object=should_cache_object,
expiration=expiration,
defer_persistence=defer_persistence,
)

@sync_compatible
Expand Down
16 changes: 13 additions & 3 deletions src/prefect/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ async def return_value_to_state(
result_factory: ResultFactory,
key: Optional[str] = None,
expiration: Optional[datetime.datetime] = None,
defer_persistence: bool = False,
) -> State[R]:
"""
Given a return value from a user's function, create a `State` the run should
Expand Down Expand Up @@ -242,7 +243,10 @@ async def return_value_to_state(
# to update the data to the correct type
if not isinstance(state.data, BaseResult):
state.data = await result_factory.create_result(
state.data, key=key, expiration=expiration
state.data,
key=key,
expiration=expiration,
defer_persistence=defer_persistence,
)

return state
Expand Down Expand Up @@ -284,7 +288,10 @@ async def return_value_to_state(
type=new_state_type,
message=message,
data=await result_factory.create_result(
retval, key=key, expiration=expiration
retval,
key=key,
expiration=expiration,
defer_persistence=defer_persistence,
),
)

Expand All @@ -300,7 +307,10 @@ async def return_value_to_state(
else:
return Completed(
data=await result_factory.create_result(
data, key=key, expiration=expiration
data,
key=key,
expiration=expiration,
defer_persistence=defer_persistence,
)
)

Expand Down
2 changes: 2 additions & 0 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ def handle_success(self, result: R, transaction: Transaction) -> R:
result_factory=result_factory,
key=transaction.key,
expiration=expiration,
# defer persistence to transaction commit
defer_persistence=True,
)
)
transaction.stage(
Expand Down
58 changes: 58 additions & 0 deletions tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
load_flow_from_flow_run,
)
from prefect.logging import get_run_logger
from prefect.results import PersistedResultBlob
from prefect.runtime import flow_run as flow_run_ctx
from prefect.server.schemas.core import TaskRunResult
from prefect.server.schemas.filters import FlowFilter, FlowRunFilter
Expand Down Expand Up @@ -4280,6 +4281,63 @@ def main():
assert "called" not in data2
assert data1["called"] is True

def test_task_doesnt_persist_prior_to_commit(self, tmp_path):
result_storage = LocalFileSystem(basepath=tmp_path)

@task(result_storage=result_storage, result_storage_key="task1-result")
def task1():
pass

@task(result_storage=result_storage, result_storage_key="task2-result")
def task2():
raise RuntimeError("oopsie")

@flow
def main():
with transaction():
task1()
task2()

main(return_state=True)

with pytest.raises(ValueError, match="does not exist"):
result_storage.read_path("task1-result", _sync=True)

def test_task_persists_only_at_commit(self, tmp_path):
result_storage = LocalFileSystem(basepath=tmp_path)

@task(result_storage=result_storage, result_storage_key="task1-result-A")
def task1():
return dict(some="data")

@task(result_storage=result_storage, result_storage_key="task2-result-B")
def task2():
pass

@flow
def main():
retval = None

with transaction():
task1()

try:
result_storage.read_path("task1-result-A", _sync=True)
except ValueError as exc:
retval = exc

task2()

return retval

val = main()

assert isinstance(val, ValueError)
assert "does not exist" in str(val)
content = result_storage.read_path("task1-result-A", _sync=True)
blob = PersistedResultBlob.model_validate_json(content)
assert blob.load() == {"some": "data"}

def test_commit_isnt_called_on_rollback(self):
data = {}

Expand Down

0 comments on commit 1b7d64c

Please sign in to comment.