Skip to content

Commit

Permalink
Replace PersistedResultBlob with ResultRecord (#15064)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle authored Aug 24, 2024
1 parent 2e72bac commit d867dd0
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 125 deletions.
288 changes: 207 additions & 81 deletions src/prefect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@
)
from uuid import UUID

from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
ValidationError,
field_serializer,
model_serializer,
model_validator,
)
from pydantic_core import PydanticUndefinedType
from pydantic_extra_types.pendulum_dt import DateTime
from typing_extensions import ParamSpec, Self
Expand All @@ -30,7 +39,7 @@
WritableFileSystem,
)
from prefect.logging import get_logger
from prefect.serializers import Serializer
from prefect.serializers import PickleSerializer, Serializer
from prefect.settings import (
PREFECT_DEFAULT_RESULT_STORAGE_BLOCK,
PREFECT_LOCAL_STORAGE_PATH,
Expand Down Expand Up @@ -360,20 +369,192 @@ def key_fn():
serialize_to_none=not self.persist_result,
)

# TODO: These two methods need to find a new home

@sync_compatible
async def store_parameters(self, identifier: UUID, parameters: Dict[str, Any]):
data = self.serializer.dumps(parameters)
blob = PersistedResultBlob(serializer=self.serializer, data=data)
record = ResultRecord(
result=parameters,
metadata=ResultRecordMetadata(
serializer=self.serializer, storage_key=str(identifier)
),
)
await self.storage_block.write_path(
f"parameters/{identifier}", content=blob.to_bytes()
f"parameters/{identifier}", content=record.serialize()
)

