From e301f54dca4749e0272789a614bff2937e4d0ccd Mon Sep 17 00:00:00 2001 From: zzstoatzz Date: Mon, 22 Jul 2024 15:22:46 -0500 Subject: [PATCH 1/7] fix task key computation --- src/prefect/task_worker.py | 1 - src/prefect/tasks.py | 45 ++++++++++++++++++++++++-------------- tests/test_tasks.py | 14 ++++++++++++ 3 files changed, 42 insertions(+), 18 deletions(-) diff --git a/src/prefect/task_worker.py b/src/prefect/task_worker.py index fa1bc9003f5e..442f1050f197 100644 --- a/src/prefect/task_worker.py +++ b/src/prefect/task_worker.py @@ -402,7 +402,6 @@ def status(): return status_app -@sync_compatible async def serve( *tasks: Task, limit: Optional[int] = 10, status_server_port: Optional[int] = None ): diff --git a/src/prefect/tasks.py b/src/prefect/tasks.py index 91489192c69a..d82a3124f281 100644 --- a/src/prefect/tasks.py +++ b/src/prefect/tasks.py @@ -4,9 +4,9 @@ # This file requires type-checking with pyright because mypy does not yet support PEP612 # See https://github.com/python/mypy/issues/8645 +import asyncio import datetime import inspect -import os from copy import copy from functools import partial, update_wrapper from typing import ( @@ -64,7 +64,6 @@ from prefect.utilities.annotations import NotSet from prefect.utilities.asyncutils import ( run_coro_as_sync, - sync_compatible, ) from prefect.utilities.callables import ( expand_mapping_parameters, @@ -188,6 +187,24 @@ def _infer_parent_task_runs( return parents +def _generate_task_key(fn: Callable) -> str: + """Generate a task key based on the function name and source code. + + We may eventually want some sort of top-level namespace here to + disambiguate tasks with the same function name in different modules, + in a more human-readable way, while avoiding relative import problems (see #12337). + + As long as the task implementations are unique (even if named the same), we should + not have any collisions. + + Args: + fn: The function to generate a task key for. + """ + qualname = fn.__qualname__.split(".")[-1] + source_hash = h[:8] if (h := hash_objects(inspect.getsource(fn))) else "unknown" + return f"{qualname}-{source_hash}" + + class Task(Generic[P, R]): """ A Prefect task definition. @@ -270,7 +287,7 @@ def __init__( description: Optional[str] = None, tags: Optional[Iterable[str]] = None, version: Optional[str] = None, - cache_policy: Optional[CachePolicy] = NotSet, + cache_policy: Union[CachePolicy, Type[NotSet]] = NotSet, cache_key_fn: Optional[ Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] ] = None, @@ -372,14 +389,7 @@ def __init__( if not hasattr(self.fn, "__qualname__"): self.task_key = to_qualified_name(type(self.fn)) else: - try: - task_origin_hash = hash_objects( - self.name, os.path.abspath(inspect.getsourcefile(self.fn)) - ) - except TypeError: - task_origin_hash = "unknown-source-file" - - self.task_key = f"{self.fn.__qualname__}-{task_origin_hash}" + self.task_key = _generate_task_key(self.fn) if cache_policy is not NotSet and cache_key_fn is not None: logger.warning( @@ -1488,15 +1498,14 @@ def delay(self, *args: P.args, **kwargs: P.kwargs) -> PrefectDistributedFuture: """ return self.apply_async(args=args, kwargs=kwargs) - @sync_compatible - async def serve(self) -> NoReturn: + def serve(self): """Serve the task using the provided task runner. This method is used to establish a websocket connection with the Prefect server and listen for submitted task runs to execute. Args: task_runner: The task runner to use for serving the task. If not provided, - the default ConcurrentTaskRunner will be used. + the default task runner will be used. Examples: Serve a task using the default task runner @@ -1508,7 +1517,7 @@ async def serve(self) -> NoReturn: """ from prefect.task_worker import serve - await serve(self) + asyncio.run(serve(self)) @overload @@ -1523,7 +1532,7 @@ def task( description: Optional[str] = None, tags: Optional[Iterable[str]] = None, version: Optional[str] = None, - cache_policy: CachePolicy = NotSet, + cache_policy: Union[CachePolicy, Type[NotSet]] = NotSet, cache_key_fn: Optional[ Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] ] = None, @@ -1561,7 +1570,9 @@ def task( tags: Optional[Iterable[str]] = None, version: Optional[str] = None, cache_policy: Union[CachePolicy, Type[NotSet]] = NotSet, - cache_key_fn: Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] = None, + cache_key_fn: Union[ + Callable[["TaskRunContext", Dict[str, Any]], Optional[str]], None + ] = None, cache_expiration: Optional[datetime.timedelta] = None, task_run_name: Optional[Union[Callable[[], str], str]] = None, retries: Optional[int] = None, diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 02bb90bd9f0f..1210dea1e08a 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -102,6 +102,20 @@ def my_task(): assert my_task.name == "another_name" +class TestTaskKey: + def test_task_key_typical_case(self): + @task + def my_task(): + pass + + assert my_task.task_key.startswith("my_task-") + + def test_task_key_after_import(self): + from tests.generic_tasks import noop + + assert noop.task_key.startswith("noop-") + + class TestTaskRunName: def test_run_name_default(self): @task From 60a347818e8010a4d6c32bfd1255675264344c06 Mon Sep 17 00:00:00 2001 From: zzstoatzz Date: Mon, 22 Jul 2024 15:32:21 -0500 Subject: [PATCH 2/7] reduce scope --- src/prefect/task_worker.py | 1 + src/prefect/tasks.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/prefect/task_worker.py b/src/prefect/task_worker.py index 442f1050f197..fa1bc9003f5e 100644 --- a/src/prefect/task_worker.py +++ b/src/prefect/task_worker.py @@ -402,6 +402,7 @@ def status(): return status_app +@sync_compatible async def serve( *tasks: Task, limit: Optional[int] = 10, status_server_port: Optional[int] = None ): diff --git a/src/prefect/tasks.py b/src/prefect/tasks.py index d82a3124f281..483f9c22dc41 100644 --- a/src/prefect/tasks.py +++ b/src/prefect/tasks.py @@ -4,7 +4,6 @@ # This file requires type-checking with pyright because mypy does not yet support PEP612 # See https://github.com/python/mypy/issues/8645 -import asyncio import datetime import inspect from copy import copy @@ -64,6 +63,7 @@ from prefect.utilities.annotations import NotSet from prefect.utilities.asyncutils import ( run_coro_as_sync, + sync_compatible, ) from prefect.utilities.callables import ( expand_mapping_parameters, @@ -1498,7 +1498,8 @@ def delay(self, *args: P.args, **kwargs: P.kwargs) -> PrefectDistributedFuture: """ return self.apply_async(args=args, kwargs=kwargs) - def serve(self): + @sync_compatible + async def serve(self) -> NoReturn: """Serve the task using the provided task runner. This method is used to establish a websocket connection with the Prefect server and listen for submitted task runs to execute. @@ -1517,7 +1518,7 @@ def serve(self): """ from prefect.task_worker import serve - asyncio.run(serve(self)) + await serve(self) @overload From 88c8d947bfe99d7c12548383a861e4dddbc4ef7c Mon Sep 17 00:00:00 2001 From: zzstoatzz Date: Mon, 22 Jul 2024 15:46:41 -0500 Subject: [PATCH 3/7] encapsulate qualname check --- src/prefect/tasks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/prefect/tasks.py b/src/prefect/tasks.py index 483f9c22dc41..09d10220e00d 100644 --- a/src/prefect/tasks.py +++ b/src/prefect/tasks.py @@ -200,6 +200,9 @@ def _generate_task_key(fn: Callable) -> str: Args: fn: The function to generate a task key for. """ + if not hasattr(fn, "__qualname__"): + return to_qualified_name(type(fn)) + qualname = fn.__qualname__.split(".")[-1] source_hash = h[:8] if (h := hash_objects(inspect.getsource(fn))) else "unknown" return f"{qualname}-{source_hash}" @@ -386,10 +389,7 @@ def __init__( self.tags = set(tags if tags else []) - if not hasattr(self.fn, "__qualname__"): - self.task_key = to_qualified_name(type(self.fn)) - else: - self.task_key = _generate_task_key(self.fn) + self.task_key = _generate_task_key(self.fn) if cache_policy is not NotSet and cache_key_fn is not None: logger.warning( From 522e7626da4118b2b23e78219ebbd40af7fc04dd Mon Sep 17 00:00:00 2001 From: zzstoatzz Date: Mon, 22 Jul 2024 16:30:48 -0500 Subject: [PATCH 4/7] try to fix task tests --- src/prefect/tasks.py | 6 ++++-- tests/test_background_tasks.py | 22 ---------------------- 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/src/prefect/tasks.py b/src/prefect/tasks.py index 09d10220e00d..0409a2a1624f 100644 --- a/src/prefect/tasks.py +++ b/src/prefect/tasks.py @@ -204,8 +204,10 @@ def _generate_task_key(fn: Callable) -> str: return to_qualified_name(type(fn)) qualname = fn.__qualname__.split(".")[-1] - source_hash = h[:8] if (h := hash_objects(inspect.getsource(fn))) else "unknown" - return f"{qualname}-{source_hash}" + + code_hash = h[:8] if (h := hash_objects(fn.__code__)) else "unknown" + + return f"{qualname}-{code_hash}" class Task(Generic[P, R]): diff --git a/tests/test_background_tasks.py b/tests/test_background_tasks.py index 718f55a92d30..d4767796f510 100644 --- a/tests/test_background_tasks.py +++ b/tests/test_background_tasks.py @@ -1,6 +1,4 @@ import asyncio -import inspect -import os from datetime import timedelta from pathlib import Path from typing import TYPE_CHECKING, AsyncGenerator, Iterable, Tuple @@ -26,7 +24,6 @@ temporary_settings, ) from prefect.task_worker import TaskWorker -from prefect.utilities.hashing import hash_objects if TYPE_CHECKING: from prefect.client.orchestration import PrefectClient @@ -447,22 +444,3 @@ async def bar(x: int, mappable: Iterable) -> Tuple[int, Iterable]: "parameters": {"x": i + 1, "mappable": ["some", "iterable"]}, "context": mock.ANY, } - - -class TestTaskKey: - def test_task_key_includes_qualname_and_source_file_hash(self): - def some_fn(): - pass - - t = Task(fn=some_fn) - source_file = os.path.abspath(inspect.getsourcefile(some_fn)) - task_origin_hash = hash_objects(t.name, source_file) - assert t.task_key == f"{some_fn.__qualname__}-{task_origin_hash}" - - def test_task_key_handles_unknown_source_file(self, monkeypatch): - def some_fn(): - pass - - monkeypatch.setattr(inspect, "getsourcefile", lambda x: None) - t = Task(fn=some_fn) - assert t.task_key == f"{some_fn.__qualname__}-unknown-source-file" From a63c9bdef17bd34edfe61d8c3c29c01abc191d2a Mon Sep 17 00:00:00 2001 From: zzstoatzz Date: Mon, 22 Jul 2024 16:43:08 -0500 Subject: [PATCH 5/7] fix expected task keys in tests --- .../test_task_run_state_change_events.py | 36 ++++--------------- 1 file changed, 6 insertions(+), 30 deletions(-) diff --git a/tests/events/client/instrumentation/test_task_run_state_change_events.py b/tests/events/client/instrumentation/test_task_run_state_change_events.py index 88fbb5b392c9..8edf6084328f 100644 --- a/tests/events/client/instrumentation/test_task_run_state_change_events.py +++ b/tests/events/client/instrumentation/test_task_run_state_change_events.py @@ -60,11 +60,7 @@ def happy_path(): == task_run.expected_start_time ) assert pending.payload["task_run"].pop("estimated_start_time_delta") > 0.0 - assert ( - pending.payload["task_run"] - .pop("task_key") - .startswith("test_task_state_change_happy_path..happy_little_tree") - ) + assert pending.payload["task_run"].pop("task_key").startswith("happy_little_tree") assert pending.payload == { "initial_state": None, "intended": {"from": None, "to": "PENDING"}, @@ -112,11 +108,7 @@ def happy_path(): == task_run.expected_start_time ) assert running.payload["task_run"].pop("estimated_start_time_delta") > 0.0 - assert ( - running.payload["task_run"] - .pop("task_key") - .startswith("test_task_state_change_happy_path..happy_little_tree") - ) + assert running.payload["task_run"].pop("task_key").startswith("happy_little_tree") assert running.payload == { "intended": {"from": "PENDING", "to": "RUNNING"}, "initial_state": { @@ -169,11 +161,7 @@ def happy_path(): == task_run.expected_start_time ) assert completed.payload["task_run"].pop("estimated_start_time_delta") > 0.0 - assert ( - completed.payload["task_run"] - .pop("task_key") - .startswith("test_task_state_change_happy_path..happy_little_tree") - ) + assert completed.payload["task_run"].pop("task_key").startswith("happy_little_tree") assert completed.payload["task_run"].pop("estimated_run_time") > 0.0 assert ( pendulum.parse(completed.payload["task_run"].pop("start_time")) @@ -262,11 +250,7 @@ def happy_path(): == task_run.expected_start_time ) assert pending.payload["task_run"].pop("estimated_start_time_delta") > 0.0 - assert ( - pending.payload["task_run"] - .pop("task_key") - .startswith("test_task_state_change_task_failure..happy_little_tree") - ) + assert pending.payload["task_run"].pop("task_key").startswith("happy_little_tree") assert pending.payload == { "initial_state": None, "intended": {"from": None, "to": "PENDING"}, @@ -314,11 +298,7 @@ def happy_path(): == task_run.expected_start_time ) assert running.payload["task_run"].pop("estimated_start_time_delta") > 0.0 - assert ( - running.payload["task_run"] - .pop("task_key") - .startswith("test_task_state_change_task_failure..happy_little_tree") - ) + assert running.payload["task_run"].pop("task_key").startswith("happy_little_tree") assert running.payload == { "intended": {"from": "PENDING", "to": "RUNNING"}, "initial_state": { @@ -374,11 +354,7 @@ def happy_path(): == task_run.expected_start_time ) assert failed.payload["task_run"].pop("estimated_start_time_delta") > 0.0 - assert ( - failed.payload["task_run"] - .pop("task_key") - .startswith("test_task_state_change_task_failure..happy_little_tree") - ) + assert failed.payload["task_run"].pop("task_key").startswith("happy_little_tree") assert failed.payload["task_run"].pop("estimated_run_time") > 0.0 assert ( pendulum.parse(failed.payload["task_run"].pop("start_time")) From cbc661a3a34b0cb6049171a472a88a8cb92c9c85 Mon Sep 17 00:00:00 2001 From: zzstoatzz Date: Mon, 22 Jul 2024 17:43:07 -0500 Subject: [PATCH 6/7] use existing const NUM_CHARS_DYNAMIC_KEY --- src/prefect/tasks.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/prefect/tasks.py b/src/prefect/tasks.py index 0409a2a1624f..c5331d10149a 100644 --- a/src/prefect/tasks.py +++ b/src/prefect/tasks.py @@ -205,7 +205,9 @@ def _generate_task_key(fn: Callable) -> str: qualname = fn.__qualname__.split(".")[-1] - code_hash = h[:8] if (h := hash_objects(fn.__code__)) else "unknown" + code_hash = ( + h[:NUM_CHARS_DYNAMIC_KEY] if (h := hash_objects(fn.__code__)) else "unknown" + ) return f"{qualname}-{code_hash}" From e7d869161b12c8560b8e56ee413b592d7abd429c Mon Sep 17 00:00:00 2001 From: zzstoatzz Date: Mon, 22 Jul 2024 17:45:50 -0500 Subject: [PATCH 7/7] better typing --- src/prefect/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/prefect/tasks.py b/src/prefect/tasks.py index c5331d10149a..c5406852e58c 100644 --- a/src/prefect/tasks.py +++ b/src/prefect/tasks.py @@ -187,7 +187,7 @@ def _infer_parent_task_runs( return parents -def _generate_task_key(fn: Callable) -> str: +def _generate_task_key(fn: Callable[..., Any]) -> str: """Generate a task key based on the function name and source code. We may eventually want some sort of top-level namespace here to