Skip to content

Commit

Permalink
Fix some more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle committed Sep 7, 2024
1 parent 80b36df commit afb104b
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 18 deletions.
23 changes: 14 additions & 9 deletions src/prefect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _format_user_supplied_storage_key(key: str) -> str:
)
class ResultStore(BaseModel):
"""
Manages the storage and retrieval of results.ff
Manages the storage and retrieval of results.
Attributes:
result_storage: The storage for result records. If not provided, the default
Expand Down Expand Up @@ -404,15 +404,20 @@ async def _read(self, key: str, holder: str) -> "ResultRecord":
result_record = ResultRecord.deserialize_from_result_and_metadata(
result=result_content, metadata=metadata_content
)
if self.cache_result_in_memory:
self.cache[key] = result_record
return result_record
else:
content = await self.result_storage.read_path(key)
result_record = ResultRecord.deserialize(content)
if self.cache_result_in_memory:
self.cache[key] = result_record
return result_record

if self.cache_result_in_memory:
if self.result_storage_block_id is None and hasattr(
self.result_storage, "_resolve_path"
):
cache_key = str(self.result_storage._resolve_path(key))
else:
cache_key = key

self.cache[cache_key] = result_record
return result_record

def read(self, key: str, holder: Optional[str] = None) -> "ResultRecord":
"""
Expand Down Expand Up @@ -475,8 +480,8 @@ def create_result_record(

def write(
self,
key: str,
obj: Any,
key: Optional[str] = None,
expiration: Optional[DateTime] = None,
holder: Optional[str] = None,
):
Expand All @@ -501,8 +506,8 @@ def write(

async def awrite(
self,
key: str,
obj: Any,
key: Optional[str] = None,
expiration: Optional[DateTime] = None,
holder: Optional[str] = None,
):
Expand Down
4 changes: 2 additions & 2 deletions src/prefect/utilities/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,9 +735,9 @@ def emit_task_run_state_change_event(
) -> Event:
state_message_truncation_length = 100_000

if isinstance(validated_state.data, ResultRecord):
if isinstance(validated_state.data, ResultRecord) and should_persist_result():
data = validated_state.data.metadata.model_dump(mode="json")
elif isinstance(validated_state.data, BaseResult) and should_persist_result():
elif isinstance(validated_state.data, BaseResult):
data = validated_state.data.model_dump(mode="json")
else:
data = None
Expand Down
5 changes: 2 additions & 3 deletions tests/test_task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1659,9 +1659,8 @@ async def async_task():
assert await state.result() == 42

async def test_task_loads_result_if_exists_using_result_storage_key(self):
store = ResultStore(persist_result=True)
result = await store.create_result(-92, key="foo-bar")
await result.write()
store = ResultStore()
store.write(obj=-92, key="foo-bar")

@task(result_storage_key="foo-bar", persist_result=True)
async def async_task():
Expand Down
11 changes: 7 additions & 4 deletions tests/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,14 +278,15 @@ async def test_transaction_outside_of_run(self):
result = await txn.store.result_store.create_result(
obj={"foo": "bar"}, key=txn.key
)
result.serialize_to_none = False
txn.stage(result)

result = txn.read()
assert result
assert await result.get() == {"foo": "bar"}

async def test_transaction_inside_flow_default_storage(self):
@flow
@flow(persist_result=True)
def test_flow():
with transaction(key="test_transaction_inside_flow_default_storage") as txn:
assert isinstance(txn.store, ResultRecordStore)
Expand Down Expand Up @@ -330,14 +331,16 @@ async def test_flow():
async def test_transaction_inside_task_default_storage(self):
default_task_storage = await get_or_create_default_task_scheduling_storage()

@task
@task(persist_result=True)
async def test_task():
with transaction(key="test_transaction_inside_task_default_storage") as txn:
with transaction(
key="test_transaction_inside_task_default_storage",
commit_mode=CommitMode.EAGER,
) as txn:
assert isinstance(txn.store, ResultRecordStore)
result = await txn.store.result_store.create_result(
obj={"foo": "bar"}, key=txn.key
)
await result.write()
txn.stage(result)

result = txn.read()
Expand Down

0 comments on commit afb104b

Please sign in to comment.