@sync_compatible
async def read_parameters(self, identifier: UUID) -> Dict[str, Any]:
blob = PersistedResultBlob.model_validate_json(
record = ResultRecord.deserialize(
await self.storage_block.read_path(f"parameters/{identifier}")
)
return self.serializer.loads(blob.data)
return record.result


class ResultRecordMetadata(BaseModel):
"""
Metadata for a result record.
"""

storage_key: Optional[str] = Field(
default=None
) # optional for backwards compatibility
expiration: Optional[DateTime] = Field(default=None)
serializer: Serializer = Field(default_factory=PickleSerializer)
prefect_version: str = Field(default=prefect.__version__)

def dump_bytes(self) -> bytes:
"""
Serialize the metadata to bytes.
Returns:
bytes: the serialized metadata
"""
return self.model_dump_json(serialize_as_any=True).encode()

@classmethod
def load_bytes(cls, data: bytes) -> "ResultRecordMetadata":
"""
Deserialize metadata from bytes.
Args:
data: the serialized metadata
Returns:
ResultRecordMetadata: the deserialized metadata
"""
return cls.model_validate_json(data)


class ResultRecord(BaseModel, Generic[R]):
"""
A record of a result.
"""

metadata: ResultRecordMetadata
result: R

@property
def expiration(self) -> Optional[DateTime]:
return self.metadata.expiration

@property
def serializer(self) -> Serializer:
return self.metadata.serializer

@field_serializer("result")
def serialize_result(self, value: R) -> bytes:
try:
data = self.serializer.dumps(value)
except Exception as exc:
extra_info = (
'You can try a different serializer (e.g. result_serializer="json") '
"or disabling persistence (persist_result=False) for this flow or task."
)
# check if this is a known issue with cloudpickle and pydantic
# and add extra information to help the user recover

if (
isinstance(exc, TypeError)
and isinstance(value, BaseModel)
and str(exc).startswith("cannot pickle")
):
try:
from IPython import get_ipython

if get_ipython() is not None:
extra_info = inspect.cleandoc(
"""
This is a known issue in Pydantic that prevents
locally-defined (non-imported) models from being
serialized by cloudpickle in IPython/Jupyter
environments. Please see
https://github.com/pydantic/pydantic/issues/8232 for
more information. To fix the issue, either: (1) move
your Pydantic class definition to an importable
location, (2) use the JSON serializer for your flow
or task (`result_serializer="json"`), or (3)
disable result persistence for your flow or task
(`persist_result=False`).
"""
).replace("\n", " ")
except ImportError:
pass
raise ValueError(
f"Failed to serialize object of type {type(value).__name__!r} with "
f"serializer {self.serializer.type!r}. {extra_info}"
) from exc

return data

@model_validator(mode="before")
@classmethod
def coerce_old_format(cls, value: Any):
if isinstance(value, dict):
if "data" in value:
value["result"] = value.pop("data")
if "metadata" not in value:
value["metadata"] = {}
if "expiration" in value:
value["metadata"]["expiration"] = value.pop("expiration")
if "serializer" in value:
value["metadata"]["serializer"] = value.pop("serializer")
if "prefect_version" in value:
value["metadata"]["prefect_version"] = value.pop("prefect_version")
return value

def serialize_metadata(self) -> bytes:
return self.metadata.dump_bytes()

def serialize(
self,
) -> bytes:
"""
Serialize the record to bytes.
Returns:
bytes: the serialized record
"""
return self.model_dump_json(serialize_as_any=True).encode()

@classmethod
def deserialize(cls, data: bytes) -> "ResultRecord[R]":
"""
Deserialize a record from bytes.
Args:
data: the serialized record
Returns:
ResultRecord: the deserialized record
"""
instance = cls.model_validate_json(data)
if isinstance(instance.result, bytes):
instance.result = instance.serializer.loads(instance.result)
elif isinstance(instance.result, str):
instance.result = instance.serializer.loads(instance.result.encode())
return instance

@classmethod
def deserialize_from_result_and_metadata(
cls, result: bytes, metadata: bytes
) -> "ResultRecord[R]":
"""
Deserialize a record from separate result and metadata bytes.
Args:
result: the result
metadata: the serialized metadata
Returns:
ResultRecord: the deserialized record
"""
result_record_metadata = ResultRecordMetadata.load_bytes(metadata)
return cls(
metadata=result_record_metadata,
result=result_record_metadata.serializer.loads(result),
)


@register_base_type
Expand Down Expand Up @@ -429,7 +610,7 @@ class PersistedResult(BaseResult):
Result type which stores a reference to a persisted result.
When created, the user's object is serialized and stored. The format for the content
is defined by `PersistedResultBlob`. This reference contains metadata necessary for retrieval
is defined by `ResultRecord`. This reference contains metadata necessary for retrieval
of the object, such as a reference to the storage block and the key where the
content was written.
"""
Expand All @@ -447,11 +628,11 @@ class PersistedResult(BaseResult):
_storage_block: WritableFileSystem = PrivateAttr(default=None)
_serializer: Serializer = PrivateAttr(default=None)

def model_dump(self, *args, **kwargs):
@model_serializer(mode="wrap")
def serialize_model(self, handler, info):
if self.serialize_to_none:
return None
else:
return super().model_dump(*args, **kwargs)
return handler(self, info)

def _cache_object(
self,
Expand Down Expand Up @@ -483,21 +664,20 @@ async def get(self, client: "PrefectClient") -> R:
if self.has_cached_object():
return self._cache

blob = await self._read_blob(client=client)
obj = blob.load()
self.expiration = blob.expiration
record = await self._read_result_record(client=client)
self.expiration = record.expiration

if self._should_cache_object:
self._cache_object(obj)
self._cache_object(record.result)

return obj
return record.result

@inject_client
async def _read_blob(self, client: "PrefectClient") -> "PersistedResultBlob":
async def _read_result_record(self, client: "PrefectClient") -> "ResultRecord":
block = await self._get_storage_block(client=client)
content = await block.read_path(self.storage_key)
blob = PersistedResultBlob.model_validate_json(content)
return blob
record = ResultRecord.deserialize(content)
return record

@staticmethod
def _infer_path(storage_block, key) -> str:
Expand Down Expand Up @@ -538,50 +718,15 @@ async def write(self, obj: R = NotSet, client: "PrefectClient" = None) -> None:
# this could error if the serializer requires kwargs
serializer = Serializer(type=self.serializer_type)

try:
data = serializer.dumps(obj)
except Exception as exc:
extra_info = (
'You can try a different serializer (e.g. result_serializer="json") '
"or disabling persistence (persist_result=False) for this flow or task."
)
# check if this is a known issue with cloudpickle and pydantic
# and add extra information to help the user recover

if (
isinstance(exc, TypeError)
and isinstance(obj, BaseModel)
and str(exc).startswith("cannot pickle")
):
try:
from IPython import get_ipython

if get_ipython() is not None:
extra_info = inspect.cleandoc(
"""
This is a known issue in Pydantic that prevents
locally-defined (non-imported) models from being
serialized by cloudpickle in IPython/Jupyter
environments. Please see
https://github.com/pydantic/pydantic/issues/8232 for
more information. To fix the issue, either: (1) move
your Pydantic class definition to an importable
location, (2) use the JSON serializer for your flow
or task (`result_serializer="json"`), or (3)
disable result persistence for your flow or task
(`persist_result=False`).
"""
).replace("\n", " ")
except ImportError:
pass
raise ValueError(
f"Failed to serialize object of type {type(obj).__name__!r} with "
f"serializer {serializer.type!r}. {extra_info}"
) from exc
blob = PersistedResultBlob(
serializer=serializer, data=data, expiration=self.expiration
record = ResultRecord(
result=obj,
metadata=ResultRecordMetadata(
storage_key=self.storage_key,
expiration=self.expiration,
serializer=serializer,
),
)
await storage_block.write_path(self.storage_key, content=blob.to_bytes())
await storage_block.write_path(self.storage_key, content=record.serialize())
self._persisted = True

if not self._should_cache_object:
Expand Down Expand Up @@ -642,22 +787,3 @@ def __eq__(self, other):
and self.storage_block_id == other.storage_block_id
and self.expiration == other.expiration
)


class PersistedResultBlob(BaseModel):
"""
The format of the content stored by a persisted result.
Typically, this is written to a file as bytes.
"""

serializer: Serializer
data: bytes
prefect_version: str = Field(default=prefect.__version__)
expiration: Optional[DateTime] = None

def load(self) -> Any:
return self.serializer.loads(self.data)

def to_bytes(self) -> bytes:
return self.model_dump_json(serialize_as_any=True).encode()
2 changes: 1 addition & 1 deletion src/prefect/testing/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ async def assert_uses_result_serializer(
if isinstance(serializer, str)
else serializer.type
)
blob = await state.data._read_blob()
blob = await state.data._read_result_record()
assert (
blob.serializer == serializer
if isinstance(serializer, Serializer)
Expand Down
Loading

0 comments on commit d867dd0

Please sign in to comment.