Skip to content

Commit

Permalink
Sharing state with transaction hooks (#14895)
Browse files Browse the repository at this point in the history
Co-authored-by: nate nowack <[email protected]>
  • Loading branch information
cicdw and zzstoatzz authored Aug 12, 2024
1 parent 8476f52 commit 9d39d68
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 6 deletions.
12 changes: 11 additions & 1 deletion src/prefect/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Optional,
Type,
Union,
)

from pydantic import Field
from pydantic import Field, PrivateAttr
from typing_extensions import Self

from prefect.context import ContextModel, FlowRunContext, TaskRunContext
Expand Down Expand Up @@ -64,9 +65,18 @@ class Transaction(ContextModel):
)
overwrite: bool = False
logger: Union[logging.Logger, logging.LoggerAdapter, None] = None
_stored_values: Dict[str, Any] = PrivateAttr(default_factory=dict)
_staged_value: Any = None
__var__: ContextVar = ContextVar("transaction")

def set(self, name: str, value: Any) -> None:
self._stored_values[name] = value

def get(self, name: str) -> Any:
if name not in self._stored_values:
raise ValueError(f"Could not retrieve value for unknown key: {name}")
return self._stored_values.get(name)

def is_committed(self) -> bool:
return self.state == TransactionState.COMMITTED

Expand Down
41 changes: 36 additions & 5 deletions tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
exceptions_equal,
get_most_recent_flow_run,
)
from prefect.transactions import transaction
from prefect.transactions import get_transaction, transaction
from prefect.types.entrypoint import EntrypointType
from prefect.utilities.annotations import allow_failure, quote
from prefect.utilities.callables import parameter_schema
Expand Down Expand Up @@ -4416,6 +4416,37 @@ def main():
assert data2["called"] is True
assert data1["called"] is True

def test_isolated_shared_state_on_txn_between_tasks(self):
data1, data2 = {}, {}

@task
def task1():
get_transaction().set("task", 1)

@task1.on_rollback
def rollback(txn):
data1["hook"] = txn.get("task")

@task
def task2():
get_transaction().set("task", 2)

@task2.on_rollback
def rollback2(txn):
data2["hook"] = txn.get("task")

@flow
def main():
with transaction():
task1()
task2()
raise ValueError("oopsie")

main(return_state=True)

assert data2["hook"] == 2
assert data1["hook"] == 1

def test_task_failure_causes_previous_to_rollback(self):
data1, data2 = {}, {}

Expand Down Expand Up @@ -4904,8 +4935,8 @@ def test_annotations_and_defaults_rely_on_imports(self, tmp_path: Path):
source_code = dedent(
"""
import pendulum
import datetime
from prefect import flow
import datetime
from prefect import flow
@flow
def f(
Expand Down Expand Up @@ -5033,7 +5064,7 @@ def get_model() -> BaseModel:
def f(
param: Optional[MyModel] = None,
) -> None:
return MyModel()
return MyModel()
"""
)
tmp_path.joinpath("test.py").write_text(source_code)
Expand All @@ -5052,7 +5083,7 @@ def test_raises_name_error_when_loaded_flow_cannot_run(self, tmp_path):
@flow(description="Says woof!")
def dog():
return not_a_function('dog')
return not_a_function('dog')
"""
)

Expand Down
21 changes: 21 additions & 0 deletions tests/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,24 @@ async def test_task():
return result

await test_task() == {"foo": "bar"}


class TestHooks:
def test_get_and_set_data(self):
with transaction(key="test") as txn:
txn.set("x", 42)
assert txn.get("x") == 42

def test_get_and_set_data_in_nested_context(self):
with transaction(key="test") as top:
top.set("key", 42)
with transaction(key="nested") as inner:
inner.set("key", "string")
assert inner.get("key") == "string"
assert top.get("key") == 42
assert top.get("key") == 42

def test_get_raises_on_unknown(self):
with transaction(key="test") as txn:
with pytest.raises(ValueError, match="foobar"):
txn.get("foobar")

0 comments on commit 9d39d68

Please sign in to comment.