From c5f19705715af9745a96d59bfbdb6bad1839c411 Mon Sep 17 00:00:00 2001 From: nate nowack Date: Mon, 22 Jul 2024 18:06:50 -0500 Subject: [PATCH] Fix task key computation (#14704) --- src/prefect/tasks.py | 48 ++++++++++++------- .../test_task_run_state_change_events.py | 36 +++----------- tests/test_background_tasks.py | 22 --------- tests/test_tasks.py | 14 ++++++ 4 files changed, 52 insertions(+), 68 deletions(-) diff --git a/src/prefect/tasks.py b/src/prefect/tasks.py index 91489192c69a..c5406852e58c 100644 --- a/src/prefect/tasks.py +++ b/src/prefect/tasks.py @@ -6,7 +6,6 @@ import datetime import inspect -import os from copy import copy from functools import partial, update_wrapper from typing import ( @@ -188,6 +187,31 @@ def _infer_parent_task_runs( return parents +def _generate_task_key(fn: Callable[..., Any]) -> str: + """Generate a task key based on the function name and source code. + + We may eventually want some sort of top-level namespace here to + disambiguate tasks with the same function name in different modules, + in a more human-readable way, while avoiding relative import problems (see #12337). + + As long as the task implementations are unique (even if named the same), we should + not have any collisions. + + Args: + fn: The function to generate a task key for. + """ + if not hasattr(fn, "__qualname__"): + return to_qualified_name(type(fn)) + + qualname = fn.__qualname__.split(".")[-1] + + code_hash = ( + h[:NUM_CHARS_DYNAMIC_KEY] if (h := hash_objects(fn.__code__)) else "unknown" + ) + + return f"{qualname}-{code_hash}" + + class Task(Generic[P, R]): """ A Prefect task definition. @@ -270,7 +294,7 @@ def __init__( description: Optional[str] = None, tags: Optional[Iterable[str]] = None, version: Optional[str] = None, - cache_policy: Optional[CachePolicy] = NotSet, + cache_policy: Union[CachePolicy, Type[NotSet]] = NotSet, cache_key_fn: Optional[ Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] ] = None, @@ -369,17 +393,7 @@ def __init__( self.tags = set(tags if tags else []) - if not hasattr(self.fn, "__qualname__"): - self.task_key = to_qualified_name(type(self.fn)) - else: - try: - task_origin_hash = hash_objects( - self.name, os.path.abspath(inspect.getsourcefile(self.fn)) - ) - except TypeError: - task_origin_hash = "unknown-source-file" - - self.task_key = f"{self.fn.__qualname__}-{task_origin_hash}" + self.task_key = _generate_task_key(self.fn) if cache_policy is not NotSet and cache_key_fn is not None: logger.warning( @@ -1496,7 +1510,7 @@ async def serve(self) -> NoReturn: Args: task_runner: The task runner to use for serving the task. If not provided, - the default ConcurrentTaskRunner will be used. + the default task runner will be used. Examples: Serve a task using the default task runner @@ -1523,7 +1537,7 @@ def task( description: Optional[str] = None, tags: Optional[Iterable[str]] = None, version: Optional[str] = None, - cache_policy: CachePolicy = NotSet, + cache_policy: Union[CachePolicy, Type[NotSet]] = NotSet, cache_key_fn: Optional[ Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] ] = None, @@ -1561,7 +1575,9 @@ def task( tags: Optional[Iterable[str]] = None, version: Optional[str] = None, cache_policy: Union[CachePolicy, Type[NotSet]] = NotSet, - cache_key_fn: Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] = None, + cache_key_fn: Union[ + Callable[["TaskRunContext", Dict[str, Any]], Optional[str]], None + ] = None, cache_expiration: Optional[datetime.timedelta] = None, task_run_name: Optional[Union[Callable[[], str], str]] = None, retries: Optional[int] = None, diff --git a/tests/events/client/instrumentation/test_task_run_state_change_events.py b/tests/events/client/instrumentation/test_task_run_state_change_events.py index 88fbb5b392c9..8edf6084328f 100644 --- a/tests/events/client/instrumentation/test_task_run_state_change_events.py +++ b/tests/events/client/instrumentation/test_task_run_state_change_events.py @@ -60,11 +60,7 @@ def happy_path(): == task_run.expected_start_time ) assert pending.payload["task_run"].pop("estimated_start_time_delta") > 0.0 - assert ( - pending.payload["task_run"] - .pop("task_key") - .startswith("test_task_state_change_happy_path..happy_little_tree") - ) + assert pending.payload["task_run"].pop("task_key").startswith("happy_little_tree") assert pending.payload == { "initial_state": None, "intended": {"from": None, "to": "PENDING"}, @@ -112,11 +108,7 @@ def happy_path(): == task_run.expected_start_time ) assert running.payload["task_run"].pop("estimated_start_time_delta") > 0.0 - assert ( - running.payload["task_run"] - .pop("task_key") - .startswith("test_task_state_change_happy_path..happy_little_tree") - ) + assert running.payload["task_run"].pop("task_key").startswith("happy_little_tree") assert running.payload == { "intended": {"from": "PENDING", "to": "RUNNING"}, "initial_state": { @@ -169,11 +161,7 @@ def happy_path(): == task_run.expected_start_time ) assert completed.payload["task_run"].pop("estimated_start_time_delta") > 0.0 - assert ( - completed.payload["task_run"] - .pop("task_key") - .startswith("test_task_state_change_happy_path..happy_little_tree") - ) + assert completed.payload["task_run"].pop("task_key").startswith("happy_little_tree") assert completed.payload["task_run"].pop("estimated_run_time") > 0.0 assert ( pendulum.parse(completed.payload["task_run"].pop("start_time")) @@ -262,11 +250,7 @@ def happy_path(): == task_run.expected_start_time ) assert pending.payload["task_run"].pop("estimated_start_time_delta") > 0.0 - assert ( - pending.payload["task_run"] - .pop("task_key") - .startswith("test_task_state_change_task_failure..happy_little_tree") - ) + assert pending.payload["task_run"].pop("task_key").startswith("happy_little_tree") assert pending.payload == { "initial_state": None, "intended": {"from": None, "to": "PENDING"}, @@ -314,11 +298,7 @@ def happy_path(): == task_run.expected_start_time ) assert running.payload["task_run"].pop("estimated_start_time_delta") > 0.0 - assert ( - running.payload["task_run"] - .pop("task_key") - .startswith("test_task_state_change_task_failure..happy_little_tree") - ) + assert running.payload["task_run"].pop("task_key").startswith("happy_little_tree") assert running.payload == { "intended": {"from": "PENDING", "to": "RUNNING"}, "initial_state": { @@ -374,11 +354,7 @@ def happy_path(): == task_run.expected_start_time ) assert failed.payload["task_run"].pop("estimated_start_time_delta") > 0.0 - assert ( - failed.payload["task_run"] - .pop("task_key") - .startswith("test_task_state_change_task_failure..happy_little_tree") - ) + assert failed.payload["task_run"].pop("task_key").startswith("happy_little_tree") assert failed.payload["task_run"].pop("estimated_run_time") > 0.0 assert ( pendulum.parse(failed.payload["task_run"].pop("start_time")) diff --git a/tests/test_background_tasks.py b/tests/test_background_tasks.py index 718f55a92d30..d4767796f510 100644 --- a/tests/test_background_tasks.py +++ b/tests/test_background_tasks.py @@ -1,6 +1,4 @@ import asyncio -import inspect -import os from datetime import timedelta from pathlib import Path from typing import TYPE_CHECKING, AsyncGenerator, Iterable, Tuple @@ -26,7 +24,6 @@ temporary_settings, ) from prefect.task_worker import TaskWorker -from prefect.utilities.hashing import hash_objects if TYPE_CHECKING: from prefect.client.orchestration import PrefectClient @@ -447,22 +444,3 @@ async def bar(x: int, mappable: Iterable) -> Tuple[int, Iterable]: "parameters": {"x": i + 1, "mappable": ["some", "iterable"]}, "context": mock.ANY, } - - -class TestTaskKey: - def test_task_key_includes_qualname_and_source_file_hash(self): - def some_fn(): - pass - - t = Task(fn=some_fn) - source_file = os.path.abspath(inspect.getsourcefile(some_fn)) - task_origin_hash = hash_objects(t.name, source_file) - assert t.task_key == f"{some_fn.__qualname__}-{task_origin_hash}" - - def test_task_key_handles_unknown_source_file(self, monkeypatch): - def some_fn(): - pass - - monkeypatch.setattr(inspect, "getsourcefile", lambda x: None) - t = Task(fn=some_fn) - assert t.task_key == f"{some_fn.__qualname__}-unknown-source-file" diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 02bb90bd9f0f..1210dea1e08a 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -102,6 +102,20 @@ def my_task(): assert my_task.name == "another_name" +class TestTaskKey: + def test_task_key_typical_case(self): + @task + def my_task(): + pass + + assert my_task.task_key.startswith("my_task-") + + def test_task_key_after_import(self): + from tests.generic_tasks import noop + + assert noop.task_key.startswith("noop-") + + class TestTaskRunName: def test_run_name_default(self): @task