Skip to content

Commit

Permalink
Consolidate where transaction hooks are run (#14877)
Browse files Browse the repository at this point in the history
  • Loading branch information
cicdw authored Aug 12, 2024
1 parent 6062153 commit 09f6f59
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 46 deletions.
48 changes: 4 additions & 44 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from asyncio import CancelledError
from contextlib import ExitStack, asynccontextmanager, contextmanager
from dataclasses import dataclass, field
from functools import wraps
from textwrap import dedent
from typing import (
Any,
Expand Down Expand Up @@ -496,15 +495,8 @@ def handle_success(self, result: R, transaction: Transaction) -> R:
)
transaction.stage(
terminal_state.data,
on_rollback_hooks=[self.handle_rollback]
+ [
_with_transaction_hook_logging(hook, "rollback", self.logger)
for hook in self.task.on_rollback_hooks
],
on_commit_hooks=[
_with_transaction_hook_logging(hook, "commit", self.logger)
for hook in self.task.on_commit_hooks
],
on_rollback_hooks=[self.handle_rollback] + self.task.on_rollback_hooks,
on_commit_hooks=self.task.on_commit_hooks,
)
if transaction.is_committed():
terminal_state.name = "Cached"
Expand Down Expand Up @@ -1068,15 +1060,8 @@ async def handle_success(self, result: R, transaction: Transaction) -> R:
)
transaction.stage(
terminal_state.data,
on_rollback_hooks=[self.handle_rollback]
+ [
_with_transaction_hook_logging(hook, "rollback", self.logger)
for hook in self.task.on_rollback_hooks
],
on_commit_hooks=[
_with_transaction_hook_logging(hook, "commit", self.logger)
for hook in self.task.on_commit_hooks
],
on_rollback_hooks=[self.handle_rollback] + self.task.on_rollback_hooks,
on_commit_hooks=self.task.on_commit_hooks,
)
if transaction.is_committed():
terminal_state.name = "Cached"
Expand Down Expand Up @@ -1622,28 +1607,3 @@ def run_task(
return run_task_async(**kwargs)
else:
return run_task_sync(**kwargs)


def _with_transaction_hook_logging(
hook: Callable[[Transaction], None],
hook_type: Literal["rollback", "commit"],
logger: logging.Logger,
) -> Callable[[Transaction], None]:
@wraps(hook)
def _hook(txn: Transaction) -> None:
hook_name = _get_hook_name(hook)
logger.info(f"Running {hook_type} hook {hook_name!r}")

try:
hook(txn)
except Exception as exc:
logger.error(
f"An error was encountered while running {hook_type} hook {hook_name!r}",
)
raise exc
else:
logger.info(
f"{hook_type.capitalize()} hook {hook_name!r} finished running successfully"
)

return _hook
21 changes: 19 additions & 2 deletions src/prefect/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from prefect.utilities.asyncutils import run_coro_as_sync
from prefect.utilities.collections import AutoEnum
from prefect.utilities.engine import _get_hook_name


class IsolationLevel(AutoEnum):
Expand Down Expand Up @@ -183,7 +184,7 @@ def commit(self) -> bool:
child.commit()

for hook in self.on_commit_hooks:
hook(self)
self.run_hook(hook, "commit")

if self.store and self.key:
self.store.write(key=self.key, value=self._staged_value)
Expand All @@ -198,6 +199,22 @@ def commit(self) -> bool:
self.rollback()
return False

def run_hook(self, hook, hook_type: str) -> None:
hook_name = _get_hook_name(hook)
self.logger.info(f"Running {hook_type} hook {hook_name!r}")

try:
hook(self)
except Exception as exc:
self.logger.error(
f"An error was encountered while running {hook_type} hook {hook_name!r}",
)
raise exc
else:
self.logger.info(
f"{hook_type.capitalize()} hook {hook_name!r} finished running successfully"
)

def stage(
self,
value: BaseResult,
Expand All @@ -222,7 +239,7 @@ def rollback(self) -> bool:

try:
for hook in reversed(self.on_rollback_hooks):
hook(self)
self.run_hook(hook, "rollback")

self.state = TransactionState.ROLLED_BACK

Expand Down

0 comments on commit 09f6f59

Please sign in to comment.