From 6fa0f99a226870d97800e6f219267a4d38a7b015 Mon Sep 17 00:00:00 2001 From: Alexander Streed Date: Tue, 3 Sep 2024 09:49:29 -0500 Subject: [PATCH] Add `read` and `write` methods to `ResultFactory` (#15176) --- src/prefect/results.py | 186 ++++++++++++++++++++++--- src/prefect/testing/utilities.py | 8 +- tests/blocks/test_notifications.py | 39 +++--- tests/results/test_persisted_result.py | 16 +-- tests/results/test_state_result.py | 17 ++- 5 files changed, 209 insertions(+), 57 deletions(-) diff --git a/src/prefect/results.py b/src/prefect/results.py index 3037e3377e69..edbca0928a6c 100644 --- a/src/prefect/results.py +++ b/src/prefect/results.py @@ -190,6 +190,12 @@ def storage_block_id(self) -> Optional[UUID]: async def update_for_flow(self, flow: "Flow") -> Self: """ Create a new result factory for a flow with updated settings. + + Args: + flow: The flow to update the result factory for. + + Returns: + An updated result factory. """ update = {} if flow.result_storage is not None: @@ -208,6 +214,12 @@ async def update_for_flow(self, flow: "Flow") -> Self: async def update_for_task(self: Self, task: "Task") -> Self: """ Create a new result factory for a task. + + Args: + task: The task to update the result factory for. + + Returns: + An updated result factory. """ update = {} if task.result_storage is not None: @@ -226,6 +238,139 @@ async def update_for_task(self: Self, task: "Task") -> Self: update["storage_block"] = await get_default_result_storage() return self.model_copy(update=update) + @sync_compatible + async def _read(self, key: str) -> "ResultRecord": + """ + Read a result record from storage. + + This is the internal implementation. Use `read` or `aread` for synchronous and + asynchronous result reading respectively. + + Args: + key: The key to read the result record from. + + Returns: + A result record. + """ + if self.storage_block is None: + self.storage_block = await get_default_result_storage() + + content = await self.storage_block.read_path(f"{key}") + return ResultRecord.deserialize(content) + + def read(self, key: str) -> "ResultRecord": + """ + Read a result record from storage. + + Args: + key: The key to read the result record from. + + Returns: + A result record. + """ + return self._read(key=key, _sync=True) + + async def aread(self, key: str) -> "ResultRecord": + """ + Read a result record from storage. + + Args: + key: The key to read the result record from. + + Returns: + A result record. + """ + return await self._read(key=key, _sync=False) + + @sync_compatible + async def _write( + self, + obj: Any, + key: Optional[str] = None, + expiration: Optional[DateTime] = None, + ): + """ + Write a result to storage. + + This is the internal implementation. Use `write` or `awrite` for synchronous and + asynchronous result writing respectively. + + Args: + key: The key to write the result record to. + obj: The object to write to storage. + expiration: The expiration time for the result record. + """ + if self.storage_block is None: + self.storage_block = await get_default_result_storage() + key = key or self.storage_key_fn() + + record = ResultRecord( + result=obj, + metadata=ResultRecordMetadata( + serializer=self.serializer, expiration=expiration, storage_key=key + ), + ) + await self.apersist_result_record(record) + + def write(self, key: str, obj: Any, expiration: Optional[DateTime] = None): + """ + Write a result to storage. + + Handles the creation of a `ResultRecord` and its serialization to storage. + + Args: + key: The key to write the result record to. + obj: The object to write to storage. + expiration: The expiration time for the result record. + """ + return self._write(obj=obj, key=key, expiration=expiration, _sync=True) + + async def awrite(self, key: str, obj: Any, expiration: Optional[DateTime] = None): + """ + Write a result to storage. + + Args: + key: The key to write the result record to. + obj: The object to write to storage. + expiration: The expiration time for the result record. + """ + return await self._write(obj=obj, key=key, expiration=expiration, _sync=False) + + @sync_compatible + async def _persist_result_record(self, result_record: "ResultRecord"): + """ + Persist a result record to storage. + + Args: + result_record: The result record to persist. + """ + if self.storage_block is None: + self.storage_block = await get_default_result_storage() + + await self.storage_block.write_path( + result_record.metadata.storage_key, content=result_record.serialize() + ) + + def persist_result_record(self, result_record: "ResultRecord"): + """ + Persist a result record to storage. + + Args: + result_record: The result record to persist. + """ + return self._persist_result_record(result_record=result_record, _sync=True) + + async def apersist_result_record(self, result_record: "ResultRecord"): + """ + Persist a result record to storage. + + Args: + result_record: The result record to persist. + """ + return await self._persist_result_record( + result_record=result_record, _sync=False + ) + @sync_compatible async def create_result( self, @@ -234,9 +379,7 @@ async def create_result( expiration: Optional[DateTime] = None, ) -> Union[R, "BaseResult[R]"]: """ - Create a result type for the given object. - - If persistence is enabled the object is serialized, persisted to storage, and a reference is returned. + Create a `PersistedResult` for the given object. """ # Null objects are "cached" in memory at no cost should_cache_object = self.cache_result_in_memory or obj is None @@ -570,14 +713,24 @@ async def _get_storage_block(self, client: "PrefectClient") -> WritableFileSyste @sync_compatible @inject_client - async def get(self, client: "PrefectClient") -> R: + async def get( + self, ignore_cache: bool = False, client: "PrefectClient" = None + ) -> R: """ Retrieve the data and deserialize it into the original object. """ - if self.has_cached_object(): + if self.has_cached_object() and not ignore_cache: return self._cache - record = await self._read_result_record(client=client) + result_factory_kwargs = {} + if self._serializer: + result_factory_kwargs["serializer"] = resolve_serializer(self._serializer) + storage_block = await self._get_storage_block(client=client) + result_factory = ResultFactory( + storage_block=storage_block, **result_factory_kwargs + ) + + record = await result_factory.aread(self.storage_key) self.expiration = record.expiration if self._should_cache_object: @@ -585,13 +738,6 @@ async def get(self, client: "PrefectClient") -> R: return record.result - @inject_client - async def _read_result_record(self, client: "PrefectClient") -> "ResultRecord": - block = await self._get_storage_block(client=client) - content = await block.read_path(self.storage_key) - record = ResultRecord.deserialize(content) - return record - @staticmethod def _infer_path(storage_block, key) -> str: """ @@ -631,15 +777,13 @@ async def write(self, obj: R = NotSet, client: "PrefectClient" = None) -> None: # this could error if the serializer requires kwargs serializer = Serializer(type=self.serializer_type) - record = ResultRecord( - result=obj, - metadata=ResultRecordMetadata( - storage_key=self.storage_key, - expiration=self.expiration, - serializer=serializer, - ), + result_factory = ResultFactory( + storage_block=storage_block, serializer=serializer + ) + await result_factory.awrite( + obj=obj, key=self.storage_key, expiration=self.expiration ) - await storage_block.write_path(self.storage_key, content=record.serialize()) + self._persisted = True if not self._should_cache_object: diff --git a/src/prefect/testing/utilities.py b/src/prefect/testing/utilities.py index 0d59c5ef7d2b..86691dc2eb91 100644 --- a/src/prefect/testing/utilities.py +++ b/src/prefect/testing/utilities.py @@ -18,7 +18,7 @@ from prefect.client.schemas import sorting from prefect.client.utilities import inject_client from prefect.logging.handlers import APILogWorker -from prefect.results import PersistedResult +from prefect.results import PersistedResult, ResultFactory from prefect.serializers import Serializer from prefect.server.api.server import SubprocessASGIServer from prefect.states import State @@ -196,9 +196,11 @@ async def assert_uses_result_serializer( if isinstance(serializer, str) else serializer.type ) - blob = await state.data._read_result_record() + blob = await ResultFactory( + storage_block=await state.data._get_storage_block() + ).aread(state.data.storage_key) assert ( - blob.serializer == serializer + blob.metadata.serializer == serializer if isinstance(serializer, Serializer) else Serializer(type=serializer) ) diff --git a/tests/blocks/test_notifications.py b/tests/blocks/test_notifications.py index b77e3180067f..278db5d94a88 100644 --- a/tests/blocks/test_notifications.py +++ b/tests/blocks/test_notifications.py @@ -98,7 +98,7 @@ async def test_notify_async(self): AppriseMock.assert_called_once() apprise_instance_mock.add.assert_called_once_with( f"mmost://{mm_block.hostname}/{mm_block.token.get_secret_value()}/" - "?image=yes&format=text&overflow=upstream&rto=4.0&cto=4.0&verify=yes" + "?image=yes&format=text&overflow=upstream" ) apprise_instance_mock.async_notify.assert_awaited_once_with( body="test", title=None, notify_type=PREFECT_NOTIFY_TYPE_DEFAULT @@ -120,7 +120,7 @@ def test_flow(): AppriseMock.assert_called_once() apprise_instance_mock.add.assert_called_once_with( f"mmost://{mm_block.hostname}/{mm_block.token.get_secret_value()}/" - "?image=no&format=text&overflow=upstream&rto=4.0&cto=4.0&verify=yes" + "?image=no&format=text&overflow=upstream" ) apprise_instance_mock.async_notify.assert_called_once_with( body="test", title=None, notify_type=PREFECT_NOTIFY_TYPE_DEFAULT @@ -146,7 +146,7 @@ def test_flow(): AppriseMock.assert_called_once() apprise_instance_mock.add.assert_called_once_with( f"mmost://{mm_block.hostname}/{mm_block.token.get_secret_value()}/" - "?image=no&format=text&overflow=upstream&rto=4.0&cto=4.0&verify=yes" + "?image=no&format=text&overflow=upstream" "&channel=death-metal-anonymous%2Cgeneral" ) @@ -176,7 +176,7 @@ async def test_notify_async(self): AppriseMock.assert_called_once() apprise_instance_mock.add.assert_called_once_with( f"discord://{discord_block.webhook_id.get_secret_value()}/{discord_block.webhook_token.get_secret_value()}/" - "?tts=no&avatar=no&footer=no&footer_logo=yes&image=no&fields=yes&format=text&overflow=upstream&rto=4.0&cto=4.0&verify=yes" + "?tts=no&avatar=no&footer=no&footer_logo=yes&image=no&fields=yes&format=text&overflow=upstream" ) apprise_instance_mock.async_notify.assert_awaited_once_with( body="test", title=None, notify_type=PREFECT_NOTIFY_TYPE_DEFAULT @@ -200,7 +200,7 @@ def test_flow(): AppriseMock.assert_called_once() apprise_instance_mock.add.assert_called_once_with( f"discord://{discord_block.webhook_id.get_secret_value()}/{discord_block.webhook_token.get_secret_value()}/" - "?tts=no&avatar=no&footer=no&footer_logo=yes&image=no&fields=yes&format=text&overflow=upstream&rto=4.0&cto=4.0&verify=yes" + "?tts=no&avatar=no&footer=no&footer_logo=yes&image=no&fields=yes&format=text&overflow=upstream" ) apprise_instance_mock.async_notify.assert_called_once_with( body="test", title=None, notify_type=PREFECT_NOTIFY_TYPE_DEFAULT @@ -226,8 +226,9 @@ async def test_notify_async(self): AppriseMock.assert_called_once() apprise_instance_mock.add.assert_called_once_with( - f"opsgenie://{self.API_KEY}//?region=us&priority=normal&batch=no&" - "format=text&overflow=upstream&rto=4.0&cto=4.0&verify=yes" + f"opsgenie://{self.API_KEY}//?action=map®ion=us&priority=normal&" + "batch=no&%3Ainfo=note&%3Asuccess=close&%3Awarning=new&%3Afailure=" + "new&format=text&overflow=upstream" ) apprise_instance_mock.async_notify.assert_awaited_once_with( @@ -237,7 +238,7 @@ async def test_notify_async(self): def _test_notify_sync(self, targets="", params=None, **kwargs): with patch("apprise.Apprise", autospec=True) as AppriseMock: if params is None: - params = "region=us&priority=normal&batch=no" + params = "action=map®ion=us&priority=normal&batch=no" apprise_instance_mock = AppriseMock.return_value apprise_instance_mock.async_notify = AsyncMock() @@ -253,7 +254,7 @@ def test_flow(): AppriseMock.assert_called_once() apprise_instance_mock.add.assert_called_once_with( f"opsgenie://{self.API_KEY}/{targets}/?{params}" - "&format=text&overflow=upstream&rto=4.0&cto=4.0&verify=yes" + "&%3Ainfo=note&%3Asuccess=close&%3Awarning=new&%3Afailure=new&format=text&overflow=upstream" ) apprise_instance_mock.async_notify.assert_awaited_once_with( @@ -264,7 +265,7 @@ def test_notify_sync_simple(self): self._test_notify_sync() def test_notify_sync_params(self): - params = "region=eu&priority=low&batch=yes" + params = "action=map®ion=eu&priority=low&batch=yes" self._test_notify_sync(params=params, region_name="eu", priority=1, batch=True) def test_notify_sync_targets(self): @@ -282,7 +283,7 @@ def test_notify_sync_users(self): self._test_notify_sync(targets=targets, target_user=["user1", "user2"]) def test_notify_sync_details(self): - params = "region=us&priority=normal&batch=no&%2Bkey1=value1&%2Bkey2=value2" + params = "action=map®ion=us&priority=normal&batch=no&%2Bkey1=value1&%2Bkey2=value2" self._test_notify_sync( params=params, details={ @@ -304,7 +305,7 @@ async def test_notify_async(self): AppriseMock.assert_called_once() apprise_instance_mock.add.assert_called_once_with( "pagerduty://int_key@api_key/Prefect/Notification?region=us&" - "image=yes&format=text&overflow=upstream&rto=4.0&cto=4.0&verify=yes" + "image=yes&format=text&overflow=upstream" ) notify_type = "info" @@ -328,7 +329,7 @@ def test_flow(): AppriseMock.assert_called_once() apprise_instance_mock.add.assert_called_once_with( "pagerduty://int_key@api_key/Prefect/Notification?region=us&" - "image=yes&format=text&overflow=upstream&rto=4.0&cto=4.0&verify=yes" + "image=yes&format=text&overflow=upstream" ) notify_type = "info" @@ -344,7 +345,7 @@ def valid_apprise_url(self) -> str: "twilio://ACabcdefabcdefabcdefabcdef" ":XXXXXXXXXXXXXXXXXXXXXXXX" "@%2B15555555555/%2B15555555556/%2B15555555557/" - "?format=text&overflow=upstream&rto=4.0&cto=4.0&verify=yes" + "?format=text&overflow=upstream" ) async def test_twilio_notify_async(self, valid_apprise_url): @@ -619,12 +620,6 @@ class TestSendgridEmail: "format": "html", # default overflow mode "overflow": "upstream", - # socket read timeout - "rto": 4.0, - # socket connect timeout - "cto": 4.0, - # ssl certificate authority verification - "verify": "yes", } async def test_notify_async(self): @@ -713,7 +708,7 @@ async def test_notify_async(self): apprise_instance_mock.add.assert_called_once_with( "workflow://prod-NO.LOCATION.logic.azure.com:443/WFID/SIGNATURE/" "?image=yes&wrap=yes" - "&format=markdown&overflow=upstream&rto=4.0&cto=4.0&verify=yes" + "&format=markdown&overflow=upstream" ) apprise_instance_mock.async_notify.assert_awaited_once_with( body="test", title=None, notify_type=PREFECT_NOTIFY_TYPE_DEFAULT @@ -736,7 +731,7 @@ def test_flow(): apprise_instance_mock.add.assert_called_once_with( "workflow://prod-NO.LOCATION.logic.azure.com:443/WFID/SIGNATURE/" "?image=yes&wrap=yes" - "&format=markdown&overflow=upstream&rto=4.0&cto=4.0&verify=yes" + "&format=markdown&overflow=upstream" ) apprise_instance_mock.async_notify.assert_called_once_with( body="test", title=None, notify_type=PREFECT_NOTIFY_TYPE_DEFAULT diff --git a/tests/results/test_persisted_result.py b/tests/results/test_persisted_result.py index 6705c431d5ed..b7e73a81dc5d 100644 --- a/tests/results/test_persisted_result.py +++ b/tests/results/test_persisted_result.py @@ -203,15 +203,15 @@ async def test_write_is_idempotent(storage_block): ) with pytest.raises(ValueError, match="does not exist"): - await result._read_result_record() + await result.get(ignore_cache=True) await result.write() - record = await result._read_result_record() - assert record.result == "test-defer" + obj = await result.get() + assert obj == "test-defer" await result.write(obj="new-object!") - record = await result._read_result_record() - assert record.result == "test-defer" + obj = await result.get() + assert obj == "test-defer" async def test_lifecycle_of_deferred_persistence(storage_block): @@ -226,11 +226,11 @@ async def test_lifecycle_of_deferred_persistence(storage_block): assert await result.get() == "test-defer" with pytest.raises(ValueError, match="does not exist"): - await result._read_result_record() + await result.get(ignore_cache=True) await result.write() - record = await result._read_result_record() - assert record.result == "test-defer" + obj = await result.get() + assert obj == "test-defer" async def test_read_old_format_into_result_record(): diff --git a/tests/results/test_state_result.py b/tests/results/test_state_result.py index 0b8dac168f2a..3220019e3f09 100644 --- a/tests/results/test_state_result.py +++ b/tests/results/test_state_result.py @@ -10,7 +10,12 @@ import prefect.states from prefect.exceptions import UnfinishedRun from prefect.filesystems import LocalFileSystem, WritableFileSystem -from prefect.results import PersistedResult, ResultFactory, ResultRecord +from prefect.results import ( + PersistedResult, + ResultFactory, + ResultRecord, + ResultRecordMetadata, +) from prefect.serializers import JSONSerializer from prefect.states import State, StateType from prefect.utilities.annotations import NotSet @@ -132,8 +137,14 @@ async def test_graceful_retries_eventually_succeed_while( ): # now write the result so it's available await a_real_result.write() - expected_record = await a_real_result._read_result_record() - assert isinstance(expected_record, ResultRecord) + expected_record = ResultRecord( + result="test-graceful-retry", + metadata=ResultRecordMetadata( + storage_key=a_real_result.storage_key, + expiration=a_real_result.expiration, + serializer=JSONSerializer(), + ), + ) # even if it misses a couple times, it will eventually return the data now = time.monotonic()