Skip to content

Commit

Permalink
Add read and write methods to ResultFactory (#15176)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle authored Sep 3, 2024
1 parent 5d05449 commit 6fa0f99
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 57 deletions.
186 changes: 165 additions & 21 deletions src/prefect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ def storage_block_id(self) -> Optional[UUID]:
async def update_for_flow(self, flow: "Flow") -> Self:
"""
Create a new result factory for a flow with updated settings.
Args:
flow: The flow to update the result factory for.
Returns:
An updated result factory.
"""
update = {}
if flow.result_storage is not None:
Expand All @@ -208,6 +214,12 @@ async def update_for_flow(self, flow: "Flow") -> Self:
async def update_for_task(self: Self, task: "Task") -> Self:
"""
Create a new result factory for a task.
Args:
task: The task to update the result factory for.
Returns:
An updated result factory.
"""
update = {}
if task.result_storage is not None:
Expand All @@ -226,6 +238,139 @@ async def update_for_task(self: Self, task: "Task") -> Self:
update["storage_block"] = await get_default_result_storage()
return self.model_copy(update=update)

@sync_compatible
async def _read(self, key: str) -> "ResultRecord":
"""
Read a result record from storage.
This is the internal implementation. Use `read` or `aread` for synchronous and
asynchronous result reading respectively.
Args:
key: The key to read the result record from.
Returns:
A result record.
"""
if self.storage_block is None:
self.storage_block = await get_default_result_storage()

content = await self.storage_block.read_path(f"{key}")
return ResultRecord.deserialize(content)

def read(self, key: str) -> "ResultRecord":
"""
Read a result record from storage.
Args:
key: The key to read the result record from.
Returns:
A result record.
"""
return self._read(key=key, _sync=True)

async def aread(self, key: str) -> "ResultRecord":
"""
Read a result record from storage.
Args:
key: The key to read the result record from.
Returns:
A result record.
"""
return await self._read(key=key, _sync=False)

@sync_compatible
async def _write(
self,
obj: Any,
key: Optional[str] = None,
expiration: Optional[DateTime] = None,
):
"""
Write a result to storage.
This is the internal implementation. Use `write` or `awrite` for synchronous and
asynchronous result writing respectively.
Args:
key: The key to write the result record to.
obj: The object to write to storage.
expiration: The expiration time for the result record.
"""
if self.storage_block is None:
self.storage_block = await get_default_result_storage()
key = key or self.storage_key_fn()

record = ResultRecord(
result=obj,
metadata=ResultRecordMetadata(
serializer=self.serializer, expiration=expiration, storage_key=key
),
)
await self.apersist_result_record(record)

def write(self, key: str, obj: Any, expiration: Optional[DateTime] = None):
"""
Write a result to storage.
Handles the creation of a `ResultRecord` and its serialization to storage.
Args:
key: The key to write the result record to.
obj: The object to write to storage.
expiration: The expiration time for the result record.
"""
return self._write(obj=obj, key=key, expiration=expiration, _sync=True)

async def awrite(self, key: str, obj: Any, expiration: Optional[DateTime] = None):
"""
Write a result to storage.
Args:
key: The key to write the result record to.
obj: The object to write to storage.
expiration: The expiration time for the result record.
"""
return await self._write(obj=obj, key=key, expiration=expiration, _sync=False)

@sync_compatible
async def _persist_result_record(self, result_record: "ResultRecord"):
"""
Persist a result record to storage.
Args:
result_record: The result record to persist.
"""
if self.storage_block is None:
self.storage_block = await get_default_result_storage()

await self.storage_block.write_path(
result_record.metadata.storage_key, content=result_record.serialize()
)

def persist_result_record(self, result_record: "ResultRecord"):
"""
Persist a result record to storage.
Args:
result_record: The result record to persist.
"""
return self._persist_result_record(result_record=result_record, _sync=True)

async def apersist_result_record(self, result_record: "ResultRecord"):
"""
Persist a result record to storage.
Args:
result_record: The result record to persist.
"""
return await self._persist_result_record(
result_record=result_record, _sync=False
)

@sync_compatible
async def create_result(
self,
Expand All @@ -234,9 +379,7 @@ async def create_result(
expiration: Optional[DateTime] = None,
) -> Union[R, "BaseResult[R]"]:
"""
Create a result type for the given object.
If persistence is enabled the object is serialized, persisted to storage, and a reference is returned.
Create a `PersistedResult` for the given object.
"""
# Null objects are "cached" in memory at no cost
should_cache_object = self.cache_result_in_memory or obj is None
Expand Down Expand Up @@ -570,28 +713,31 @@ async def _get_storage_block(self, client: "PrefectClient") -> WritableFileSyste

@sync_compatible
@inject_client
async def get(self, client: "PrefectClient") -> R:
async def get(
self, ignore_cache: bool = False, client: "PrefectClient" = None
) -> R:
"""
Retrieve the data and deserialize it into the original object.
"""
if self.has_cached_object():
if self.has_cached_object() and not ignore_cache:
return self._cache

record = await self._read_result_record(client=client)
result_factory_kwargs = {}
if self._serializer:
result_factory_kwargs["serializer"] = resolve_serializer(self._serializer)
storage_block = await self._get_storage_block(client=client)
result_factory = ResultFactory(
storage_block=storage_block, **result_factory_kwargs
)

record = await result_factory.aread(self.storage_key)
self.expiration = record.expiration

if self._should_cache_object:
self._cache_object(record.result)

return record.result

@inject_client
async def _read_result_record(self, client: "PrefectClient") -> "ResultRecord":
block = await self._get_storage_block(client=client)
content = await block.read_path(self.storage_key)
record = ResultRecord.deserialize(content)
return record

@staticmethod
def _infer_path(storage_block, key) -> str:
"""
Expand Down Expand Up @@ -631,15 +777,13 @@ async def write(self, obj: R = NotSet, client: "PrefectClient" = None) -> None:
# this could error if the serializer requires kwargs
serializer = Serializer(type=self.serializer_type)

record = ResultRecord(
result=obj,
metadata=ResultRecordMetadata(
storage_key=self.storage_key,
expiration=self.expiration,
serializer=serializer,
),
result_factory = ResultFactory(
storage_block=storage_block, serializer=serializer
)
await result_factory.awrite(
obj=obj, key=self.storage_key, expiration=self.expiration
)
await storage_block.write_path(self.storage_key, content=record.serialize())

self._persisted = True

if not self._should_cache_object:
Expand Down
8 changes: 5 additions & 3 deletions src/prefect/testing/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from prefect.client.schemas import sorting
from prefect.client.utilities import inject_client
from prefect.logging.handlers import APILogWorker
from prefect.results import PersistedResult
from prefect.results import PersistedResult, ResultFactory
from prefect.serializers import Serializer
from prefect.server.api.server import SubprocessASGIServer
from prefect.states import State
Expand Down Expand Up @@ -196,9 +196,11 @@ async def assert_uses_result_serializer(
if isinstance(serializer, str)
else serializer.type
)
blob = await state.data._read_result_record()
blob = await ResultFactory(
storage_block=await state.data._get_storage_block()
).aread(state.data.storage_key)
assert (
blob.serializer == serializer
blob.metadata.serializer == serializer
if isinstance(serializer, Serializer)
else Serializer(type=serializer)
)
Expand Down
Loading

0 comments on commit 6fa0f99

Please sign in to comment.