From f9ec197818cc57076ac5fc2f1bf2e2c986f7fcd2 Mon Sep 17 00:00:00 2001 From: Alex Streed Date: Fri, 6 Sep 2024 12:07:28 -0500 Subject: [PATCH] Use `ResultStore` by default for task transactions --- src/prefect/results.py | 21 +++++++++++++++--- src/prefect/states.py | 32 +++++++++++++++------------ src/prefect/task_engine.py | 39 ++++++++++----------------------- src/prefect/transactions.py | 2 ++ src/prefect/utilities/engine.py | 13 ++++++----- tests/test_tasks.py | 21 +++++++++--------- 6 files changed, 67 insertions(+), 61 deletions(-) diff --git a/src/prefect/results.py b/src/prefect/results.py index bae82080fbcb..b5ab69dbe394 100644 --- a/src/prefect/results.py +++ b/src/prefect/results.py @@ -19,6 +19,7 @@ ) from uuid import UUID +import pendulum from pydantic import ( BaseModel, ConfigDict, @@ -293,16 +294,30 @@ async def _exists(self, key: str) -> bool: # so the entire payload doesn't need to be read try: metadata_content = await self.metadata_storage.read_path(key) - return metadata_content is not None + if metadata_content is None: + return False + metadata = ResultRecordMetadata.load_bytes(metadata_content) + except Exception: return False else: try: content = await self.result_storage.read_path(key) - return content is not None + if content is None: + return False + record = ResultRecord.deserialize(content) + metadata = record.metadata except Exception: return False + if metadata.expiration: + # if the result has an expiration, + # check if it is still in the future + exists = metadata.expiration > pendulum.now("utc") + else: + exists = True + return exists + def exists(self, key: str) -> bool: """ Check if a result record exists in storage. @@ -390,8 +405,8 @@ async def aread(self, key: str, holder: Optional[str] = None) -> "ResultRecord": def create_result_record( self, - key: str, obj: Any, + key: Optional[str] = None, expiration: Optional[DateTime] = None, ): """ diff --git a/src/prefect/states.py b/src/prefect/states.py index f3bf2a48ee56..bc4e7be827e6 100644 --- a/src/prefect/states.py +++ b/src/prefect/states.py @@ -25,7 +25,7 @@ UnfinishedRun, ) from prefect.logging.loggers import get_logger, get_run_logger -from prefect.results import BaseResult, R, ResultStore +from prefect.results import BaseResult, R, ResultRecord, ResultStore from prefect.settings import PREFECT_ASYNC_FETCH_STATE_RESULT from prefect.utilities.annotations import BaseAnnotation from prefect.utilities.asyncutils import in_async_main_thread, sync_compatible @@ -131,6 +131,8 @@ async def _get_state_result( result = await _get_state_result_data_with_retries( state, retry_result_failure=retry_result_failure ) + elif isinstance(state.data, ResultRecord): + result = state.data.result elif state.data is None: if state.is_failed() or state.is_crashed() or state.is_cancelled(): @@ -207,7 +209,7 @@ async def exception_to_crashed_state( ) if result_store: - data = await result_store.create_result(exc) + data = result_store.create_result_record(exc) else: # Attach the exception for local usage, will not be available when retrieved # from the API @@ -240,10 +242,10 @@ async def exception_to_failed_state( pass if result_store: - data = await result_store.create_result(exc) + data = result_store.create_result_record(exc) if write_result: try: - await data.write() + await result_store.apersist_result_record(data) except Exception as exc: local_logger.warning( "Failed to write result: %s Execution will continue, but the result has not been written", @@ -310,20 +312,20 @@ async def return_value_to_state( # Unless the user has already constructed a result explicitly, use the store # to update the data to the correct type if not isinstance(state.data, BaseResult): - result = await result_store.create_result( + result_record = result_store.create_result_record( state.data, key=key, expiration=expiration, ) if write_result: try: - await result.write() + await result_store.apersist_result_record(result_record) except Exception as exc: local_logger.warning( "Encountered an error while persisting result: %s Execution will continue, but the result has not been persisted", exc, ) - state.data = result + state.data = result_record return state # Determine a new state from the aggregate of contained states @@ -359,14 +361,14 @@ async def return_value_to_state( # TODO: We may actually want to set the data to a `StateGroup` object and just # allow it to be unpacked into a tuple and such so users can interact with # it - result = await result_store.create_result( + result_record = result_store.create_result_record( retval, key=key, expiration=expiration, ) if write_result: try: - await result.write() + await result_store.apersist_result_record(result_record) except Exception as exc: local_logger.warning( "Encountered an error while persisting result: %s Execution will continue, but the result has not been persisted", @@ -375,7 +377,7 @@ async def return_value_to_state( return State( type=new_state_type, message=message, - data=result, + data=result_record, ) # Generators aren't portable, implicitly convert them to a list. @@ -385,23 +387,23 @@ async def return_value_to_state( data = retval # Otherwise, they just gave data and this is a completed retval - if isinstance(data, BaseResult): + if isinstance(data, (BaseResult, ResultRecord)): return Completed(data=data) else: - result = await result_store.create_result( + result_record = result_store.create_result_record( data, key=key, expiration=expiration, ) if write_result: try: - await result.write() + await result_store.apersist_result_record(result_record) except Exception as exc: local_logger.warning( "Encountered an error while persisting result: %s Execution will continue, but the result has not been persisted", exc, ) - return Completed(data=result) + return Completed(data=result_record) @sync_compatible @@ -442,6 +444,8 @@ async def get_state_exception(state: State) -> BaseException: if isinstance(state.data, BaseResult): result = await _get_state_result_data_with_retries(state) + elif isinstance(state.data, ResultRecord): + result = state.data.result elif state.data is None: result = None else: diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index dec57aa40c7e..e9ea0156a494 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -55,9 +55,9 @@ ) from prefect.futures import PrefectFuture from prefect.logging.loggers import get_logger, patch_print, task_run_logger -from prefect.records.result_store import ResultRecordStore from prefect.results import ( BaseResult, + ResultRecord, _format_user_supplied_storage_key, get_current_result_store, ) @@ -418,6 +418,8 @@ def set_state(self, state: State, force: bool = False) -> State: result = state.result(raise_on_failure=False, fetch=True) if inspect.isawaitable(result): result = run_coro_as_sync(result) + elif isinstance(state.data, ResultRecord): + result = state.data.result else: result = state.data @@ -454,10 +456,6 @@ def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": return self._raised def handle_success(self, result: R, transaction: Transaction) -> R: - result_store = getattr(TaskRunContext.get(), "result_store", None) - if result_store is None: - raise ValueError("Result store is not set") - if self.task.cache_expiration is not None: expiration = pendulum.now("utc") + self.task.cache_expiration else: @@ -466,7 +464,7 @@ def handle_success(self, result: R, transaction: Transaction) -> R: terminal_state = run_coro_as_sync( return_value_to_state( result, - result_store=result_store, + result_store=get_current_result_store(), key=transaction.key, expiration=expiration, ) @@ -538,12 +536,11 @@ def handle_exception(self, exc: Exception) -> None: # If the task fails, and we have retries left, set the task to retrying. if not self.handle_retry(exc): # If the task has no retries left, or the retry condition is not met, set the task to failed. - context = TaskRunContext.get() state = run_coro_as_sync( exception_to_failed_state( exc, message="Task run encountered an exception", - result_store=getattr(context, "result_store", None), + result_store=get_current_result_store(), write_result=True, ) ) @@ -723,15 +720,9 @@ def transaction_context(self) -> Generator[Transaction, None, None]: else PREFECT_TASKS_REFRESH_CACHE.value() ) - result_store = getattr(TaskRunContext.get(), "result_store", None) - if result_store and result_store.persist_result: - store = ResultRecordStore(result_store=result_store) - else: - store = None - with transaction( key=self.compute_transaction_key(), - store=store, + store=get_current_result_store(), overwrite=overwrite, logger=self.logger, ) as txn: @@ -933,6 +924,8 @@ async def set_state(self, state: State, force: bool = False) -> State: # Avoid fetching the result unless it is cached, otherwise we defeat # the purpose of disabling `cache_result_in_memory` result = await new_state.result(raise_on_failure=False, fetch=True) + elif isinstance(new_state.data, ResultRecord): + result = new_state.data.result else: result = new_state.data @@ -966,10 +959,6 @@ async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]" return self._raised async def handle_success(self, result: R, transaction: Transaction) -> R: - result_store = getattr(TaskRunContext.get(), "result_store", None) - if result_store is None: - raise ValueError("Result store is not set") - if self.task.cache_expiration is not None: expiration = pendulum.now("utc") + self.task.cache_expiration else: @@ -977,7 +966,7 @@ async def handle_success(self, result: R, transaction: Transaction) -> R: terminal_state = await return_value_to_state( result, - result_store=result_store, + result_store=get_current_result_store(), key=transaction.key, expiration=expiration, ) @@ -1048,11 +1037,10 @@ async def handle_exception(self, exc: Exception) -> None: # If the task fails, and we have retries left, set the task to retrying. if not await self.handle_retry(exc): # If the task has no retries left, or the retry condition is not met, set the task to failed. - context = TaskRunContext.get() state = await exception_to_failed_state( exc, message="Task run encountered an exception", - result_store=getattr(context, "result_store", None), + result_store=get_current_result_store(), ) self.record_terminal_state_timing(state) await self.set_state(state) @@ -1226,15 +1214,10 @@ async def transaction_context(self) -> AsyncGenerator[Transaction, None]: if self.task.refresh_cache is not None else PREFECT_TASKS_REFRESH_CACHE.value() ) - result_store = getattr(TaskRunContext.get(), "result_store", None) - if result_store and result_store.persist_result: - store = ResultRecordStore(result_store=result_store) - else: - store = None with transaction( key=self.compute_transaction_key(), - store=store, + store=get_current_result_store(), overwrite=overwrite, logger=self.logger, ) as txn: diff --git a/src/prefect/transactions.py b/src/prefect/transactions.py index f0268e33b5b2..a7010f3c5e86 100644 --- a/src/prefect/transactions.py +++ b/src/prefect/transactions.py @@ -233,6 +233,8 @@ def commit(self) -> bool: if isinstance(self.store, ResultStore): if isinstance(self._staged_value, BaseResult): self.store.write(self.key, self._staged_value.get(_sync=True)) + elif isinstance(self._staged_value, ResultRecord): + self.store.persist_result_record(self._staged_value) else: self.store.write(self.key, self._staged_value) else: diff --git a/src/prefect/utilities/engine.py b/src/prefect/utilities/engine.py index 26cd9bc5fe6e..702bfd2bd5e5 100644 --- a/src/prefect/utilities/engine.py +++ b/src/prefect/utilities/engine.py @@ -44,12 +44,11 @@ ) from prefect.flows import Flow from prefect.futures import PrefectFuture -from prefect.futures import PrefectFuture as NewPrefectFuture from prefect.logging.loggers import ( get_logger, task_run_logger, ) -from prefect.results import BaseResult +from prefect.results import BaseResult, ResultRecord from prefect.settings import ( PREFECT_LOGGING_LOG_PRINTS, ) @@ -122,7 +121,7 @@ def add_futures_and_states_to_inputs(obj): def collect_task_run_inputs_sync( - expr: Any, future_cls: Any = NewPrefectFuture, max_depth: int = -1 + expr: Any, future_cls: Any = PrefectFuture, max_depth: int = -1 ) -> Set[TaskRunInput]: """ This function recurses through an expression to generate a set of any discernible @@ -131,7 +130,7 @@ def collect_task_run_inputs_sync( Examples: >>> task_inputs = { - >>> k: collect_task_run_inputs(v) for k, v in parameters.items() + >>> k: collect_task_run_inputs_sync(v) for k, v in parameters.items() >>> } """ # TODO: This function needs to be updated to detect parameters and constants @@ -401,6 +400,8 @@ async def propose_state( # Avoid fetching the result unless it is cached, otherwise we defeat # the purpose of disabling `cache_result_in_memory` result = await state.result(raise_on_failure=False, fetch=True) + elif isinstance(state.data, ResultRecord): + result = state.data.result else: result = state.data @@ -504,6 +505,8 @@ def propose_state_sync( result = state.result(raise_on_failure=False, fetch=True) if inspect.isawaitable(result): result = run_coro_as_sync(result) + elif isinstance(state.data, ResultRecord): + result = state.data.result else: result = state.data @@ -822,7 +825,7 @@ def resolve_to_final_result(expr, context): if isinstance(context.get("annotation"), quote): raise StopVisiting() - if isinstance(expr, NewPrefectFuture): + if isinstance(expr, PrefectFuture): upstream_task_run = context.get("current_task_run") upstream_task = context.get("current_task") if ( diff --git a/tests/test_tasks.py b/tests/test_tasks.py index e74681290f5e..e08c963f3f5e 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -29,8 +29,7 @@ ReservedArgumentError, ) from prefect.filesystems import LocalFileSystem -from prefect.futures import PrefectDistributedFuture -from prefect.futures import PrefectFuture as NewPrefectFuture +from prefect.futures import PrefectDistributedFuture, PrefectFuture from prefect.logging import get_run_logger from prefect.results import ResultStore, get_or_create_default_task_scheduling_storage from prefect.runtime import task_run as task_run_ctx @@ -635,7 +634,7 @@ def foo(x): @flow def bar(): future = foo.submit(1) - assert isinstance(future, NewPrefectFuture) + assert isinstance(future, PrefectFuture) return future task_state = bar() @@ -677,7 +676,7 @@ async def foo(x): @flow async def bar(): future = foo.submit(1) - assert isinstance(future, NewPrefectFuture) + assert isinstance(future, PrefectFuture) return future task_state = await bar() @@ -691,7 +690,7 @@ def foo(x): @flow async def bar(): future = foo.submit(1) - assert isinstance(future, NewPrefectFuture) + assert isinstance(future, PrefectFuture) return future task_state = await bar() @@ -705,7 +704,7 @@ async def foo(x): @flow def bar(): future = foo.submit(1) - assert isinstance(future, NewPrefectFuture) + assert isinstance(future, PrefectFuture) return future task_state = bar() @@ -3449,7 +3448,7 @@ async def test_simple_map(self): @flow def my_flow(): futures = TestTaskMap.add_one.map([1, 2, 3]) - assert all(isinstance(f, NewPrefectFuture) for f in futures) + assert all(isinstance(f, PrefectFuture) for f in futures) return futures task_states = my_flow() @@ -3469,7 +3468,7 @@ async def test_map_can_take_tuple_as_input(self): @flow def my_flow(): futures = TestTaskMap.add_one.map((1, 2, 3)) - assert all(isinstance(f, NewPrefectFuture) for f in futures) + assert all(isinstance(f, PrefectFuture) for f in futures) return futures task_states = my_flow() @@ -3485,7 +3484,7 @@ def generate_numbers(): @flow def my_flow(): futures = TestTaskMap.add_one.map(generate_numbers()) - assert all(isinstance(f, NewPrefectFuture) for f in futures) + assert all(isinstance(f, PrefectFuture) for f in futures) return futures task_states = my_flow() @@ -3508,7 +3507,7 @@ async def test_can_take_quoted_iterable_as_input(self): @flow def my_flow(): futures = TestTaskMap.add_together.map(quote(1), [1, 2, 3]) - assert all(isinstance(f, NewPrefectFuture) for f in futures) + assert all(isinstance(f, PrefectFuture) for f in futures) return futures task_states = my_flow() @@ -3518,7 +3517,7 @@ async def test_does_not_treat_quote_as_iterable(self): @flow def my_flow(): futures = TestTaskMap.add_one.map(quote([1, 2, 3])) - assert all(isinstance(f, NewPrefectFuture) for f in futures) + assert all(isinstance(f, PrefectFuture) for f in futures) return futures task_states = my_flow()