Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to defer result persistence #14155

Merged
merged 7 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 62 additions & 9 deletions src/prefect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,19 @@ class PersistedResult(BaseResult):
expiration: Optional[DateTime] = None

_should_cache_object: bool = PrivateAttr(default=True)
_persisted: bool = PrivateAttr(default=False)
_storage_block: WritableFileSystem = PrivateAttr(default=None)
_serializer: Serializer = PrivateAttr(default=None)

def _cache_object(
self,
obj: Any,
storage_block: WritableFileSystem = None,
serializer: Serializer = None,
) -> None:
self._cache = obj
self._storage_block = storage_block
self._serializer = serializer

@sync_compatible
@inject_client
Expand All @@ -601,7 +614,7 @@ async def get(self, client: "PrefectClient") -> R:
return self._cache

blob = await self._read_blob(client=client)
obj = blob.serializer.loads(blob.data)
obj = blob.load()
self.expiration = blob.expiration

if self._should_cache_object:
Expand Down Expand Up @@ -632,6 +645,43 @@ def _infer_path(storage_block, key) -> str:
if hasattr(storage_block, "_remote_file_system"):
return storage_block._remote_file_system._resolve_path(key)

@sync_compatible
@inject_client
async def write(self, obj: R = NotSet, client: "PrefectClient" = None) -> None:
"""
Write the result to the storage block.
"""

if self._persisted:
# don't double write or overwrite
return

# load objects from a cache

# first the object itself
if obj is NotSet and not self.has_cached_object():
raise ValueError("Cannot write a result that has no object cached.")
obj = obj if obj is not NotSet else self._cache

# next, the storage block
storage_block = self._storage_block
if storage_block is None:
block_document = await client.read_block_document(self.storage_block_id)
storage_block = Block._from_block_document(block_document)

# finally, the serializer
serializer = self._serializer
if serializer is None:
# this could error if the serializer requires kwargs
serializer = Serializer(type=self.serializer_type)

data = serializer.dumps(obj)
blob = PersistedResultBlob(
serializer=serializer, data=data, expiration=self.expiration
)
await storage_block.write_path(self.storage_key, content=blob.to_bytes())
self._persisted = True

@classmethod
@sync_compatible
async def create(
Expand All @@ -643,6 +693,7 @@ async def create(
serializer: Serializer,
cache_object: bool = True,
expiration: Optional[DateTime] = None,
defer_persistence: bool = False,
) -> "PersistedResult[R]":
"""
Create a new result reference from a user's object.
Expand All @@ -652,19 +703,13 @@ async def create(
"""
assert (
storage_block_id is not None
), "Unexpected storage block ID. Was it persisted?"
data = serializer.dumps(obj)
blob = PersistedResultBlob(
serializer=serializer, data=data, expiration=expiration
)
), "Unexpected storage block ID. Was it saved?"

key = storage_key_fn()
if not isinstance(key, str):
raise TypeError(
f"Expected type 'str' for result storage key; got value {key!r}"
)
await storage_block.write_path(key, content=blob.to_bytes())

description = f"Result of type `{type(obj).__name__}`"
uri = cls._infer_path(storage_block, key)
if uri:
Expand All @@ -686,10 +731,15 @@ async def create(

if cache_object:
# Attach the object to the result so it's available without deserialization
result._cache_object(obj)
result._cache_object(
obj, storage_block=storage_block, serializer=serializer
)

object.__setattr__(result, "_should_cache_object", cache_object)

if not defer_persistence:
await result.write(obj=obj)

return result


Expand All @@ -705,6 +755,9 @@ class PersistedResultBlob(BaseModel):
prefect_version: str = Field(default=prefect.__version__)
expiration: Optional[DateTime] = None

def load(self) -> Any:
return self.serializer.loads(self.data)

def to_bytes(self) -> bytes:
return self.model_dump_json(serialize_as_any=True).encode()

Expand Down
20 changes: 20 additions & 0 deletions tests/results/test_persisted_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,23 @@ async def test_expiration_when_loaded(self, storage_block):

assert await result.get() == 42
assert result.expiration == timestamp


async def test_lifecycle_of_defer_persistence(storage_block):
result = await PersistedResult.create(
"test-defer",
storage_block_id=storage_block._block_document_id,
storage_block=storage_block,
storage_key_fn=lambda: "test-defer-path",
serializer=JSONSerializer(),
defer_persistence=True,
)

assert await result.get() == "test-defer"

with pytest.raises(ValueError, match="does not exist"):
await result._read_blob()

await result.write()
blob = await result._read_blob()
assert blob.load() == "test-defer"