Skip to content

Commit

Permalink
Remove wrapt and fail on incorrect await usage (#14837)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Aug 5, 2024
1 parent 12983a5 commit 4a7424d
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 124 deletions.
63 changes: 62 additions & 1 deletion docs/3.0rc/resources/upgrade-prefect-3.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -333,4 +333,65 @@ print(my_flow()) # Output: Failed(message='Flow failed due to task failure')
```
</CodeGroup>

Choose the strategy that best fits your specific use case and error handling requirements.
Choose the strategy that best fits your specific use case and error handling requirements.

## Breaking changes: Errors and resolutions

#### `can't be used in 'await' expression`

In Prefect 3, certain methods that were contextually sync/async in Prefect 2 are now synchronous:

##### `Task`
- `submit`
- `map`

##### `PrefectFuture`
- `result`
- `wait`

Attempting to use `await` with these methods will result in a `TypeError`, like:

```python
TypeError: object PrefectConcurrentFuture can't be used in 'await' expression
```

**Example and Resolution**

You should **remove the `await` keyword** from calls of these methods in Prefect 3:

```python
from prefect import flow, task
import asyncio

@task
async def fetch_user_data(user_id):
return {"id": user_id, "score": user_id * 10}

@task
def calculate_average(user_data):
return sum(user["score"] for user in user_data) / len(user_data)

@flow
async def prefect_2_flow(n_users: int = 10): #
users = await fetch_user_data.map(range(1, n_users + 1))
avg = calculate_average.submit(users)
print(f"Users: {await users.result()}")
print(f"Average score: {await avg.result()}")

@flow
def prefect_3_flow(n_users: int = 10): #
users = fetch_user_data.map(range(1, n_users + 1))
avg = calculate_average.submit(users)
print(f"Users: {users.result()}")
print(f"Average score: {avg.result()}")

try:
asyncio.run(prefect_2_flow())
raise AssertionError("Expected a TypeError")
except TypeError as e:
assert "can't be used in 'await' expression" in str(e)

prefect_3_flow()
# Users: [{'id': 1, 'score': 10}, ... , {'id': 10, 'score': 100}]
# Average score: 55.0
```
3 changes: 1 addition & 2 deletions requirements-client.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,4 @@ toml >= 0.10.0
typing_extensions >= 4.5.0, < 5.0.0
ujson >= 5.8.0, < 6.0.0
uvicorn >=0.14.0, !=0.29.0
websockets >= 10.4, < 13.0
wrapt >= 1.16.0
websockets >= 10.4, < 13.0
53 changes: 0 additions & 53 deletions src/prefect/_internal/compatibility/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from typing import Any, Callable, List, Optional, Type, TypeVar

import pendulum
import wrapt
from pydantic import BaseModel

from prefect.utilities.callables import get_call_parameters
Expand Down Expand Up @@ -273,55 +272,3 @@ def callback(_):
DEPRECATED_MODULE_ALIASES.append(
AliasedModuleDefinition(old_name, new_name, callback)
)


class AsyncCompatProxy(wrapt.ObjectProxy):
"""
A proxy object that allows for awaiting a method that is no longer async.
See https://wrapt.readthedocs.io/en/master/wrappers.html#object-proxy for more
"""

def __init__(self, wrapped, class_name: str, method_name: str):
super().__init__(wrapped)
self._self_class_name = class_name
self._self_method_name = method_name
self._self_already_awaited = False

def __await__(self):
if not self._self_already_awaited:
warnings.warn(
(
f"The {self._self_method_name!r} method on {self._self_class_name!r}"
" is no longer async and awaiting it will raise an error after Dec 2024"
" - please remove the `await` keyword."
),
DeprecationWarning,
stacklevel=2,
)
self._self_already_awaited = True
yield
return self.__wrapped__

def __repr__(self):
return repr(self.__wrapped__)

def __reduce_ex__(self, protocol):
return (
type(self),
(self.__wrapped__,),
{"_self_already_awaited": self._self_already_awaited},
)


def deprecated_async_method(wrapped):
"""Decorator that wraps a sync method to allow awaiting it even though it is no longer async."""

@wrapt.decorator
def wrapper(wrapped, instance, args, kwargs):
result = wrapped(*args, **kwargs)
return AsyncCompatProxy(
result, class_name=instance.__class__.__name__, method_name=wrapped.__name__
)

return wrapper(wrapped)
5 changes: 0 additions & 5 deletions src/prefect/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from typing_extensions import TypeVar

from prefect._internal.compatibility.deprecated import deprecated_async_method
from prefect.client.orchestration import get_client
from prefect.client.schemas.objects import TaskRun
from prefect.exceptions import ObjectNotFound
Expand Down Expand Up @@ -135,7 +134,6 @@ class PrefectConcurrentFuture(PrefectWrappedFuture[R, concurrent.futures.Future]
when the task run is submitted to a ThreadPoolExecutor.
"""

@deprecated_async_method
def wait(self, timeout: Optional[float] = None) -> None:
try:
result = self._wrapped_future.result(timeout=timeout)
Expand All @@ -144,7 +142,6 @@ def wait(self, timeout: Optional[float] = None) -> None:
if isinstance(result, State):
self._final_state = result

@deprecated_async_method
def result(
self,
timeout: Optional[float] = None,
Expand Down Expand Up @@ -198,7 +195,6 @@ class PrefectDistributedFuture(PrefectFuture[R]):
done_callbacks: List[Callable[[PrefectFuture], None]] = []
waiter = None

@deprecated_async_method
def wait(self, timeout: Optional[float] = None) -> None:
return run_coro_as_sync(self.wait_async(timeout=timeout))

Expand Down Expand Up @@ -235,7 +231,6 @@ async def wait_async(self, timeout: Optional[float] = None):
self._final_state = task_run.state
return

@deprecated_async_method
def result(
self,
timeout: Optional[float] = None,
Expand Down
5 changes: 0 additions & 5 deletions src/prefect/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
from typing_extensions import Literal, ParamSpec

import prefect.states
from prefect._internal.compatibility.deprecated import (
deprecated_async_method,
)
from prefect.cache_policies import DEFAULT, NONE, CachePolicy
from prefect.client.orchestration import get_client
from prefect.client.schemas import TaskRun
Expand Down Expand Up @@ -1038,7 +1035,6 @@ def submit(
) -> State[T]:
...

@deprecated_async_method
def submit(
self,
*args: Any,
Expand Down Expand Up @@ -1203,7 +1199,6 @@ def map(
) -> PrefectFutureList[State[T]]:
...

@deprecated_async_method
def map(
self,
*args: Any,
Expand Down
6 changes: 0 additions & 6 deletions src/prefect/utilities/asyncutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import anyio.from_thread
import anyio.to_thread
import sniffio
import wrapt
from typing_extensions import Literal, ParamSpec, TypeGuard

from prefect._internal.concurrency.api import _cast_to_call, from_sync
Expand Down Expand Up @@ -210,11 +209,6 @@ def run_coro_as_sync(
Returns:
The result of the coroutine if wait_for_result is True, otherwise None.
"""
if not asyncio.iscoroutine(coroutine):
if isinstance(coroutine, wrapt.ObjectProxy):
return coroutine.__wrapped__
else:
raise TypeError("`coroutine` must be a coroutine object")

async def coroutine_wrapper() -> Union[R, None]:
"""
Expand Down
67 changes: 15 additions & 52 deletions tests/public/flows/test_previously_awaited_methods.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,23 @@
from prefect import flow, task

import pytest

async def test_awaiting_formerly_async_methods():
import warnings
from prefect import flow, task

N = 5

@pytest.mark.parametrize(
"method,args",
[
("submit", (None,)),
("map", ([None],)),
],
)
async def test_awaiting_previously_async_task_methods_fail(method, args):
@task
def get_random_number(_) -> int:
async def get_random_number(_) -> int:
return 42

@flow
async def get_some_numbers_old_await_syntax():
state1 = await get_random_number.submit(None, return_state=True)
assert state1.is_completed()

future1 = await get_random_number.submit(None)

await future1.wait()
assert await future1.result() == 42

list_of_futures = await get_random_number.map([None] * N)
[await future.wait() for future in list_of_futures]
assert all([await future.result() == 42 for future in list_of_futures])

@flow
async def get_some_numbers_new_way():
state1 = get_random_number.submit(None, return_state=True)
assert state1.is_completed()

future1 = get_random_number.submit(None)
future1.wait()
assert future1.result() == 42

list_of_futures = get_random_number.map([None] * N)
[future.wait() for future in list_of_futures]
assert all(future.result() == 42 for future in list_of_futures)

# Test the old way (with await)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
await get_some_numbers_old_await_syntax()
deprecation_warnings = [
_w for _w in w if issubclass(_w.category, DeprecationWarning)
]

assert all(
"please remove the `await` keyword" in str(warning.message)
for warning in deprecation_warnings
)
async def run_a_task():
await getattr(get_random_number, method)(*args)

# Test the new way (without await)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
await get_some_numbers_new_way()
deprecation_warnings = [
_w for _w in w if issubclass(_w.category, DeprecationWarning)
]
assert len(deprecation_warnings) == 0
with pytest.raises(TypeError, match="can't be used in 'await' expression"):
await run_a_task()

0 comments on commit 4a7424d

Please sign in to comment.