From 4a7424d9ced3477ec21839c2f41413bed9d3ef74 Mon Sep 17 00:00:00 2001 From: nate nowack Date: Mon, 5 Aug 2024 10:19:25 -0500 Subject: [PATCH] Remove `wrapt` and fail on incorrect `await` usage (#14837) --- docs/3.0rc/resources/upgrade-prefect-3.mdx | 63 ++++++++++++++++- requirements-client.txt | 3 +- .../_internal/compatibility/deprecated.py | 53 --------------- src/prefect/futures.py | 5 -- src/prefect/tasks.py | 5 -- src/prefect/utilities/asyncutils.py | 6 -- .../flows/test_previously_awaited_methods.py | 67 +++++-------------- 7 files changed, 78 insertions(+), 124 deletions(-) diff --git a/docs/3.0rc/resources/upgrade-prefect-3.mdx b/docs/3.0rc/resources/upgrade-prefect-3.mdx index e6703604596a..a3156560190d 100644 --- a/docs/3.0rc/resources/upgrade-prefect-3.mdx +++ b/docs/3.0rc/resources/upgrade-prefect-3.mdx @@ -333,4 +333,65 @@ print(my_flow()) # Output: Failed(message='Flow failed due to task failure') ``` -Choose the strategy that best fits your specific use case and error handling requirements. \ No newline at end of file +Choose the strategy that best fits your specific use case and error handling requirements. + +## Breaking changes: Errors and resolutions + +#### `can't be used in 'await' expression` + +In Prefect 3, certain methods that were contextually sync/async in Prefect 2 are now synchronous: + +##### `Task` +- `submit` +- `map` + +##### `PrefectFuture` +- `result` +- `wait` + +Attempting to use `await` with these methods will result in a `TypeError`, like: + +```python +TypeError: object PrefectConcurrentFuture can't be used in 'await' expression +``` + +**Example and Resolution** + +You should **remove the `await` keyword** from calls of these methods in Prefect 3: + +```python +from prefect import flow, task +import asyncio + +@task +async def fetch_user_data(user_id): + return {"id": user_id, "score": user_id * 10} + +@task +def calculate_average(user_data): + return sum(user["score"] for user in user_data) / len(user_data) + +@flow +async def prefect_2_flow(n_users: int = 10): # ❌ + users = await fetch_user_data.map(range(1, n_users + 1)) + avg = calculate_average.submit(users) + print(f"Users: {await users.result()}") + print(f"Average score: {await avg.result()}") + +@flow +def prefect_3_flow(n_users: int = 10): # ✅ + users = fetch_user_data.map(range(1, n_users + 1)) + avg = calculate_average.submit(users) + print(f"Users: {users.result()}") + print(f"Average score: {avg.result()}") + +try: + asyncio.run(prefect_2_flow()) + raise AssertionError("Expected a TypeError") +except TypeError as e: + assert "can't be used in 'await' expression" in str(e) + +prefect_3_flow() +# Users: [{'id': 1, 'score': 10}, ... , {'id': 10, 'score': 100}] +# Average score: 55.0 +``` \ No newline at end of file diff --git a/requirements-client.txt b/requirements-client.txt index a7282cc38f21..0113ed7dd746 100644 --- a/requirements-client.txt +++ b/requirements-client.txt @@ -35,5 +35,4 @@ toml >= 0.10.0 typing_extensions >= 4.5.0, < 5.0.0 ujson >= 5.8.0, < 6.0.0 uvicorn >=0.14.0, !=0.29.0 -websockets >= 10.4, < 13.0 -wrapt >= 1.16.0 +websockets >= 10.4, < 13.0 \ No newline at end of file diff --git a/src/prefect/_internal/compatibility/deprecated.py b/src/prefect/_internal/compatibility/deprecated.py index 3acdc5ff74c2..c8d114d5ce20 100644 --- a/src/prefect/_internal/compatibility/deprecated.py +++ b/src/prefect/_internal/compatibility/deprecated.py @@ -16,7 +16,6 @@ from typing import Any, Callable, List, Optional, Type, TypeVar import pendulum -import wrapt from pydantic import BaseModel from prefect.utilities.callables import get_call_parameters @@ -273,55 +272,3 @@ def callback(_): DEPRECATED_MODULE_ALIASES.append( AliasedModuleDefinition(old_name, new_name, callback) ) - - -class AsyncCompatProxy(wrapt.ObjectProxy): - """ - A proxy object that allows for awaiting a method that is no longer async. - - See https://wrapt.readthedocs.io/en/master/wrappers.html#object-proxy for more - """ - - def __init__(self, wrapped, class_name: str, method_name: str): - super().__init__(wrapped) - self._self_class_name = class_name - self._self_method_name = method_name - self._self_already_awaited = False - - def __await__(self): - if not self._self_already_awaited: - warnings.warn( - ( - f"The {self._self_method_name!r} method on {self._self_class_name!r}" - " is no longer async and awaiting it will raise an error after Dec 2024" - " - please remove the `await` keyword." - ), - DeprecationWarning, - stacklevel=2, - ) - self._self_already_awaited = True - yield - return self.__wrapped__ - - def __repr__(self): - return repr(self.__wrapped__) - - def __reduce_ex__(self, protocol): - return ( - type(self), - (self.__wrapped__,), - {"_self_already_awaited": self._self_already_awaited}, - ) - - -def deprecated_async_method(wrapped): - """Decorator that wraps a sync method to allow awaiting it even though it is no longer async.""" - - @wrapt.decorator - def wrapper(wrapped, instance, args, kwargs): - result = wrapped(*args, **kwargs) - return AsyncCompatProxy( - result, class_name=instance.__class__.__name__, method_name=wrapped.__name__ - ) - - return wrapper(wrapped) diff --git a/src/prefect/futures.py b/src/prefect/futures.py index 136f409001b4..ba2d0d46a9d4 100644 --- a/src/prefect/futures.py +++ b/src/prefect/futures.py @@ -10,7 +10,6 @@ from typing_extensions import TypeVar -from prefect._internal.compatibility.deprecated import deprecated_async_method from prefect.client.orchestration import get_client from prefect.client.schemas.objects import TaskRun from prefect.exceptions import ObjectNotFound @@ -135,7 +134,6 @@ class PrefectConcurrentFuture(PrefectWrappedFuture[R, concurrent.futures.Future] when the task run is submitted to a ThreadPoolExecutor. """ - @deprecated_async_method def wait(self, timeout: Optional[float] = None) -> None: try: result = self._wrapped_future.result(timeout=timeout) @@ -144,7 +142,6 @@ def wait(self, timeout: Optional[float] = None) -> None: if isinstance(result, State): self._final_state = result - @deprecated_async_method def result( self, timeout: Optional[float] = None, @@ -198,7 +195,6 @@ class PrefectDistributedFuture(PrefectFuture[R]): done_callbacks: List[Callable[[PrefectFuture], None]] = [] waiter = None - @deprecated_async_method def wait(self, timeout: Optional[float] = None) -> None: return run_coro_as_sync(self.wait_async(timeout=timeout)) @@ -235,7 +231,6 @@ async def wait_async(self, timeout: Optional[float] = None): self._final_state = task_run.state return - @deprecated_async_method def result( self, timeout: Optional[float] = None, diff --git a/src/prefect/tasks.py b/src/prefect/tasks.py index c5406852e58c..8d3fc1ba59b8 100644 --- a/src/prefect/tasks.py +++ b/src/prefect/tasks.py @@ -33,9 +33,6 @@ from typing_extensions import Literal, ParamSpec import prefect.states -from prefect._internal.compatibility.deprecated import ( - deprecated_async_method, -) from prefect.cache_policies import DEFAULT, NONE, CachePolicy from prefect.client.orchestration import get_client from prefect.client.schemas import TaskRun @@ -1038,7 +1035,6 @@ def submit( ) -> State[T]: ... - @deprecated_async_method def submit( self, *args: Any, @@ -1203,7 +1199,6 @@ def map( ) -> PrefectFutureList[State[T]]: ... - @deprecated_async_method def map( self, *args: Any, diff --git a/src/prefect/utilities/asyncutils.py b/src/prefect/utilities/asyncutils.py index 5832fd3440aa..ed29e8a80a12 100644 --- a/src/prefect/utilities/asyncutils.py +++ b/src/prefect/utilities/asyncutils.py @@ -30,7 +30,6 @@ import anyio.from_thread import anyio.to_thread import sniffio -import wrapt from typing_extensions import Literal, ParamSpec, TypeGuard from prefect._internal.concurrency.api import _cast_to_call, from_sync @@ -210,11 +209,6 @@ def run_coro_as_sync( Returns: The result of the coroutine if wait_for_result is True, otherwise None. """ - if not asyncio.iscoroutine(coroutine): - if isinstance(coroutine, wrapt.ObjectProxy): - return coroutine.__wrapped__ - else: - raise TypeError("`coroutine` must be a coroutine object") async def coroutine_wrapper() -> Union[R, None]: """ diff --git a/tests/public/flows/test_previously_awaited_methods.py b/tests/public/flows/test_previously_awaited_methods.py index 854927ee1b46..997d3c6be1ac 100644 --- a/tests/public/flows/test_previously_awaited_methods.py +++ b/tests/public/flows/test_previously_awaited_methods.py @@ -1,60 +1,23 @@ -from prefect import flow, task - +import pytest -async def test_awaiting_formerly_async_methods(): - import warnings +from prefect import flow, task - N = 5 +@pytest.mark.parametrize( + "method,args", + [ + ("submit", (None,)), + ("map", ([None],)), + ], +) +async def test_awaiting_previously_async_task_methods_fail(method, args): @task - def get_random_number(_) -> int: + async def get_random_number(_) -> int: return 42 @flow - async def get_some_numbers_old_await_syntax(): - state1 = await get_random_number.submit(None, return_state=True) - assert state1.is_completed() - - future1 = await get_random_number.submit(None) - - await future1.wait() - assert await future1.result() == 42 - - list_of_futures = await get_random_number.map([None] * N) - [await future.wait() for future in list_of_futures] - assert all([await future.result() == 42 for future in list_of_futures]) - - @flow - async def get_some_numbers_new_way(): - state1 = get_random_number.submit(None, return_state=True) - assert state1.is_completed() - - future1 = get_random_number.submit(None) - future1.wait() - assert future1.result() == 42 - - list_of_futures = get_random_number.map([None] * N) - [future.wait() for future in list_of_futures] - assert all(future.result() == 42 for future in list_of_futures) - - # Test the old way (with await) - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - await get_some_numbers_old_await_syntax() - deprecation_warnings = [ - _w for _w in w if issubclass(_w.category, DeprecationWarning) - ] - - assert all( - "please remove the `await` keyword" in str(warning.message) - for warning in deprecation_warnings - ) + async def run_a_task(): + await getattr(get_random_number, method)(*args) - # Test the new way (without await) - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - await get_some_numbers_new_way() - deprecation_warnings = [ - _w for _w in w if issubclass(_w.category, DeprecationWarning) - ] - assert len(deprecation_warnings) == 0 + with pytest.raises(TypeError, match="can't be used in 'await' expression"): + await run_a_task()