diff --git a/src/prefect/engine.py b/src/prefect/engine.py index ef3fa7e790ea..ec9f286cb3d0 100644 --- a/src/prefect/engine.py +++ b/src/prefect/engine.py @@ -167,6 +167,7 @@ from prefect.settings import ( PREFECT_DEBUG_MODE, PREFECT_EXPERIMENTAL_ENABLE_NEW_ENGINE, + PREFECT_RUN_ON_COMPLETION_HOOKS_ON_CACHED, PREFECT_TASK_INTROSPECTION_WARN_THRESHOLD, PREFECT_TASKS_REFRESH_CACHE, PREFECT_UI_URL, @@ -2117,6 +2118,19 @@ async def tick(): # flag to ensure we only update the task run name once run_name_set = False + run_on_completion_hooks_on_cached = ( + PREFECT_RUN_ON_COMPLETION_HOOKS_ON_CACHED + and state.is_completed() + and state.name == "Cached" + ) + + if run_on_completion_hooks_on_cached: + await _run_task_hooks( + task=task, + task_run=task_run, + state=state, + ) + # Only run the task if we enter a `RUNNING` state while state.is_running(): # Retrieve the latest metadata for the task run context @@ -2326,9 +2340,16 @@ async def _run_task_hooks(task: Task, task_run: TaskRun, state: State) -> None: catch and log any errors that occur. """ hooks = None + run_on_completion_hooks_on_cached = ( + PREFECT_RUN_ON_COMPLETION_HOOKS_ON_CACHED + and state.is_completed() + and state.name == "Cached" + ) if state.is_failed() and task.on_failure: hooks = task.on_failure - elif state.is_completed() and task.on_completion: + elif ( + state.is_completed() or run_on_completion_hooks_on_cached + ) and task.on_completion: hooks = task.on_completion if hooks: diff --git a/src/prefect/settings.py b/src/prefect/settings.py index ec2b42491624..7130ca65e20a 100644 --- a/src/prefect/settings.py +++ b/src/prefect/settings.py @@ -1759,6 +1759,11 @@ def default_cloud_ui_url(settings, value): How long to cache related resource data for emitting server-side vents """ +PREFECT_RUN_ON_COMPLETION_HOOKS_ON_CACHED = Setting(bool, default=False) +""" +Whether or not to run on_completion hooks on cached task runs. +""" + def automation_settings_enabled() -> bool: """ diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 5d3b6e92ba1a..849660b95714 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -29,6 +29,7 @@ from prefect.settings import ( PREFECT_DEBUG_MODE, PREFECT_EXPERIMENTAL_ENABLE_NEW_ENGINE, + PREFECT_RUN_ON_COMPLETION_HOOKS_ON_CACHED, PREFECT_TASK_DEFAULT_RETRIES, PREFECT_TASKS_REFRESH_CACHE, temporary_settings, @@ -3767,6 +3768,90 @@ def my_flow(): assert my_mock.call_args_list == [call(), call()] +class TestOnCompletionHooksForCachedTasks: + @pytest.mark.parametrize("run_hooks_on_cached", [True, False]) + def test_on_completion_hooks_for_cached_tasks(self, run_hooks_on_cached): + my_mock = MagicMock() + + def completed1(task, task_run, state): + my_mock("completed1") + + def completed2(task, task_run, state): + my_mock("completed2") + + @task( + cache_key_fn=lambda *_: "cache_2", + cache_expiration=datetime.timedelta(seconds=5), + on_completion=[completed1, completed2], + ) + def cached_task(): + return "cached_result" + + @flow + def test_flow(): + return cached_task._run() + + with temporary_settings( + {PREFECT_RUN_ON_COMPLETION_HOOKS_ON_CACHED: run_hooks_on_cached} + ): + # First run to cache the result + state = test_flow() + assert my_mock.call_args_list == [call("completed1"), call("completed2")] + assert not (state.is_completed() and state.name == "Cached") + + my_mock.reset_mock() + + # Second run to test the hooks behavior + state = test_flow() + assert state.is_completed() and state.name == "Cached" + if run_hooks_on_cached: + assert my_mock.call_args_list == [ + call("completed1"), + call("completed2"), + ] + else: + assert my_mock.call_args_list == [] + + @pytest.mark.parametrize("run_hooks_on_cached", [True, False]) + def test_exception_in_on_completion_hook_for_cached_task( + self, caplog, run_hooks_on_cached + ): + def failing_hook(task, task_run, state): + raise ValueError("Hook failed") + + @task( + cache_key_fn=lambda *_: "cache_3", + cache_expiration=datetime.timedelta(seconds=5), + on_completion=[failing_hook], + ) + def cached_task(): + return "cached_result" + + @flow + def test_flow(): + return cached_task._run() + + with temporary_settings( + {PREFECT_RUN_ON_COMPLETION_HOOKS_ON_CACHED: run_hooks_on_cached} + ): + # First run to cache the result + state = test_flow() + assert state.result() == "cached_result" + assert not (state.is_completed() and state.name == "Cached") + + caplog.clear() + + # Second run to test the hook behavior + state = test_flow() + assert state.is_completed() and state.name == "Cached" + assert state.result() == "cached_result" + + if run_hooks_on_cached: + assert "ValueError: Hook failed" in caplog.text + else: + assert "ValueError: Hook failed" not in caplog.text + + class TestTaskHooksOnFailure: def test_noniterable_hook_raises(self): def failure_hook():