Skip to content

Commit

Permalink
Only persist a tasks result at commit time (#14157)
Browse files Browse the repository at this point in the history
  • Loading branch information
cicdw authored Jun 19, 2024
1 parent 9538cb4 commit c57e35f
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 27 deletions.
6 changes: 5 additions & 1 deletion src/prefect/records/result_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def read(self, key: str) -> BaseResult:
raise ValueError("Result could not be read")

def write(self, key: str, value: Any) -> BaseResult:
if isinstance(value, BaseResult):
if isinstance(value, PersistedResult):
# if the value is already a persisted result, write it
value.write(_sync=True)
return value
elif isinstance(value, BaseResult):
return value
return run_coro_as_sync(self.result_factory.create_result(obj=value, key=key))
18 changes: 16 additions & 2 deletions src/prefect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,11 @@ def resolve_serializer(serializer: ResultSerializer) -> Serializer:

@sync_compatible
async def create_result(
self, obj: R, key: Optional[str] = None, expiration: Optional[DateTime] = None
self,
obj: R,
key: Optional[str] = None,
expiration: Optional[DateTime] = None,
defer_persistence: bool = False,
) -> Union[R, "BaseResult[R]"]:
"""
Create a result type for the given object.
Expand Down Expand Up @@ -464,6 +468,7 @@ def key_fn():
serializer=self.serializer,
cache_object=should_cache_object,
expiration=expiration,
defer_persistence=defer_persistence,
)

@sync_compatible
Expand Down Expand Up @@ -682,6 +687,9 @@ async def write(self, obj: R = NotSet, client: "PrefectClient" = None) -> None:
await storage_block.write_path(self.storage_key, content=blob.to_bytes())
self._persisted = True

if not self._should_cache_object:
self._cache = NotSet

@classmethod
@sync_compatible
async def create(
Expand Down Expand Up @@ -729,7 +737,7 @@ async def create(
expiration=expiration,
)

if cache_object:
if cache_object and not defer_persistence:
# Attach the object to the result so it's available without deserialization
result._cache_object(
obj, storage_block=storage_block, serializer=serializer
Expand All @@ -739,6 +747,12 @@ async def create(

if not defer_persistence:
await result.write(obj=obj)
else:
# we must cache temporarily to allow for writing later
# the cache will be removed on write
result._cache_object(
obj, storage_block=storage_block, serializer=serializer
)

return result

Expand Down
16 changes: 13 additions & 3 deletions src/prefect/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ async def return_value_to_state(
result_factory: ResultFactory,
key: Optional[str] = None,
expiration: Optional[datetime.datetime] = None,
defer_persistence: bool = False,
) -> State[R]:
"""
Given a return value from a user's function, create a `State` the run should
Expand Down Expand Up @@ -242,7 +243,10 @@ async def return_value_to_state(
# to update the data to the correct type
if not isinstance(state.data, BaseResult):
state.data = await result_factory.create_result(
state.data, key=key, expiration=expiration
state.data,
key=key,
expiration=expiration,
defer_persistence=defer_persistence,
)

return state
Expand Down Expand Up @@ -284,7 +288,10 @@ async def return_value_to_state(
type=new_state_type,
message=message,
data=await result_factory.create_result(
retval, key=key, expiration=expiration
retval,
key=key,
expiration=expiration,
defer_persistence=defer_persistence,
),
)

Expand All @@ -300,7 +307,10 @@ async def return_value_to_state(
else:
return Completed(
data=await result_factory.create_result(
data, key=key, expiration=expiration
data,
key=key,
expiration=expiration,
defer_persistence=defer_persistence,
)
)

Expand Down
2 changes: 2 additions & 0 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ def handle_success(self, result: R, transaction: Transaction) -> R:
result_factory=result_factory,
key=transaction.key,
expiration=expiration,
# defer persistence to transaction commit
defer_persistence=True,
)
)
transaction.stage(
Expand Down
4 changes: 2 additions & 2 deletions src/prefect/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __enter__(self):
if parent:
self.commit_mode = parent.commit_mode
else:
self.commit_mode = CommitMode.EAGER
self.commit_mode = CommitMode.LAZY

# this needs to go before begin, which could set the state to committed
self.state = TransactionState.ACTIVE
Expand Down Expand Up @@ -236,7 +236,7 @@ def get_transaction() -> Optional[Transaction]:
def transaction(
key: Optional[str] = None,
store: Optional[RecordStore] = None,
commit_mode: CommitMode = CommitMode.LAZY,
commit_mode: Optional[CommitMode] = None,
overwrite: bool = False,
) -> Generator[Transaction, None, None]:
"""
Expand Down
58 changes: 58 additions & 0 deletions tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
load_flow_from_flow_run,
)
from prefect.logging import get_run_logger
from prefect.results import PersistedResultBlob
from prefect.runtime import flow_run as flow_run_ctx
from prefect.server.schemas.core import TaskRunResult
from prefect.server.schemas.filters import FlowFilter, FlowRunFilter
Expand Down Expand Up @@ -4280,6 +4281,63 @@ def main():
assert "called" not in data2
assert data1["called"] is True

def test_task_doesnt_persist_prior_to_commit(self, tmp_path):
result_storage = LocalFileSystem(basepath=tmp_path)

@task(result_storage=result_storage, result_storage_key="task1-result")
def task1():
pass

@task(result_storage=result_storage, result_storage_key="task2-result")
def task2():
raise RuntimeError("oopsie")

@flow
def main():
with transaction():
task1()
task2()

main(return_state=True)

with pytest.raises(ValueError, match="does not exist"):
result_storage.read_path("task1-result", _sync=True)

def test_task_persists_only_at_commit(self, tmp_path):
result_storage = LocalFileSystem(basepath=tmp_path)

@task(result_storage=result_storage, result_storage_key="task1-result-A")
def task1():
return dict(some="data")

@task(result_storage=result_storage, result_storage_key="task2-result-B")
def task2():
pass

@flow
def main():
retval = None

with transaction():
task1()

try:
result_storage.read_path("task1-result-A", _sync=True)
except ValueError as exc:
retval = exc

task2()

return retval

val = main()

assert isinstance(val, ValueError)
assert "does not exist" in str(val)
content = result_storage.read_path("task1-result-A", _sync=True)
blob = PersistedResultBlob.model_validate_json(content)
assert blob.load() == {"some": "data"}

def test_commit_isnt_called_on_rollback(self):
data = {}

Expand Down
90 changes: 86 additions & 4 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from prefect.states import State
from prefect.tasks import Task, task, task_input_hash
from prefect.testing.utilities import exceptions_equal
from prefect.transactions import Transaction
from prefect.transactions import CommitMode, Transaction, transaction
from prefect.utilities.annotations import allow_failure, unmapped
from prefect.utilities.asyncutils import run_coro_as_sync
from prefect.utilities.collections import quote
Expand Down Expand Up @@ -4412,7 +4412,13 @@ async def my_flow():
assert await state1.result() == 4
assert await state2.result() == 4

async def test_nested_cache_key_fn_inner_task_cached(self):
async def test_nested_cache_key_fn_inner_task_cached_default(self):
"""
By default, task transactions are LAZY committed and therefore
inner tasks do not persist data (i.e., create a cache) until
the outer task is complete.
"""

@task(cache_key_fn=task_input_hash)
def inner_task(x):
return x * 2
Expand All @@ -4432,12 +4438,54 @@ def my_flow():
assert state.name == "Completed"
inner_state1, inner_state2 = await state.result()
assert inner_state1.name == "Completed"
assert inner_state2.name == "Cached"
assert inner_state2.name == "Completed"

assert await inner_state1.result() == 4
assert await inner_state2.result() == 4

async def test_nested_async_cache_key_fn_inner_task_cached(self):
async def test_nested_cache_key_fn_inner_task_cached_eager(self):
"""
By default, task transactions are LAZY committed and therefore
inner tasks do not persist data (i.e., create a cache) until
the outer task is complete.
This behavior can be modified by using a transaction context manager.
"""

@task(cache_key_fn=task_input_hash)
def inner_task(x):
return x * 2

@task
def outer_task(x):
with transaction(commit_mode=CommitMode.EAGER):
state1 = inner_task(x, return_state=True)
state2 = inner_task(x, return_state=True)
return state1, state2

@flow
def my_flow():
state = outer_task(4, return_state=True)
return state

state = my_flow()
assert state.name == "Completed"
inner_state1, inner_state2 = await state.result()
assert inner_state1.name == "Completed"
assert inner_state2.name == "Cached"

assert await inner_state1.result() == 8
assert await inner_state2.result() == 8

async def test_nested_async_cache_key_fn_inner_task_cached_default(self):
"""
By default, task transactions are LAZY committed and therefore
inner tasks do not persist data (i.e., create a cache) until
the outer task is complete.
This behavior can be modified by using a transaction context manager.
"""

@task(cache_key_fn=task_input_hash)
async def inner_task(x):
return x * 2
Expand All @@ -4453,6 +4501,40 @@ async def my_flow():
state = await outer_task(2, return_state=True)
return state

state = await my_flow()
assert state.name == "Completed"
inner_state1, inner_state2 = await state.result()
assert inner_state1.name == "Completed"
assert inner_state2.name == "Completed"

assert await inner_state1.result() == 4
assert await inner_state2.result() == 4

async def test_nested_async_cache_key_fn_inner_task_cached_eager(self):
"""
By default, task transactions are LAZY committed and therefore
inner tasks do not persist data (i.e., create a cache) until
the outer task is complete.
This behavior can be modified by using a transaction context manager.
"""

@task(cache_key_fn=task_input_hash)
async def inner_task(x):
return x * 2

@task
async def outer_task(x):
with transaction(commit_mode=CommitMode.EAGER):
state1 = await inner_task(x, return_state=True)
state2 = await inner_task(x, return_state=True)
return state1, state2

@flow
async def my_flow():
state = await outer_task(2, return_state=True)
return state

state = await my_flow()
assert state.name == "Completed"
inner_state1, inner_state2 = await state.result()
Expand Down
30 changes: 15 additions & 15 deletions tests/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,25 @@ def test_get_parent_with_parent(self):


class TestCommitMode:
def test_txns_auto_commit(self):
with Transaction() as txn:
def test_txns_dont_auto_commit(self):
with Transaction(key="outer") as outer:
assert not outer.is_committed()

with Transaction(key="inner") as inner:
pass

assert not inner.is_committed()

assert outer.is_committed()
assert inner.is_committed()

def test_txns_auto_commit_in_eager(self):
with Transaction(commit_mode=CommitMode.EAGER) as txn:
assert txn.is_active()
assert not txn.is_committed()
assert txn.is_committed()

with Transaction(key="outer") as outer:
with Transaction(key="outer", commit_mode=CommitMode.EAGER) as outer:
assert not outer.is_committed()
with Transaction(key="inner") as inner:
pass
Expand All @@ -110,18 +122,6 @@ def test_txns_dont_commit_on_rollback(self):
assert not txn.is_committed()
assert txn.is_rolled_back()

def test_txns_dont_auto_commit_with_lazy_parent(self):
with Transaction(key="outer", commit_mode=CommitMode.LAZY) as outer:
assert not outer.is_committed()

with Transaction(key="inner") as inner:
pass

assert not inner.is_committed()

assert outer.is_committed()
assert inner.is_committed()

def test_txns_commit_with_lazy_parent_if_eager(self):
with Transaction(key="outer", commit_mode=CommitMode.LAZY) as outer:
assert not outer.is_committed()
Expand Down

0 comments on commit c57e35f

Please sign in to comment.