diff --git a/src/prefect/results.py b/src/prefect/results.py index a45f53780a2d..63aaa679f1eb 100644 --- a/src/prefect/results.py +++ b/src/prefect/results.py @@ -17,7 +17,16 @@ ) from uuid import UUID -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + ValidationError, + field_serializer, + model_serializer, + model_validator, +) from pydantic_core import PydanticUndefinedType from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import ParamSpec, Self @@ -30,7 +39,7 @@ WritableFileSystem, ) from prefect.logging import get_logger -from prefect.serializers import Serializer +from prefect.serializers import PickleSerializer, Serializer from prefect.settings import ( PREFECT_DEFAULT_RESULT_STORAGE_BLOCK, PREFECT_LOCAL_STORAGE_PATH, @@ -360,20 +369,192 @@ def key_fn(): serialize_to_none=not self.persist_result, ) + # TODO: These two methods need to find a new home + @sync_compatible async def store_parameters(self, identifier: UUID, parameters: Dict[str, Any]): - data = self.serializer.dumps(parameters) - blob = PersistedResultBlob(serializer=self.serializer, data=data) + record = ResultRecord( + result=parameters, + metadata=ResultRecordMetadata( + serializer=self.serializer, storage_key=str(identifier) + ), + ) await self.storage_block.write_path( - f"parameters/{identifier}", content=blob.to_bytes() + f"parameters/{identifier}", content=record.serialize() ) @sync_compatible async def read_parameters(self, identifier: UUID) -> Dict[str, Any]: - blob = PersistedResultBlob.model_validate_json( + record = ResultRecord.deserialize( await self.storage_block.read_path(f"parameters/{identifier}") ) - return self.serializer.loads(blob.data) + return record.result + + +class ResultRecordMetadata(BaseModel): + """ + Metadata for a result record. + """ + + storage_key: Optional[str] = Field( + default=None + ) # optional for backwards compatibility + expiration: Optional[DateTime] = Field(default=None) + serializer: Serializer = Field(default_factory=PickleSerializer) + prefect_version: str = Field(default=prefect.__version__) + + def dump_bytes(self) -> bytes: + """ + Serialize the metadata to bytes. + + Returns: + bytes: the serialized metadata + """ + return self.model_dump_json(serialize_as_any=True).encode() + + @classmethod + def load_bytes(cls, data: bytes) -> "ResultRecordMetadata": + """ + Deserialize metadata from bytes. + + Args: + data: the serialized metadata + + Returns: + ResultRecordMetadata: the deserialized metadata + """ + return cls.model_validate_json(data) + + +class ResultRecord(BaseModel, Generic[R]): + """ + A record of a result. + """ + + metadata: ResultRecordMetadata + result: R + + @property + def expiration(self) -> Optional[DateTime]: + return self.metadata.expiration + + @property + def serializer(self) -> Serializer: + return self.metadata.serializer + + @field_serializer("result") + def serialize_result(self, value: R) -> bytes: + try: + data = self.serializer.dumps(value) + except Exception as exc: + extra_info = ( + 'You can try a different serializer (e.g. result_serializer="json") ' + "or disabling persistence (persist_result=False) for this flow or task." + ) + # check if this is a known issue with cloudpickle and pydantic + # and add extra information to help the user recover + + if ( + isinstance(exc, TypeError) + and isinstance(value, BaseModel) + and str(exc).startswith("cannot pickle") + ): + try: + from IPython import get_ipython + + if get_ipython() is not None: + extra_info = inspect.cleandoc( + """ + This is a known issue in Pydantic that prevents + locally-defined (non-imported) models from being + serialized by cloudpickle in IPython/Jupyter + environments. Please see + https://github.com/pydantic/pydantic/issues/8232 for + more information. To fix the issue, either: (1) move + your Pydantic class definition to an importable + location, (2) use the JSON serializer for your flow + or task (`result_serializer="json"`), or (3) + disable result persistence for your flow or task + (`persist_result=False`). + """ + ).replace("\n", " ") + except ImportError: + pass + raise ValueError( + f"Failed to serialize object of type {type(value).__name__!r} with " + f"serializer {self.serializer.type!r}. {extra_info}" + ) from exc + + return data + + @model_validator(mode="before") + @classmethod + def coerce_old_format(cls, value: Any): + if isinstance(value, dict): + if "data" in value: + value["result"] = value.pop("data") + if "metadata" not in value: + value["metadata"] = {} + if "expiration" in value: + value["metadata"]["expiration"] = value.pop("expiration") + if "serializer" in value: + value["metadata"]["serializer"] = value.pop("serializer") + if "prefect_version" in value: + value["metadata"]["prefect_version"] = value.pop("prefect_version") + return value + + def serialize_metadata(self) -> bytes: + return self.metadata.dump_bytes() + + def serialize( + self, + ) -> bytes: + """ + Serialize the record to bytes. + + Returns: + bytes: the serialized record + + """ + return self.model_dump_json(serialize_as_any=True).encode() + + @classmethod + def deserialize(cls, data: bytes) -> "ResultRecord[R]": + """ + Deserialize a record from bytes. + + Args: + data: the serialized record + + Returns: + ResultRecord: the deserialized record + """ + instance = cls.model_validate_json(data) + if isinstance(instance.result, bytes): + instance.result = instance.serializer.loads(instance.result) + elif isinstance(instance.result, str): + instance.result = instance.serializer.loads(instance.result.encode()) + return instance + + @classmethod + def deserialize_from_result_and_metadata( + cls, result: bytes, metadata: bytes + ) -> "ResultRecord[R]": + """ + Deserialize a record from separate result and metadata bytes. + + Args: + result: the result + metadata: the serialized metadata + + Returns: + ResultRecord: the deserialized record + """ + result_record_metadata = ResultRecordMetadata.load_bytes(metadata) + return cls( + metadata=result_record_metadata, + result=result_record_metadata.serializer.loads(result), + ) @register_base_type @@ -429,7 +610,7 @@ class PersistedResult(BaseResult): Result type which stores a reference to a persisted result. When created, the user's object is serialized and stored. The format for the content - is defined by `PersistedResultBlob`. This reference contains metadata necessary for retrieval + is defined by `ResultRecord`. This reference contains metadata necessary for retrieval of the object, such as a reference to the storage block and the key where the content was written. """ @@ -447,11 +628,11 @@ class PersistedResult(BaseResult): _storage_block: WritableFileSystem = PrivateAttr(default=None) _serializer: Serializer = PrivateAttr(default=None) - def model_dump(self, *args, **kwargs): + @model_serializer(mode="wrap") + def serialize_model(self, handler, info): if self.serialize_to_none: return None - else: - return super().model_dump(*args, **kwargs) + return handler(self, info) def _cache_object( self, @@ -483,21 +664,20 @@ async def get(self, client: "PrefectClient") -> R: if self.has_cached_object(): return self._cache - blob = await self._read_blob(client=client) - obj = blob.load() - self.expiration = blob.expiration + record = await self._read_result_record(client=client) + self.expiration = record.expiration if self._should_cache_object: - self._cache_object(obj) + self._cache_object(record.result) - return obj + return record.result @inject_client - async def _read_blob(self, client: "PrefectClient") -> "PersistedResultBlob": + async def _read_result_record(self, client: "PrefectClient") -> "ResultRecord": block = await self._get_storage_block(client=client) content = await block.read_path(self.storage_key) - blob = PersistedResultBlob.model_validate_json(content) - return blob + record = ResultRecord.deserialize(content) + return record @staticmethod def _infer_path(storage_block, key) -> str: @@ -538,50 +718,15 @@ async def write(self, obj: R = NotSet, client: "PrefectClient" = None) -> None: # this could error if the serializer requires kwargs serializer = Serializer(type=self.serializer_type) - try: - data = serializer.dumps(obj) - except Exception as exc: - extra_info = ( - 'You can try a different serializer (e.g. result_serializer="json") ' - "or disabling persistence (persist_result=False) for this flow or task." - ) - # check if this is a known issue with cloudpickle and pydantic - # and add extra information to help the user recover - - if ( - isinstance(exc, TypeError) - and isinstance(obj, BaseModel) - and str(exc).startswith("cannot pickle") - ): - try: - from IPython import get_ipython - - if get_ipython() is not None: - extra_info = inspect.cleandoc( - """ - This is a known issue in Pydantic that prevents - locally-defined (non-imported) models from being - serialized by cloudpickle in IPython/Jupyter - environments. Please see - https://github.com/pydantic/pydantic/issues/8232 for - more information. To fix the issue, either: (1) move - your Pydantic class definition to an importable - location, (2) use the JSON serializer for your flow - or task (`result_serializer="json"`), or (3) - disable result persistence for your flow or task - (`persist_result=False`). - """ - ).replace("\n", " ") - except ImportError: - pass - raise ValueError( - f"Failed to serialize object of type {type(obj).__name__!r} with " - f"serializer {serializer.type!r}. {extra_info}" - ) from exc - blob = PersistedResultBlob( - serializer=serializer, data=data, expiration=self.expiration + record = ResultRecord( + result=obj, + metadata=ResultRecordMetadata( + storage_key=self.storage_key, + expiration=self.expiration, + serializer=serializer, + ), ) - await storage_block.write_path(self.storage_key, content=blob.to_bytes()) + await storage_block.write_path(self.storage_key, content=record.serialize()) self._persisted = True if not self._should_cache_object: @@ -642,22 +787,3 @@ def __eq__(self, other): and self.storage_block_id == other.storage_block_id and self.expiration == other.expiration ) - - -class PersistedResultBlob(BaseModel): - """ - The format of the content stored by a persisted result. - - Typically, this is written to a file as bytes. - """ - - serializer: Serializer - data: bytes - prefect_version: str = Field(default=prefect.__version__) - expiration: Optional[DateTime] = None - - def load(self) -> Any: - return self.serializer.loads(self.data) - - def to_bytes(self) -> bytes: - return self.model_dump_json(serialize_as_any=True).encode() diff --git a/src/prefect/testing/utilities.py b/src/prefect/testing/utilities.py index a50d9decef02..0d59c5ef7d2b 100644 --- a/src/prefect/testing/utilities.py +++ b/src/prefect/testing/utilities.py @@ -196,7 +196,7 @@ async def assert_uses_result_serializer( if isinstance(serializer, str) else serializer.type ) - blob = await state.data._read_blob() + blob = await state.data._read_result_record() assert ( blob.serializer == serializer if isinstance(serializer, Serializer) diff --git a/tests/results/test_flow_results.py b/tests/results/test_flow_results.py index 4a619d15e102..660a98170f7f 100644 --- a/tests/results/test_flow_results.py +++ b/tests/results/test_flow_results.py @@ -1,5 +1,3 @@ -import base64 -import pickle from pathlib import Path import pytest @@ -11,7 +9,7 @@ from prefect.filesystems import LocalFileSystem from prefect.results import ( PersistedResult, - PersistedResultBlob, + ResultRecord, ) from prefect.serializers import ( CompressedSerializer, @@ -384,9 +382,7 @@ async def foo(): assert result == {"foo": "bar"} local_storage = await LocalFileSystem.load("my-result-storage") result_bytes = await local_storage.read_path(f"{tmp_path/'my-result.pkl'}") - saved_python_result = pickle.loads( - base64.b64decode(PersistedResultBlob.model_validate_json(result_bytes).data) - ) + saved_python_result = ResultRecord.deserialize(result_bytes).result assert saved_python_result == {"foo": "bar"} @@ -444,11 +440,7 @@ def some_flow() -> Block: storage_block = some_flow() assert isinstance(storage_block, LocalFileSystem) - result = pickle.loads( - base64.b64decode( - PersistedResultBlob.model_validate_json( - storage_block.read_path("somespecialflowversion") - ).data - ) - ) + result = ResultRecord.deserialize( + storage_block.read_path("somespecialflowversion") + ).result assert result == "hello" diff --git a/tests/results/test_persisted_result.py b/tests/results/test_persisted_result.py index cb761f357d6e..6705c431d5ed 100644 --- a/tests/results/test_persisted_result.py +++ b/tests/results/test_persisted_result.py @@ -5,7 +5,12 @@ import pytest from prefect.filesystems import LocalFileSystem -from prefect.results import DEFAULT_STORAGE_KEY_FN, PersistedResult, PersistedResultBlob +from prefect.results import ( + DEFAULT_STORAGE_KEY_FN, + PersistedResult, + ResultRecord, + ResultRecordMetadata, +) from prefect.serializers import JSONSerializer, PickleSerializer @@ -65,9 +70,9 @@ async def test_result_reference_create_uses_serializer(storage_block): assert result.serializer_type == serializer.type contents = await storage_block.read_path(result.storage_key) - blob = PersistedResultBlob.model_validate_json(contents) - assert blob.serializer == serializer - assert serializer.loads(blob.data) == "test" + record = ResultRecord.deserialize(contents) + assert record.serializer == serializer + assert record.result == "test" async def test_result_reference_file_blob_is_json(storage_block): @@ -87,13 +92,13 @@ async def test_result_reference_file_blob_is_json(storage_block): contents = await storage_block.read_path(result.storage_key) # Should be readable by JSON - blob_dict = json.loads(contents) + json.loads(contents) - # Should conform to the PersistedResultBlob spec - blob = PersistedResultBlob.model_validate(blob_dict) + # Should conform to the ResultRecord spec + blob = ResultRecord.deserialize(contents) assert blob.serializer - assert blob.data + assert blob.result async def test_result_reference_create_uses_storage_key_fn(storage_block): @@ -123,9 +128,11 @@ async def test_init_doesnt_error_when_doesnt_exist(storage_block): with pytest.raises(ValueError, match="does not exist"): await result.get() - blob = PersistedResultBlob(serializer=JSONSerializer(), data=b"38") - await storage_block.write_path(path, blob.to_bytes()) - assert await result.get() == 38 + record = ResultRecord( + metadata=ResultRecordMetadata(serializer=JSONSerializer()), result=b"38" + ) + await storage_block.write_path(path, record.serialize()) + assert await result.get() == b"38" class TestExpirationField: @@ -168,10 +175,13 @@ async def test_setting_expiration_at_create(self, storage_block): async def test_expiration_when_loaded(self, storage_block): path = uuid.uuid4().hex timestamp = pendulum.now("utc").subtract(days=100) - blob = PersistedResultBlob( - serializer=JSONSerializer(), data=b"42", expiration=timestamp + record = ResultRecord( + metadata=ResultRecordMetadata( + serializer=JSONSerializer(), expiration=timestamp + ), + result=b"42", ) - await storage_block.write_path(path, blob.to_bytes()) + await storage_block.write_path(path, record.serialize()) result = PersistedResult( storage_block_id=storage_block._block_document_id, @@ -179,7 +189,7 @@ async def test_expiration_when_loaded(self, storage_block): serializer_type="json", ) - assert await result.get() == 42 + assert await result.get() == b"42" assert result.expiration == timestamp @@ -193,15 +203,15 @@ async def test_write_is_idempotent(storage_block): ) with pytest.raises(ValueError, match="does not exist"): - await result._read_blob() + await result._read_result_record() await result.write() - blob = await result._read_blob() - assert blob.load() == "test-defer" + record = await result._read_result_record() + assert record.result == "test-defer" await result.write(obj="new-object!") - blob = await result._read_blob() - assert blob.load() == "test-defer" + record = await result._read_result_record() + assert record.result == "test-defer" async def test_lifecycle_of_deferred_persistence(storage_block): @@ -216,8 +226,26 @@ async def test_lifecycle_of_deferred_persistence(storage_block): assert await result.get() == "test-defer" with pytest.raises(ValueError, match="does not exist"): - await result._read_blob() + await result._read_result_record() await result.write() - blob = await result._read_blob() - assert blob.load() == "test-defer" + record = await result._read_result_record() + assert record.result == "test-defer" + + +async def test_read_old_format_into_result_record(): + old_blob = { + "serializer": { + "type": "pickle", + "picklelib": "cloudpickle", + "picklelib_version": None, + }, + "data": "gAVLCS4=\n", + "prefect_version": "2.20.1", + "expiration": None, + } + record = ResultRecord.deserialize(json.dumps(old_blob).encode()) + assert record.result == 9 + assert record.metadata.serializer == PickleSerializer(picklelib="cloudpickle") + assert record.metadata.prefect_version == "2.20.1" + assert record.metadata.expiration is None diff --git a/tests/results/test_state_result.py b/tests/results/test_state_result.py index 5b95839f8d43..6d4acf529733 100644 --- a/tests/results/test_state_result.py +++ b/tests/results/test_state_result.py @@ -10,7 +10,7 @@ import prefect.states from prefect.exceptions import UnfinishedRun from prefect.filesystems import LocalFileSystem, WritableFileSystem -from prefect.results import PersistedResult, PersistedResultBlob, ResultFactory +from prefect.results import PersistedResult, ResultFactory, ResultRecord from prefect.serializers import JSONSerializer from prefect.states import State, StateType from prefect.utilities.annotations import NotSet @@ -136,8 +136,8 @@ async def test_graceful_retries_eventually_succeed_while( ): # now write the result so it's available await a_real_result.write() - expected_blob = await a_real_result._read_blob() - assert isinstance(expected_blob, PersistedResultBlob) + expected_record = await a_real_result._read_result_record() + assert isinstance(expected_record, ResultRecord) # even if it misses a couple times, it will eventually return the data now = time.monotonic() @@ -147,7 +147,7 @@ async def test_graceful_retries_eventually_succeed_while( side_effect=[ FileNotFoundError, TimeoutError, - expected_blob.model_dump_json().encode(), + expected_record.model_dump_json().encode(), ] ), ) as m: diff --git a/tests/test_flows.py b/tests/test_flows.py index e66af38760cb..aec1b8421ddf 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -53,7 +53,7 @@ safe_load_flow_from_entrypoint, ) from prefect.logging import get_run_logger -from prefect.results import PersistedResultBlob +from prefect.results import ResultRecord from prefect.runtime import flow_run as flow_run_ctx from prefect.server.schemas.core import TaskRunResult from prefect.server.schemas.filters import FlowFilter, FlowRunFilter @@ -4545,8 +4545,8 @@ def main(): assert isinstance(val, ValueError) assert "does not exist" in str(val) content = result_storage.read_path("task1-result-A", _sync=True) - blob = PersistedResultBlob.model_validate_json(content) - assert blob.load() == {"some": "data"} + record = ResultRecord.deserialize(content) + assert record.result == {"some": "data"} def test_commit_isnt_called_on_rollback(self): data = {}