diff --git a/src/prefect/exceptions.py b/src/prefect/exceptions.py index aa738003b6c7..89a7f0acfa2f 100644 --- a/src/prefect/exceptions.py +++ b/src/prefect/exceptions.py @@ -412,3 +412,9 @@ class PrefectImportError(ImportError): def __init__(self, message: str) -> None: super().__init__(message) + + +class SerializationError(PrefectException): + """ + Raised when an object cannot be serialized. + """ diff --git a/src/prefect/results.py b/src/prefect/results.py index 63aaa679f1eb..bd13d5856f39 100644 --- a/src/prefect/results.py +++ b/src/prefect/results.py @@ -23,7 +23,6 @@ Field, PrivateAttr, ValidationError, - field_serializer, model_serializer, model_validator, ) @@ -34,6 +33,7 @@ import prefect from prefect.blocks.core import Block from prefect.client.utilities import inject_client +from prefect.exceptions import SerializationError from prefect.filesystems import ( LocalFileSystem, WritableFileSystem, @@ -442,10 +442,9 @@ def expiration(self) -> Optional[DateTime]: def serializer(self) -> Serializer: return self.metadata.serializer - @field_serializer("result") - def serialize_result(self, value: R) -> bytes: + def serialize_result(self) -> bytes: try: - data = self.serializer.dumps(value) + data = self.serializer.dumps(self.result) except Exception as exc: extra_info = ( 'You can try a different serializer (e.g. result_serializer="json") ' @@ -456,7 +455,7 @@ def serialize_result(self, value: R) -> bytes: if ( isinstance(exc, TypeError) - and isinstance(value, BaseModel) + and isinstance(self.result, BaseModel) and str(exc).startswith("cannot pickle") ): try: @@ -480,8 +479,8 @@ def serialize_result(self, value: R) -> bytes: ).replace("\n", " ") except ImportError: pass - raise ValueError( - f"Failed to serialize object of type {type(value).__name__!r} with " + raise SerializationError( + f"Failed to serialize object of type {type(self.result).__name__!r} with " f"serializer {self.serializer.type!r}. {extra_info}" ) from exc @@ -516,7 +515,11 @@ def serialize( bytes: the serialized record """ - return self.model_dump_json(serialize_as_any=True).encode() + return ( + self.model_copy(update={"result": self.serialize_result()}) + .model_dump_json(serialize_as_any=True) + .encode() + ) @classmethod def deserialize(cls, data: bytes) -> "ResultRecord[R]": diff --git a/src/prefect/states.py b/src/prefect/states.py index bc68a6b34fb7..3ef02744398d 100644 --- a/src/prefect/states.py +++ b/src/prefect/states.py @@ -18,12 +18,13 @@ CancelledRun, CrashedRun, FailedRun, + MissingContextError, MissingResult, PausedRun, TerminationSignal, UnfinishedRun, ) -from prefect.logging.loggers import get_logger +from prefect.logging.loggers import get_logger, get_run_logger from prefect.results import BaseResult, R, ResultFactory from prefect.settings import PREFECT_ASYNC_FETCH_STATE_RESULT from prefect.utilities.annotations import BaseAnnotation @@ -224,6 +225,11 @@ async def exception_to_failed_state( """ Convenience function for creating `Failed` states from exceptions """ + try: + local_logger = get_run_logger() + except MissingContextError: + local_logger = logger + if not exc: _, exc, _ = sys.exc_info() if exc is None: @@ -236,7 +242,13 @@ async def exception_to_failed_state( if result_factory: data = await result_factory.create_result(exc) if write_result: - await data.write() + try: + await data.write() + except Exception as exc: + local_logger.warning( + "Failed to write result: %s Execution will continue, but the result has not been written", + exc, + ) else: # Attach the exception for local usage, will not be available when retrieved # from the API @@ -283,6 +295,10 @@ async def return_value_to_state( Callers should resolve all futures into states before passing return values to this function. """ + try: + local_logger = get_run_logger() + except MissingContextError: + local_logger = logger if ( isinstance(retval, State) @@ -300,7 +316,13 @@ async def return_value_to_state( expiration=expiration, ) if write_result: - await result.write() + try: + await result.write() + except Exception as exc: + local_logger.warning( + "Encountered an error while writing result: %s Execution will continue, but the result has not been written", + exc, + ) state.data = result return state @@ -343,7 +365,13 @@ async def return_value_to_state( expiration=expiration, ) if write_result: - await result.write() + try: + await result.write() + except Exception as exc: + local_logger.warning( + "Encountered an error while writing result: %s Execution will continue, but the result has not been written", + exc, + ) return State( type=new_state_type, message=message, @@ -366,7 +394,13 @@ async def return_value_to_state( expiration=expiration, ) if write_result: - await result.write() + try: + await result.write() + except Exception as exc: + local_logger.warning( + "Encountered an error while writing result: %s Execution will continue, but the result has not been written", + exc, + ) return Completed(data=result) diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index 755ecbfcac71..efc188a4a319 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -5,6 +5,7 @@ from asyncio import CancelledError from contextlib import ExitStack, asynccontextmanager, contextmanager from dataclasses import dataclass, field +from functools import partial from textwrap import dedent from typing import ( Any, @@ -466,9 +467,14 @@ def handle_success(self, result: R, transaction: Transaction) -> R: expiration=expiration, ) ) + + # Avoid logging when running this rollback hook since it is not user-defined + handle_rollback = partial(self.handle_rollback) + handle_rollback.log_on_run = False + transaction.stage( terminal_state.data, - on_rollback_hooks=[self.handle_rollback] + self.task.on_rollback_hooks, + on_rollback_hooks=[handle_rollback] + self.task.on_rollback_hooks, on_commit_hooks=self.task.on_commit_hooks, ) if transaction.is_committed(): @@ -969,9 +975,14 @@ async def handle_success(self, result: R, transaction: Transaction) -> R: key=transaction.key, expiration=expiration, ) + + # Avoid logging when running this rollback hook since it is not user-defined + handle_rollback = partial(self.handle_rollback) + handle_rollback.log_on_run = False + transaction.stage( terminal_state.data, - on_rollback_hooks=[self.handle_rollback] + self.task.on_rollback_hooks, + on_rollback_hooks=[handle_rollback] + self.task.on_rollback_hooks, on_commit_hooks=self.task.on_commit_hooks, ) if transaction.is_committed(): diff --git a/src/prefect/transactions.py b/src/prefect/transactions.py index 5a7fb34e6b67..09191f2ef8c8 100644 --- a/src/prefect/transactions.py +++ b/src/prefect/transactions.py @@ -18,7 +18,7 @@ from typing_extensions import Self from prefect.context import ContextModel, FlowRunContext, TaskRunContext -from prefect.exceptions import MissingContextError +from prefect.exceptions import MissingContextError, SerializationError from prefect.logging.loggers import get_logger, get_run_logger from prefect.records import RecordStore from prefect.results import ( @@ -240,6 +240,14 @@ def commit(self) -> bool: self.logger.debug(f"Releasing lock for transaction {self.key!r}") self.store.release_lock(self.key) return True + except SerializationError as exc: + if self.logger: + self.logger.warning( + f"Encountered an error while serializing result for transaction {self.key!r}: {exc}" + " Code execution will continue, but the transaction will not be committed.", + ) + self.rollback() + return False except Exception: if self.logger: self.logger.exception( @@ -251,19 +259,25 @@ def commit(self) -> bool: def run_hook(self, hook, hook_type: str) -> None: hook_name = _get_hook_name(hook) - self.logger.info(f"Running {hook_type} hook {hook_name!r}") + # Undocumented way to disable logging for a hook. Subject to change. + should_log = getattr(hook, "log_on_run", True) + + if should_log: + self.logger.info(f"Running {hook_type} hook {hook_name!r}") try: hook(self) except Exception as exc: - self.logger.error( - f"An error was encountered while running {hook_type} hook {hook_name!r}", - ) + if should_log: + self.logger.error( + f"An error was encountered while running {hook_type} hook {hook_name!r}", + ) raise exc else: - self.logger.info( - f"{hook_type.capitalize()} hook {hook_name!r} finished running successfully" - ) + if should_log: + self.logger.info( + f"{hook_type.capitalize()} hook {hook_name!r} finished running successfully" + ) def stage( self, diff --git a/tests/results/test_state_result.py b/tests/results/test_state_result.py index 6d4acf529733..c9e8d7f6641e 100644 --- a/tests/results/test_state_result.py +++ b/tests/results/test_state_result.py @@ -147,7 +147,7 @@ async def test_graceful_retries_eventually_succeed_while( side_effect=[ FileNotFoundError, TimeoutError, - expected_record.model_dump_json().encode(), + expected_record.serialize(), ] ), ) as m: diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 7c5a14511224..860b91b3f56b 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1352,6 +1352,24 @@ def base(): assert my_task.persist_result is True assert new_task.persist_result is True + def test_logs_warning_on_serialization_error(self, caplog): + @task(result_serializer="json") + def my_task(): + return lambda: 1 + + my_task() + + record = next( + ( + record + for record in caplog.records + if "Encountered an error while serializing result" in record.message + ), + None, + ) + assert record is not None + assert record.levelname == "WARNING" + class TestTaskCaching: async def test_repeated_task_call_within_flow_is_cached_by_default(self): @@ -5053,6 +5071,19 @@ def commit(txn): assert state.name == "Completed" assert isinstance(data["txn"], Transaction) + def test_does_not_log_rollback_when_no_user_defined_rollback_hooks(self, caplog): + @task + def my_task(): + pass + + @my_task.on_commit + def commit(txn): + raise Exception("oops") + + my_task() + + assert "Running rollback hook" not in caplog.text + class TestApplyAsync: @pytest.mark.parametrize(