From 8ecdce468c71dc5b3a967c41fa96d7ab6c42912f Mon Sep 17 00:00:00 2001 From: Alexander Streed Date: Thu, 5 Sep 2024 17:01:36 -0500 Subject: [PATCH] Allow using a `ResultStore` with a `Transaction` (#15247) --- src/prefect/results.py | 8 ++- src/prefect/transactions.py | 28 ++++++---- tests/test_transactions.py | 101 ++++++++++++++++++++++++++++++++++++ 3 files changed, 125 insertions(+), 12 deletions(-) diff --git a/src/prefect/results.py b/src/prefect/results.py index 64b4d8cc7840..46470a93f494 100644 --- a/src/prefect/results.py +++ b/src/prefect/results.py @@ -343,7 +343,7 @@ async def _read(self, key: str, holder: str) -> "ResultRecord": A result record. """ if self.lock_manager is not None and not self.is_lock_holder(key, holder): - self.wait_for_lock(key) + await self.await_for_lock(key) if self.result_storage is None: self.result_storage = await get_default_result_storage() @@ -407,7 +407,10 @@ def create_result_record( return ResultRecord( result=obj, metadata=ResultRecordMetadata( - serializer=self.serializer, expiration=expiration, storage_key=key + serializer=self.serializer, + expiration=expiration, + storage_key=key, + storage_block_id=self.result_storage_block_id, ), ) @@ -752,6 +755,7 @@ class ResultRecordMetadata(BaseModel): expiration: Optional[DateTime] = Field(default=None) serializer: Serializer = Field(default_factory=PickleSerializer) prefect_version: str = Field(default=prefect.__version__) + storage_block_id: Optional[uuid.UUID] = Field(default=None) def dump_bytes(self) -> bytes: """ diff --git a/src/prefect/transactions.py b/src/prefect/transactions.py index 73fbe7c60472..f0268e33b5b2 100644 --- a/src/prefect/transactions.py +++ b/src/prefect/transactions.py @@ -4,7 +4,6 @@ from contextvars import ContextVar, Token from functools import partial from typing import ( - TYPE_CHECKING, Any, Callable, Dict, @@ -22,13 +21,12 @@ from prefect.exceptions import MissingContextError, SerializationError from prefect.logging.loggers import get_logger, get_run_logger from prefect.records import RecordStore +from prefect.records.base import TransactionRecord +from prefect.results import BaseResult, ResultRecord, ResultStore from prefect.utilities.annotations import NotSet from prefect.utilities.collections import AutoEnum from prefect.utilities.engine import _get_hook_name -if TYPE_CHECKING: - from prefect.results import BaseResult - class IsolationLevel(AutoEnum): READ_COMMITTED = AutoEnum.auto() @@ -54,7 +52,7 @@ class Transaction(ContextModel): A base model for transaction state. """ - store: Optional[RecordStore] = None + store: Union[RecordStore, ResultStore, None] = None key: Optional[str] = None children: List["Transaction"] = Field(default_factory=list) commit_mode: Optional[CommitMode] = None @@ -175,10 +173,14 @@ def begin(self): ): self.state = TransactionState.COMMITTED - def read(self) -> Optional["BaseResult"]: + def read(self) -> Union["BaseResult", ResultRecord, None]: if self.store and self.key: record = self.store.read(key=self.key) - if record is not None: + if isinstance(record, ResultRecord): + return record + # for backwards compatibility, if we encounter a transaction record, return the result + # This happens when the transaction is using a `ResultStore` + if isinstance(record, TransactionRecord): return record.result return None @@ -228,7 +230,13 @@ def commit(self) -> bool: self.run_hook(hook, "commit") if self.store and self.key: - self.store.write(key=self.key, result=self._staged_value) + if isinstance(self.store, ResultStore): + if isinstance(self._staged_value, BaseResult): + self.store.write(self.key, self._staged_value.get(_sync=True)) + else: + self.store.write(self.key, self._staged_value) + else: + self.store.write(self.key, self._staged_value) self.state = TransactionState.COMMITTED if ( self.store @@ -279,7 +287,7 @@ def run_hook(self, hook, hook_type: str) -> None: def stage( self, - value: "BaseResult", + value: Union["BaseResult", Any], on_rollback_hooks: Optional[List] = None, on_commit_hooks: Optional[List] = None, ) -> None: @@ -337,7 +345,7 @@ def get_transaction() -> Optional[Transaction]: @contextmanager def transaction( key: Optional[str] = None, - store: Optional[RecordStore] = None, + store: Union[RecordStore, ResultStore, None] = None, commit_mode: Optional[CommitMode] = None, isolation_level: Optional[IsolationLevel] = None, overwrite: bool = False, diff --git a/tests/test_transactions.py b/tests/test_transactions.py index f8e36a668b8a..d62baf22f584 100644 --- a/tests/test_transactions.py +++ b/tests/test_transactions.py @@ -5,6 +5,7 @@ from prefect.filesystems import LocalFileSystem from prefect.flows import flow +from prefect.locking.memory import MemoryLockManager from prefect.records import RecordStore from prefect.records.memory import MemoryRecordStore from prefect.records.result_store import ResultRecordStore @@ -466,6 +467,106 @@ def winning_transaction(): assert record.result == result_1 +class TestWithResultStore: + @pytest.fixture() + def default_storage_setting(self, tmp_path): + name = str(uuid.uuid4()) + LocalFileSystem(basepath=tmp_path).save(name) + with temporary_settings( + { + PREFECT_DEFAULT_RESULT_STORAGE_BLOCK: f"local-file-system/{name}", + PREFECT_TASK_SCHEDULING_DEFAULT_STORAGE_BLOCK: f"local-file-system/{name}", + } + ): + yield + + @pytest.fixture + async def result_store(self, default_storage_setting): + result_store = ResultStore( + persist_result=True, lock_manager=MemoryLockManager() + ) + return result_store + + async def test_basic_transaction(self, result_store): + with transaction(key="test_basic_transaction", store=result_store) as txn: + assert isinstance(txn.store, ResultStore) + txn.stage({"foo": "bar"}) + + record_1 = txn.read() + assert record_1 + assert record_1.result == {"foo": "bar"} + + record_2 = result_store.read("test_basic_transaction") + assert record_2 + assert record_2 == record_1 + assert record_2.metadata.storage_key == "test_basic_transaction" + + async def test_competing_read_transaction(self, result_store): + write_transaction_open = threading.Event() + + def writing_transaction(): + # isolation level is SERIALIZABLE, so a lock will be taken + with transaction( + key="test_competing_read_transaction", + store=result_store, + isolation_level=IsolationLevel.SERIALIZABLE, + ) as txn: + write_transaction_open.set() + txn.stage({"foo": "bar"}) + + thread = threading.Thread(target=writing_transaction) + thread.start() + write_transaction_open.wait() + with transaction( + key="test_competing_read_transaction", store=result_store + ) as txn: + read_result = txn.read() + + assert read_result.result == {"foo": "bar"} + thread.join() + + async def test_competing_write_transaction(self, result_store): + transaction_1_open = threading.Event() + + def winning_transaction(): + with transaction( + key="test_competing_write_transaction", + store=result_store, + isolation_level=IsolationLevel.SERIALIZABLE, + ) as txn: + transaction_1_open.set() + txn.stage({"foo": "bar"}) + + thread = threading.Thread(target=winning_transaction) + thread.start() + transaction_1_open.wait() + with transaction( + key="test_competing_write_transaction", + store=result_store, + isolation_level=IsolationLevel.SERIALIZABLE, + ) as txn: + txn.stage({"fizz": "buzz"}) + + thread.join() + record = result_store.read("test_competing_write_transaction") + assert record + # the first transaction should have written its result + # and the second transaction should not have written on exit + assert record.result == {"foo": "bar"} + + async def test_can_handle_staged_base_result(self, result_store): + result_1 = await result_store.create_result(obj={"foo": "bar"}) + with transaction( + key="test_can_handle_staged_base_result", store=result_store + ) as txn: + txn.stage(result_1) + + record = txn.read() + assert record + assert record.result == {"foo": "bar"} + assert record.metadata.storage_block_id == result_1.storage_block_id + + class TestHooks: def test_get_and_set_data(self): with transaction(key="test") as txn: