diff --git a/.gitignore b/.gitignore index 2e34cff..fb230e8 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ dist *.pyc *.egg-info build +venv diff --git a/.travis.yml b/.travis.yml index 8ce4a5a..63d4f08 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,6 +7,7 @@ python: - 3.3 - 3.4 - 3.5 + - 3.6 - pypy script: python setup.py test diff --git a/README.rst b/README.rst index ab5d6bd..6bcc1c5 100644 --- a/README.rst +++ b/README.rst @@ -144,6 +144,39 @@ We can also use the result of the function to alter the behavior of retrying. Any combination of stop, wait, etc. is also supported to give you the freedom to mix and match. +You might need to retry a (repeatable) generator function. The following generator functions are supported: +1. Generator functions whose values can be fetched as a whole. Ideal if the function may not preserve the order +and does not generates many elements +2. Generator functions whose values are always fetched in the same order, but generates many elements + +.. code-block:: python + + import datetime + + + @retry(stop_max_delay=3000) + def _few_elements(started: datetime.datetime): + for i in range(10): + if i == 5 and datetime.datetime.now() - started < datetime.timedelta(seconds=2): + raise ValueError + yield i + result = list(_few_elements(datetime.datetime.now())) # here we have [0, 1, ... 9] + + @retry(stop_max_delay=3000, deterministic_generators=True) + def _many_elements(started: datetime.datetime): + for i in range(sys.maxsize): + if i == 5 and datetime.datetime.now() - started < datetime.timedelta(seconds=2): + raise ValueError + yield i + + bounded_result = [] + for i in _many_elements(datetime.datetime.now()): + if i > 9: + break + bounded_result.append(i) + # Here bounded_result is [0, 1, ..., 9] and your RAM is preserved + + Contribute ---------- diff --git a/retrying.py b/retrying.py index bcb7a9d..8adde2a 100644 --- a/retrying.py +++ b/retrying.py @@ -11,13 +11,22 @@ ## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ## See the License for the specific language governing permissions and ## limitations under the License. - +import inspect import random -import six import sys import time import traceback +import six + + +if sys.version_info >= (3, 6): + def _is_async(fn): + return inspect.isasyncgenfunction(fn) or inspect.isgeneratorfunction(fn) +else: + def _is_async(fn): + return inspect.isgeneratorfunction(fn) + # sys.maxint / 2, since Python 3.2 doesn't have a sys.maxint... MAX_WAIT = 1073741823 @@ -41,8 +50,9 @@ def wrap_simple(f): @six.wraps(f) def wrapped_f(*args, **kw): + if dkw.get('deterministic_generators') and _is_async(f): + return Retrying().call_async(f, *args, **kw) return Retrying().call(f, *args, **kw) - return wrapped_f return wrap_simple(dargs[0]) @@ -52,6 +62,8 @@ def wrap(f): @six.wraps(f) def wrapped_f(*args, **kw): + if dkw.get('deterministic_generators') and _is_async(f): + return Retrying(*dargs, **dkw).call_async(f, *args, **kw) return Retrying(*dargs, **dkw).call(f, *args, **kw) return wrapped_f @@ -77,7 +89,8 @@ def __init__(self, wait_func=None, wait_jitter_max=None, before_attempts=None, - after_attempts=None): + after_attempts=None, + deterministic_generators=False): self._stop_max_attempt_number = 5 if stop_max_attempt_number is None else stop_max_attempt_number self._stop_max_delay = 100 if stop_max_delay is None else stop_max_delay @@ -92,7 +105,8 @@ def __init__(self, self._wait_jitter_max = 0 if wait_jitter_max is None else wait_jitter_max self._before_attempts = before_attempts self._after_attempts = after_attempts - + self._deterministic_generators = deterministic_generators + self._deterministic_offset = -1 # TODO add chaining of stop behaviors # stop behavior stop_funcs = [] @@ -215,6 +229,11 @@ def should_reject(self, attempt): return reject def call(self, fn, *args, **kwargs): + self._deterministic_offset = -1 + assert not self._deterministic_generators + + is_generator = _is_async(fn) + start_time = int(round(time.time() * 1000)) attempt_number = 1 while True: @@ -222,12 +241,26 @@ def call(self, fn, *args, **kwargs): self._before_attempts(attempt_number) try: - attempt = Attempt(fn(*args, **kwargs), attempt_number, False) + if is_generator: + # Here we do not know if the generator will fail. + # In order to avoid partial yield to the caller, which would + # produce partial data and then start from scratch upon error, + # we have to fetch the whole data and, in case of failures, we + # just recreate the result from scratch. + # We could also yield data incrementally and, when retrying, + # skip what we have already produced. + # This would require deterministic order of element production, though. + result = list(fn(*args, **kwargs)) + else: + result = fn(*args, **kwargs) + attempt = Attempt(result, attempt_number, False) except: tb = sys.exc_info() attempt = Attempt(tb, attempt_number, True) - + if not self.should_reject(attempt): + if is_generator: + return self._yelded_data(attempt) return attempt.get(self._wrap_exception) if self._after_attempts: @@ -249,6 +282,58 @@ def call(self, fn, *args, **kwargs): attempt_number += 1 + def call_async(self, fn, *args, **kwargs): + self._deterministic_offset = -1 + assert self._deterministic_generators + assert _is_async(fn) + + start_time = int(round(time.time() * 1000)) + attempt_number = 1 + while True: + if self._before_attempts: + self._before_attempts(attempt_number) + + try: + for d in self._deterministic_generation(fn, *args, **kwargs): + yield d + attempt = Attempt(None, attempt_number, False) + except: + tb = sys.exc_info() + attempt = Attempt(tb, attempt_number, True) + + if not self.should_reject(attempt): + self._yelded_data(attempt) + return + if self._after_attempts: + self._after_attempts(attempt_number) + + delay_since_first_attempt_ms = int(round(time.time() * 1000)) - start_time + if self.stop(attempt_number, delay_since_first_attempt_ms): + if not self._wrap_exception and attempt.has_exception: + # get() on an attempt with an exception should cause it to be raised, but raise just in case + raise attempt.get() + else: + raise RetryError(attempt) + else: + sleep = self.wait(attempt_number, delay_since_first_attempt_ms) + if self._wait_jitter_max: + jitter = random.random() * self._wait_jitter_max + sleep = sleep + max(0, jitter) + time.sleep(sleep / 1000.0) + + attempt_number += 1 + + def _yelded_data(self, attempt): + for d in attempt.get(self._wrap_exception): + yield d + + def _deterministic_generation(self, fn, *args, **kwargs): + for i, v in enumerate(fn(*args, **kwargs)): + if i <= self._deterministic_offset: + continue + yield v + self._deterministic_offset = i + class Attempt(object): """ diff --git a/test_retrying.py b/test_retrying.py index 8ce4ac3..82adbf7 100644 --- a/test_retrying.py +++ b/test_retrying.py @@ -11,7 +11,8 @@ ## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ## See the License for the specific language governing permissions and ## limitations under the License. - +import datetime +import sys import time import unittest @@ -468,5 +469,57 @@ def _test_after(): self.assertTrue(TestBeforeAfterAttempts._attempt_number is 2) + +class TestGenerators(unittest.TestCase): + def test(self): + @retry(stop_max_delay=3000) + def _f(started): + for i in range(10): + if i == 5 and datetime.datetime.now() - started < datetime.timedelta(seconds=2): + raise ValueError + yield i + + self.assertEqual(list(range(10)), list(_f(datetime.datetime.now()))) + + def test_deterministic(self): + @retry(stop_max_delay=3000, deterministic_generators=True) + def _f(started): + for i in range(10): + if i == 5 and datetime.datetime.now() - started < datetime.timedelta(seconds=2): + raise ValueError + yield i + self.assertEqual(list(range(10)), list(_f(datetime.datetime.now()))) + + def test_deterministic_big_values(self): + if sys.version_info >= (3, 0): + safe_range = range + else: + # noinspection PyUnresolvedReferences + safe_range = xrange + + # Do NOT use nondeterministic generators. You would get OOM. + @retry(stop_max_delay=3000, deterministic_generators=True) + def _f(started): + for i in safe_range(sys.maxsize): + if i == 5 and datetime.datetime.now() - started < datetime.timedelta(seconds=2): + raise ValueError + yield i + + bounded_result = [] + for i in _f(datetime.datetime.now()): + if i > 9: + break + bounded_result.append(i) + self.assertEqual(list(range(10)), bounded_result) + + def test_simple(self): + @retry(stop_max_delay=3000) + def _f(started): + if datetime.datetime.now() - started < datetime.timedelta(seconds=2): + raise ValueError + yield 'OK' + self.assertEqual(['OK'], list(_f(datetime.datetime.now()))) + + if __name__ == '__main__': unittest.main()