Skip to content

Commit

Permalink
Allow using a ResultStore with a Transaction (#15247)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle authored Sep 5, 2024
1 parent 76c4323 commit 8ecdce4
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 12 deletions.
8 changes: 6 additions & 2 deletions src/prefect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ async def _read(self, key: str, holder: str) -> "ResultRecord":
A result record.
"""
if self.lock_manager is not None and not self.is_lock_holder(key, holder):
self.wait_for_lock(key)
await self.await_for_lock(key)

if self.result_storage is None:
self.result_storage = await get_default_result_storage()
Expand Down Expand Up @@ -407,7 +407,10 @@ def create_result_record(
return ResultRecord(
result=obj,
metadata=ResultRecordMetadata(
serializer=self.serializer, expiration=expiration, storage_key=key
serializer=self.serializer,
expiration=expiration,
storage_key=key,
storage_block_id=self.result_storage_block_id,
),
)

Expand Down Expand Up @@ -752,6 +755,7 @@ class ResultRecordMetadata(BaseModel):
expiration: Optional[DateTime] = Field(default=None)
serializer: Serializer = Field(default_factory=PickleSerializer)
prefect_version: str = Field(default=prefect.__version__)
storage_block_id: Optional[uuid.UUID] = Field(default=None)

def dump_bytes(self) -> bytes:
"""
Expand Down
28 changes: 18 additions & 10 deletions src/prefect/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from contextvars import ContextVar, Token
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand All @@ -22,13 +21,12 @@
from prefect.exceptions import MissingContextError, SerializationError
from prefect.logging.loggers import get_logger, get_run_logger
from prefect.records import RecordStore
from prefect.records.base import TransactionRecord
from prefect.results import BaseResult, ResultRecord, ResultStore
from prefect.utilities.annotations import NotSet
from prefect.utilities.collections import AutoEnum
from prefect.utilities.engine import _get_hook_name

if TYPE_CHECKING:
from prefect.results import BaseResult


class IsolationLevel(AutoEnum):
READ_COMMITTED = AutoEnum.auto()
Expand All @@ -54,7 +52,7 @@ class Transaction(ContextModel):
A base model for transaction state.
"""

store: Optional[RecordStore] = None
store: Union[RecordStore, ResultStore, None] = None
key: Optional[str] = None
children: List["Transaction"] = Field(default_factory=list)
commit_mode: Optional[CommitMode] = None
Expand Down Expand Up @@ -175,10 +173,14 @@ def begin(self):
):
self.state = TransactionState.COMMITTED

def read(self) -> Optional["BaseResult"]:
def read(self) -> Union["BaseResult", ResultRecord, None]:
if self.store and self.key:
record = self.store.read(key=self.key)
if record is not None:
if isinstance(record, ResultRecord):
return record
# for backwards compatibility, if we encounter a transaction record, return the result
# This happens when the transaction is using a `ResultStore`
if isinstance(record, TransactionRecord):
return record.result
return None

Expand Down Expand Up @@ -228,7 +230,13 @@ def commit(self) -> bool:
self.run_hook(hook, "commit")

if self.store and self.key:
self.store.write(key=self.key, result=self._staged_value)
if isinstance(self.store, ResultStore):
if isinstance(self._staged_value, BaseResult):
self.store.write(self.key, self._staged_value.get(_sync=True))
else:
self.store.write(self.key, self._staged_value)
else:
self.store.write(self.key, self._staged_value)
self.state = TransactionState.COMMITTED
if (
self.store
Expand Down Expand Up @@ -279,7 +287,7 @@ def run_hook(self, hook, hook_type: str) -> None:

def stage(
self,
value: "BaseResult",
value: Union["BaseResult", Any],
on_rollback_hooks: Optional[List] = None,
on_commit_hooks: Optional[List] = None,
) -> None:
Expand Down Expand Up @@ -337,7 +345,7 @@ def get_transaction() -> Optional[Transaction]:
@contextmanager
def transaction(
key: Optional[str] = None,
store: Optional[RecordStore] = None,
store: Union[RecordStore, ResultStore, None] = None,
commit_mode: Optional[CommitMode] = None,
isolation_level: Optional[IsolationLevel] = None,
overwrite: bool = False,
Expand Down
101 changes: 101 additions & 0 deletions tests/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from prefect.filesystems import LocalFileSystem
from prefect.flows import flow
from prefect.locking.memory import MemoryLockManager
from prefect.records import RecordStore
from prefect.records.memory import MemoryRecordStore
from prefect.records.result_store import ResultRecordStore
Expand Down Expand Up @@ -466,6 +467,106 @@ def winning_transaction():
assert record.result == result_1


class TestWithResultStore:
@pytest.fixture()
def default_storage_setting(self, tmp_path):
name = str(uuid.uuid4())
LocalFileSystem(basepath=tmp_path).save(name)
with temporary_settings(
{
PREFECT_DEFAULT_RESULT_STORAGE_BLOCK: f"local-file-system/{name}",
PREFECT_TASK_SCHEDULING_DEFAULT_STORAGE_BLOCK: f"local-file-system/{name}",
}
):
yield

@pytest.fixture
async def result_store(self, default_storage_setting):
result_store = ResultStore(
persist_result=True, lock_manager=MemoryLockManager()
)
return result_store

async def test_basic_transaction(self, result_store):
with transaction(key="test_basic_transaction", store=result_store) as txn:
assert isinstance(txn.store, ResultStore)
txn.stage({"foo": "bar"})

record_1 = txn.read()
assert record_1
assert record_1.result == {"foo": "bar"}

record_2 = result_store.read("test_basic_transaction")
assert record_2
assert record_2 == record_1
assert record_2.metadata.storage_key == "test_basic_transaction"

async def test_competing_read_transaction(self, result_store):
write_transaction_open = threading.Event()

def writing_transaction():
# isolation level is SERIALIZABLE, so a lock will be taken
with transaction(
key="test_competing_read_transaction",
store=result_store,
isolation_level=IsolationLevel.SERIALIZABLE,
) as txn:
write_transaction_open.set()
txn.stage({"foo": "bar"})

thread = threading.Thread(target=writing_transaction)
thread.start()
write_transaction_open.wait()
with transaction(
key="test_competing_read_transaction", store=result_store
) as txn:
read_result = txn.read()

assert read_result.result == {"foo": "bar"}
thread.join()

async def test_competing_write_transaction(self, result_store):
transaction_1_open = threading.Event()

def winning_transaction():
with transaction(
key="test_competing_write_transaction",
store=result_store,
isolation_level=IsolationLevel.SERIALIZABLE,
) as txn:
transaction_1_open.set()
txn.stage({"foo": "bar"})

thread = threading.Thread(target=winning_transaction)
thread.start()
transaction_1_open.wait()
with transaction(
key="test_competing_write_transaction",
store=result_store,
isolation_level=IsolationLevel.SERIALIZABLE,
) as txn:
txn.stage({"fizz": "buzz"})

thread.join()
record = result_store.read("test_competing_write_transaction")
assert record
# the first transaction should have written its result
# and the second transaction should not have written on exit
assert record.result == {"foo": "bar"}

async def test_can_handle_staged_base_result(self, result_store):
result_1 = await result_store.create_result(obj={"foo": "bar"})
with transaction(
key="test_can_handle_staged_base_result", store=result_store
) as txn:
txn.stage(result_1)

record = txn.read()
assert record
assert record.result == {"foo": "bar"}
assert record.metadata.storage_block_id == result_1.storage_block_id


class TestHooks:
def test_get_and_set_data(self):
with transaction(key="test") as txn:
Expand Down

0 comments on commit 8ecdce4

Please sign in to comment.