Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2.x add cached completion hooks #15270

Merged
merged 13 commits into from
Sep 10, 2024
11 changes: 7 additions & 4 deletions src/prefect/client/schemas/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,12 @@ class State(ObjectBaseModel, Generic[R]):
)

@overload
def result(self: "State[R]", raise_on_failure: bool = True) -> R:
...
def result(self: "State[R]", raise_on_failure: bool = True) -> R: ...

@overload
def result(self: "State[R]", raise_on_failure: bool = False) -> Union[R, Exception]:
...
def result(
self: "State[R]", raise_on_failure: bool = False
) -> Union[R, Exception]: ...

def result(
self, raise_on_failure: bool = True, fetch: Optional[bool] = None
Expand Down Expand Up @@ -305,6 +305,9 @@ def is_cancelled(self) -> bool:
def is_cancelling(self) -> bool:
return self.type == StateType.CANCELLING

def is_cached(self) -> bool:
return self.name == "Cached"
derekahuang marked this conversation as resolved.
Show resolved Hide resolved

def is_final(self) -> bool:
return self.type in {
StateType.CANCELLED,
Expand Down
32 changes: 23 additions & 9 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
from prefect.settings import (
PREFECT_DEBUG_MODE,
PREFECT_EXPERIMENTAL_ENABLE_NEW_ENGINE,
PREFECT_RUN_ON_COMPLETION_HOOKS_ON_CACHED,
PREFECT_TASK_INTROSPECTION_WARN_THRESHOLD,
PREFECT_TASKS_REFRESH_CACHE,
PREFECT_UI_URL,
Expand Down Expand Up @@ -1003,8 +1004,7 @@ async def pause_flow_run(
poll_interval: int = 10,
reschedule: bool = False,
key: str = None,
) -> None:
...
) -> None: ...


@deprecated_callable(
Expand All @@ -1019,8 +1019,7 @@ async def pause_flow_run(
poll_interval: int = 10,
reschedule: bool = False,
key: str = None,
) -> T:
...
) -> T: ...


@sync_compatible
Expand Down Expand Up @@ -1255,8 +1254,7 @@ async def suspend_flow_run(
timeout: Optional[int] = 3600,
key: Optional[str] = None,
client: PrefectClient = None,
) -> None:
...
) -> None: ...


@overload
Expand All @@ -1266,8 +1264,7 @@ async def suspend_flow_run(
timeout: Optional[int] = 3600,
key: Optional[str] = None,
client: PrefectClient = None,
) -> T:
...
) -> T: ...


@sync_compatible
Expand Down Expand Up @@ -2117,6 +2114,18 @@ async def tick():
# flag to ensure we only update the task run name once
run_name_set = False

run_on_completion_hooks_on_cached = (
PREFECT_RUN_ON_COMPLETION_HOOKS_ON_CACHED and state.is_cached()
)

if run_on_completion_hooks_on_cached:
task_run = await client.read_task_run(task_run.id)
derekahuang marked this conversation as resolved.
Show resolved Hide resolved
await _run_task_hooks(
task=task,
task_run=task_run,
state=state,
)

# Only run the task if we enter a `RUNNING` state
while state.is_running():
# Retrieve the latest metadata for the task run context
Expand Down Expand Up @@ -2326,9 +2335,14 @@ async def _run_task_hooks(task: Task, task_run: TaskRun, state: State) -> None:
catch and log any errors that occur.
"""
hooks = None
run_on_completion_hooks_on_cached = (
PREFECT_RUN_ON_COMPLETION_HOOKS_ON_CACHED and state.is_cached()
zzstoatzz marked this conversation as resolved.
Show resolved Hide resolved
)
if state.is_failed() and task.on_failure:
hooks = task.on_failure
elif state.is_completed() and task.on_completion:
elif (
state.is_completed() or run_on_completion_hooks_on_cached
) and task.on_completion:
hooks = task.on_completion

if hooks:
Expand Down
4 changes: 4 additions & 0 deletions src/prefect/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,6 +1759,10 @@ def default_cloud_ui_url(settings, value):
How long to cache related resource data for emitting server-side vents
"""

PREFECT_RUN_ON_COMPLETION_HOOKS_ON_CACHED = Setting(bool, default=False)
"""
Whether or not to run on_completion hooks on cached task runs.
"""

def automation_settings_enabled() -> bool:
"""
Expand Down
85 changes: 85 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from prefect.settings import (
PREFECT_DEBUG_MODE,
PREFECT_EXPERIMENTAL_ENABLE_NEW_ENGINE,
PREFECT_RUN_ON_COMPLETION_HOOKS_ON_CACHED,
PREFECT_TASK_DEFAULT_RETRIES,
PREFECT_TASKS_REFRESH_CACHE,
temporary_settings,
Expand Down Expand Up @@ -3767,6 +3768,90 @@ def my_flow():
assert my_mock.call_args_list == [call(), call()]


class TestOnCompletionHooksForCachedTasks:
@pytest.mark.parametrize("run_hooks_on_cached", [True, False])
def test_on_completion_hooks_for_cached_tasks(self, run_hooks_on_cached):
my_mock = MagicMock()

def completed1(task, task_run, state):
my_mock("completed1")

def completed2(task, task_run, state):
my_mock("completed2")

@task(
cache_key_fn=lambda *_: "cache_2",
cache_expiration=datetime.timedelta(seconds=5),
on_completion=[completed1, completed2],
)
def cached_task():
return "cached_result"

@flow
def test_flow():
return cached_task._run()

with temporary_settings(
{PREFECT_RUN_ON_COMPLETION_HOOKS_ON_CACHED: run_hooks_on_cached}
):
# First run to cache the result
state = test_flow()
assert my_mock.call_args_list == [call("completed1"), call("completed2")]
assert not state.is_cached()

my_mock.reset_mock()

# Second run to test the hooks behavior
state = test_flow()
assert state.is_cached()
if run_hooks_on_cached:
assert my_mock.call_args_list == [
call("completed1"),
call("completed2"),
]
else:
assert my_mock.call_args_list == []

@pytest.mark.parametrize("run_hooks_on_cached", [True, False])
def test_exception_in_on_completion_hook_for_cached_task(
self, caplog, run_hooks_on_cached
):
def failing_hook(task, task_run, state):
raise ValueError("Hook failed")

@task(
cache_key_fn=lambda *_: "cache_3",
cache_expiration=datetime.timedelta(seconds=5),
on_completion=[failing_hook],
)
def cached_task():
return "cached_result"

@flow
def test_flow():
return cached_task._run()

with temporary_settings(
{PREFECT_RUN_ON_COMPLETION_HOOKS_ON_CACHED: run_hooks_on_cached}
):
# First run to cache the result
state = test_flow()
assert state.result() == "cached_result"
assert not state.is_cached()

caplog.clear()

# Second run to test the hook behavior
state = test_flow()
assert state.is_cached()
assert state.result() == "cached_result"

if run_hooks_on_cached:
assert "ValueError: Hook failed" in caplog.text
else:
assert "ValueError: Hook failed" not in caplog.text


class TestTaskHooksOnFailure:
def test_noniterable_hook_raises(self):
def failure_hook():
Expand Down
Loading