From 5138fd8f70972b9768329299daade01cf3e9b622 Mon Sep 17 00:00:00 2001 From: Hasier Date: Thu, 4 Jul 2024 17:32:13 +0100 Subject: [PATCH 1/4] Account for non-base objects in retries --- tenacity/retry.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tenacity/retry.py b/tenacity/retry.py index 9211631..6584c5b 100644 --- a/tenacity/retry.py +++ b/tenacity/retry.py @@ -30,15 +30,29 @@ def __call__(self, retry_state: "RetryCallState") -> bool: pass def __and__(self, other: "retry_base") -> "retry_all": - return other.__rand__(self) + if isinstance(other, retry_base): + # Delegate to the other object to allow for specific + # implementations, such as asyncio + return other.__rand__(self) + return retry_all(other, self) def __rand__(self, other: "retry_base") -> "retry_all": + # This is automatically invoked for inheriting classes, + # so it helps to keep the abstraction and delegate specific + # implementations, such as asyncio return retry_all(other, self) def __or__(self, other: "retry_base") -> "retry_any": - return other.__ror__(self) + if isinstance(other, retry_base): + # Delegate to the other object to allow for specific + # implementations, such as asyncio + return other.__ror__(self) + return retry_any(other, self) def __ror__(self, other: "retry_base") -> "retry_any": + # This is automatically invoked for inheriting classes, + # so it helps to keep the abstraction and delegate specific + # implementations, such as asyncio return retry_any(other, self) From e5a5387870f7b689fc631ad46b8cc27de10411a5 Mon Sep 17 00:00:00 2001 From: Hasier Date: Thu, 4 Jul 2024 17:32:19 +0100 Subject: [PATCH 2/4] Add tests --- tenacity/asyncio/retry.py | 6 +- tenacity/retry.py | 4 +- tests/test_asyncio.py | 352 +++++++++++++++++++++++++++++++++++++- tests/test_tenacity.py | 110 ++++++++++++ 4 files changed, 467 insertions(+), 5 deletions(-) diff --git a/tenacity/asyncio/retry.py b/tenacity/asyncio/retry.py index 94b8b15..f458bad 100644 --- a/tenacity/asyncio/retry.py +++ b/tenacity/asyncio/retry.py @@ -104,7 +104,7 @@ def __init__(self, *retries: typing.Union[retry_base, async_retry_base]) -> None async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] result = False for r in self.retries: - result = result or await _utils.wrap_to_async_func(r)(retry_state) + result = result or (await _utils.wrap_to_async_func(r)(retry_state) is True) if result: break return result @@ -119,7 +119,9 @@ def __init__(self, *retries: typing.Union[retry_base, async_retry_base]) -> None async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] result = True for r in self.retries: - result = result and await _utils.wrap_to_async_func(r)(retry_state) + result = result and ( + await _utils.wrap_to_async_func(r)(retry_state) is True + ) if not result: break return result diff --git a/tenacity/retry.py b/tenacity/retry.py index 6584c5b..c05c16b 100644 --- a/tenacity/retry.py +++ b/tenacity/retry.py @@ -283,7 +283,7 @@ def __init__(self, *retries: retry_base) -> None: self.retries = retries def __call__(self, retry_state: "RetryCallState") -> bool: - return any(r(retry_state) for r in self.retries) + return any(r(retry_state) is True for r in self.retries) class retry_all(retry_base): @@ -293,4 +293,4 @@ def __init__(self, *retries: retry_base) -> None: self.retries = retries def __call__(self, retry_state: "RetryCallState") -> bool: - return all(r(retry_state) for r in self.retries) + return all(r(retry_state) is True for r in self.retries) diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 8716529..69ea883 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -28,7 +28,7 @@ import pytest import tenacity -from tenacity import AsyncRetrying, RetryError +from tenacity import AsyncRetrying, RetryCallState, RetryError from tenacity import asyncio as tasyncio from tenacity import retry, retry_if_exception, retry_if_result, stop_after_attempt from tenacity.wait import wait_fixed @@ -308,6 +308,98 @@ def is_exc(e: BaseException) -> bool: self.assertEqual(4, result) + @asynctest + async def test_retry_with_async_result_or_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = tasyncio.retry_if_result(lt_3) | should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_result_or_async_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = tasyncio.retry_if_result(lt_3) | should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_sync_retry_with_async_result_or_async_func(self): + called = False + + async def test(): + attempts = 0 + + def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = tenacity.retry_if_result(lt_3) | should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + # It does not correctly work as the function is not called! + self.assertFalse(called) + self.assertEqual(3, result) + @asynctest async def test_retry_with_async_result_ror(self): async def test(): @@ -339,6 +431,98 @@ async def is_exc(e: BaseException) -> bool: self.assertEqual(4, result) + @asynctest + async def test_retry_with_async_result_ror_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = should_retry | tasyncio.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_result_ror_async_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = should_retry | tasyncio.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_sync_retry_with_async_result_ror_async_func(self): + called = False + + async def test(): + attempts = 0 + + def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = should_retry | tenacity.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + # It does not correctly work as the function is not called! + self.assertFalse(called) + self.assertEqual(3, result) + @asynctest async def test_retry_with_async_result_and(self): async def test(): @@ -362,6 +546,89 @@ def gt_0(x: float) -> bool: self.assertEqual(3, result) + @asynctest + async def test_retry_with_async_result_and_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = tasyncio.retry_if_result(lt_3) & should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_result_and_async_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = tasyncio.retry_if_result(lt_3) & should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_sync_retry_with_async_result_and_async_func(self): + called = False + + async def test(): + attempts = 0 + + def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = tenacity.retry_if_result(lt_3) & should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + # It does not correctly work as the function is not called! + self.assertFalse(called) + self.assertEqual(1, result) + @asynctest async def test_retry_with_async_result_rand(self): async def test(): @@ -385,6 +652,89 @@ def gt_0(x: float) -> bool: self.assertEqual(3, result) + @asynctest + async def test_retry_with_async_result_rand_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = should_retry & tasyncio.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_result_rand_async_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = should_retry & tasyncio.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_sync_retry_with_async_result_rand_async_func(self): + called = False + + async def test(): + attempts = 0 + + def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = should_retry & tenacity.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + # It does not correctly work as the function is not called! + self.assertFalse(called) + self.assertEqual(1, result) + @asynctest async def test_async_retying_iterator(self): thing = NoIOErrorAfterCount(5) diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py index e158fa6..07d4155 100644 --- a/tests/test_tenacity.py +++ b/tests/test_tenacity.py @@ -633,6 +633,58 @@ def r(fut): self.assertFalse(r(tenacity.Future.construct(1, 3, False))) self.assertFalse(r(tenacity.Future.construct(1, 1, True))) + async def test_retry_and_func(self): + def test(): + attempts = 0 + called = False + + def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = tenacity.retry_if_result(lt_3) & should_retry # type: ignore[operator] + for attempt in Retrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = test() + + self.assertEqual(3, result) + + async def test_retry_rand_func(self): + def test(): + attempts = 0 + called = False + + def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = should_retry & tenacity.retry_if_result(lt_3) # type: ignore[operator] + for attempt in Retrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = test() + + self.assertEqual(3, result) + def test_retry_or(self): retry = tenacity.retry_if_result( lambda x: x == "foo" @@ -647,6 +699,64 @@ def r(fut): self.assertFalse(r(tenacity.Future.construct(1, 2.2, False))) self.assertFalse(r(tenacity.Future.construct(1, 42, True))) + def test_retry_or_func(self): + def test(): + attempts = 0 + called = False + + def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = tenacity.retry_if_result(lt_3) | should_retry # type: ignore[operator] + for attempt in Retrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = test() + + self.assertEqual(3, result) + + def test_retry_ror_func(self): + def test(): + attempts = 0 + called = False + + def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = should_retry | tenacity.retry_if_result(lt_3) # type: ignore[operator] + for attempt in Retrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = test() + + self.assertEqual(3, result) + def _raise_try_again(self): self._attempts += 1 if self._attempts < 3: From d5f8bff55a0a7d2d299ec7b711e0b8dcce6dd35e Mon Sep 17 00:00:00 2001 From: Hasier Date: Thu, 4 Jul 2024 17:34:23 +0100 Subject: [PATCH 3/4] Add release note --- .../notes/allow-retry-callables-ba921a2b57229540.yaml | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 releasenotes/notes/allow-retry-callables-ba921a2b57229540.yaml diff --git a/releasenotes/notes/allow-retry-callables-ba921a2b57229540.yaml b/releasenotes/notes/allow-retry-callables-ba921a2b57229540.yaml new file mode 100644 index 0000000..e7be1cd --- /dev/null +++ b/releasenotes/notes/allow-retry-callables-ba921a2b57229540.yaml @@ -0,0 +1,7 @@ +--- +features: + - | + Allow for callables to be combined as retry values. This will only + work when used combined with their corresponding implementation + retry objects, e.g. only async functions will work when used together + with async retry strategies. From f8725d6abbfcaf38eb36725c31bf7fe80fee1ab0 Mon Sep 17 00:00:00 2001 From: Hasier Date: Mon, 8 Jul 2024 11:34:05 +0100 Subject: [PATCH 4/4] Check for async strategies in sync context --- tenacity/asyncio/retry.py | 6 ++-- tenacity/retry.py | 41 ++++++++++++++++++++++++-- tests/test_asyncio.py | 60 ++++++++++++++++++++++++++------------- 3 files changed, 81 insertions(+), 26 deletions(-) diff --git a/tenacity/asyncio/retry.py b/tenacity/asyncio/retry.py index f458bad..94b8b15 100644 --- a/tenacity/asyncio/retry.py +++ b/tenacity/asyncio/retry.py @@ -104,7 +104,7 @@ def __init__(self, *retries: typing.Union[retry_base, async_retry_base]) -> None async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] result = False for r in self.retries: - result = result or (await _utils.wrap_to_async_func(r)(retry_state) is True) + result = result or await _utils.wrap_to_async_func(r)(retry_state) if result: break return result @@ -119,9 +119,7 @@ def __init__(self, *retries: typing.Union[retry_base, async_retry_base]) -> None async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] result = True for r in self.retries: - result = result and ( - await _utils.wrap_to_async_func(r)(retry_state) is True - ) + result = result and await _utils.wrap_to_async_func(r)(retry_state) if not result: break return result diff --git a/tenacity/retry.py b/tenacity/retry.py index c05c16b..69329e7 100644 --- a/tenacity/retry.py +++ b/tenacity/retry.py @@ -18,6 +18,13 @@ import re import typing +from . import _utils + +try: + import tornado +except ImportError: + tornado = None + if typing.TYPE_CHECKING: from tenacity import RetryCallState @@ -283,7 +290,22 @@ def __init__(self, *retries: retry_base) -> None: self.retries = retries def __call__(self, retry_state: "RetryCallState") -> bool: - return any(r(retry_state) is True for r in self.retries) + result = False + for r in self.retries: + if _utils.is_coroutine_callable(r) or ( + tornado + and hasattr(tornado.gen, "is_coroutine_function") + and tornado.gen.is_coroutine_function(r) + ): + raise TypeError( + "Cannot use async functions in a sync context. Make sure " + "you use the correct retrying object and the corresponding " + "async strategies" + ) + result = result or r(retry_state) + if result: + break + return result class retry_all(retry_base): @@ -293,4 +315,19 @@ def __init__(self, *retries: retry_base) -> None: self.retries = retries def __call__(self, retry_state: "RetryCallState") -> bool: - return all(r(retry_state) is True for r in self.retries) + result = True + for r in self.retries: + if _utils.is_coroutine_callable(r) or ( + tornado + and hasattr(tornado.gen, "is_coroutine_function") + and tornado.gen.is_coroutine_function(r) + ): + raise TypeError( + "Cannot use async functions in a sync context. Make sure " + "you use the correct retrying object and the corresponding " + "async strategies" + ) + result = result and r(retry_state) + if not result: + break + return result diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 69ea883..325c8a8 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -394,11 +394,16 @@ async def should_retry(retry_state: RetryCallState) -> bool: return attempts - result = await test() - - # It does not correctly work as the function is not called! - self.assertFalse(called) - self.assertEqual(3, result) + try: + await test() + except TypeError as exc: + self.assertEqual( + str(exc), + "Cannot use async functions in a sync context. Make sure you use " + "the correct retrying object and the corresponding async strategies", + ) + else: + self.fail("This is an invalid retry combination that should have failed") @asynctest async def test_retry_with_async_result_ror(self): @@ -517,11 +522,16 @@ async def should_retry(retry_state: RetryCallState) -> bool: return attempts - result = await test() - - # It does not correctly work as the function is not called! - self.assertFalse(called) - self.assertEqual(3, result) + try: + await test() + except TypeError as exc: + self.assertEqual( + str(exc), + "Cannot use async functions in a sync context. Make sure you use " + "the correct retrying object and the corresponding async strategies", + ) + else: + self.fail("This is an invalid retry combination that should have failed") @asynctest async def test_retry_with_async_result_and(self): @@ -623,11 +633,16 @@ async def should_retry(retry_state: RetryCallState) -> bool: return attempts - result = await test() - - # It does not correctly work as the function is not called! - self.assertFalse(called) - self.assertEqual(1, result) + try: + await test() + except TypeError as exc: + self.assertEqual( + str(exc), + "Cannot use async functions in a sync context. Make sure you use " + "the correct retrying object and the corresponding async strategies", + ) + else: + self.fail("This is an invalid retry combination that should have failed") @asynctest async def test_retry_with_async_result_rand(self): @@ -729,11 +744,16 @@ async def should_retry(retry_state: RetryCallState) -> bool: return attempts - result = await test() - - # It does not correctly work as the function is not called! - self.assertFalse(called) - self.assertEqual(1, result) + try: + await test() + except TypeError as exc: + self.assertEqual( + str(exc), + "Cannot use async functions in a sync context. Make sure you use " + "the correct retrying object and the corresponding async strategies", + ) + else: + self.fail("This is an invalid retry combination that should have failed") @asynctest async def test_async_retying_iterator(self):