From 09f6f59e2820f9e9becbc707ae2acf278ee3b6e7 Mon Sep 17 00:00:00 2001 From: Chris White Date: Mon, 12 Aug 2024 11:33:35 -0500 Subject: [PATCH] Consolidate where transaction hooks are run (#14877) --- src/prefect/task_engine.py | 48 ++++--------------------------------- src/prefect/transactions.py | 21 ++++++++++++++-- 2 files changed, 23 insertions(+), 46 deletions(-) diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index e2679a9d002b..63517c6398d4 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -5,7 +5,6 @@ from asyncio import CancelledError from contextlib import ExitStack, asynccontextmanager, contextmanager from dataclasses import dataclass, field -from functools import wraps from textwrap import dedent from typing import ( Any, @@ -496,15 +495,8 @@ def handle_success(self, result: R, transaction: Transaction) -> R: ) transaction.stage( terminal_state.data, - on_rollback_hooks=[self.handle_rollback] - + [ - _with_transaction_hook_logging(hook, "rollback", self.logger) - for hook in self.task.on_rollback_hooks - ], - on_commit_hooks=[ - _with_transaction_hook_logging(hook, "commit", self.logger) - for hook in self.task.on_commit_hooks - ], + on_rollback_hooks=[self.handle_rollback] + self.task.on_rollback_hooks, + on_commit_hooks=self.task.on_commit_hooks, ) if transaction.is_committed(): terminal_state.name = "Cached" @@ -1068,15 +1060,8 @@ async def handle_success(self, result: R, transaction: Transaction) -> R: ) transaction.stage( terminal_state.data, - on_rollback_hooks=[self.handle_rollback] - + [ - _with_transaction_hook_logging(hook, "rollback", self.logger) - for hook in self.task.on_rollback_hooks - ], - on_commit_hooks=[ - _with_transaction_hook_logging(hook, "commit", self.logger) - for hook in self.task.on_commit_hooks - ], + on_rollback_hooks=[self.handle_rollback] + self.task.on_rollback_hooks, + on_commit_hooks=self.task.on_commit_hooks, ) if transaction.is_committed(): terminal_state.name = "Cached" @@ -1622,28 +1607,3 @@ def run_task( return run_task_async(**kwargs) else: return run_task_sync(**kwargs) - - -def _with_transaction_hook_logging( - hook: Callable[[Transaction], None], - hook_type: Literal["rollback", "commit"], - logger: logging.Logger, -) -> Callable[[Transaction], None]: - @wraps(hook) - def _hook(txn: Transaction) -> None: - hook_name = _get_hook_name(hook) - logger.info(f"Running {hook_type} hook {hook_name!r}") - - try: - hook(txn) - except Exception as exc: - logger.error( - f"An error was encountered while running {hook_type} hook {hook_name!r}", - ) - raise exc - else: - logger.info( - f"{hook_type.capitalize()} hook {hook_name!r} finished running successfully" - ) - - return _hook diff --git a/src/prefect/transactions.py b/src/prefect/transactions.py index 41c6c07d3001..9274db2bdf63 100644 --- a/src/prefect/transactions.py +++ b/src/prefect/transactions.py @@ -26,6 +26,7 @@ ) from prefect.utilities.asyncutils import run_coro_as_sync from prefect.utilities.collections import AutoEnum +from prefect.utilities.engine import _get_hook_name class IsolationLevel(AutoEnum): @@ -183,7 +184,7 @@ def commit(self) -> bool: child.commit() for hook in self.on_commit_hooks: - hook(self) + self.run_hook(hook, "commit") if self.store and self.key: self.store.write(key=self.key, value=self._staged_value) @@ -198,6 +199,22 @@ def commit(self) -> bool: self.rollback() return False + def run_hook(self, hook, hook_type: str) -> None: + hook_name = _get_hook_name(hook) + self.logger.info(f"Running {hook_type} hook {hook_name!r}") + + try: + hook(self) + except Exception as exc: + self.logger.error( + f"An error was encountered while running {hook_type} hook {hook_name!r}", + ) + raise exc + else: + self.logger.info( + f"{hook_type.capitalize()} hook {hook_name!r} finished running successfully" + ) + def stage( self, value: BaseResult, @@ -222,7 +239,7 @@ def rollback(self) -> bool: try: for hook in reversed(self.on_rollback_hooks): - hook(self) + self.run_hook(hook, "rollback") self.state = TransactionState.ROLLED_BACK