diff --git a/src/prefect/transactions.py b/src/prefect/transactions.py index 9274db2bdf63..70919b27ce9b 100644 --- a/src/prefect/transactions.py +++ b/src/prefect/transactions.py @@ -4,6 +4,7 @@ from typing import ( Any, Callable, + Dict, Generator, List, Optional, @@ -11,7 +12,7 @@ Union, ) -from pydantic import Field +from pydantic import Field, PrivateAttr from typing_extensions import Self from prefect.context import ContextModel, FlowRunContext, TaskRunContext @@ -64,9 +65,18 @@ class Transaction(ContextModel): ) overwrite: bool = False logger: Union[logging.Logger, logging.LoggerAdapter, None] = None + _stored_values: Dict[str, Any] = PrivateAttr(default_factory=dict) _staged_value: Any = None __var__: ContextVar = ContextVar("transaction") + def set(self, name: str, value: Any) -> None: + self._stored_values[name] = value + + def get(self, name: str) -> Any: + if name not in self._stored_values: + raise ValueError(f"Could not retrieve value for unknown key: {name}") + return self._stored_values.get(name) + def is_committed(self) -> bool: return self.state == TransactionState.COMMITTED diff --git a/tests/test_flows.py b/tests/test_flows.py index 6560119c9e22..54f6dd8c3ad3 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -77,7 +77,7 @@ exceptions_equal, get_most_recent_flow_run, ) -from prefect.transactions import transaction +from prefect.transactions import get_transaction, transaction from prefect.types.entrypoint import EntrypointType from prefect.utilities.annotations import allow_failure, quote from prefect.utilities.callables import parameter_schema @@ -4416,6 +4416,37 @@ def main(): assert data2["called"] is True assert data1["called"] is True + def test_isolated_shared_state_on_txn_between_tasks(self): + data1, data2 = {}, {} + + @task + def task1(): + get_transaction().set("task", 1) + + @task1.on_rollback + def rollback(txn): + data1["hook"] = txn.get("task") + + @task + def task2(): + get_transaction().set("task", 2) + + @task2.on_rollback + def rollback2(txn): + data2["hook"] = txn.get("task") + + @flow + def main(): + with transaction(): + task1() + task2() + raise ValueError("oopsie") + + main(return_state=True) + + assert data2["hook"] == 2 + assert data1["hook"] == 1 + def test_task_failure_causes_previous_to_rollback(self): data1, data2 = {}, {} @@ -4904,8 +4935,8 @@ def test_annotations_and_defaults_rely_on_imports(self, tmp_path: Path): source_code = dedent( """ import pendulum - import datetime - from prefect import flow + import datetime + from prefect import flow @flow def f( @@ -5033,7 +5064,7 @@ def get_model() -> BaseModel: def f( param: Optional[MyModel] = None, ) -> None: - return MyModel() + return MyModel() """ ) tmp_path.joinpath("test.py").write_text(source_code) @@ -5052,7 +5083,7 @@ def test_raises_name_error_when_loaded_flow_cannot_run(self, tmp_path): @flow(description="Says woof!") def dog(): - return not_a_function('dog') + return not_a_function('dog') """ ) diff --git a/tests/test_transactions.py b/tests/test_transactions.py index baec004c9912..904ab2c2ff95 100644 --- a/tests/test_transactions.py +++ b/tests/test_transactions.py @@ -340,3 +340,24 @@ async def test_task(): return result await test_task() == {"foo": "bar"} + + +class TestHooks: + def test_get_and_set_data(self): + with transaction(key="test") as txn: + txn.set("x", 42) + assert txn.get("x") == 42 + + def test_get_and_set_data_in_nested_context(self): + with transaction(key="test") as top: + top.set("key", 42) + with transaction(key="nested") as inner: + inner.set("key", "string") + assert inner.get("key") == "string" + assert top.get("key") == 42 + assert top.get("key") == 42 + + def test_get_raises_on_unknown(self): + with transaction(key="test") as txn: + with pytest.raises(ValueError, match="foobar"): + txn.get("foobar")