Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generator support #83

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ dist
*.pyc
*.egg-info
build
venv
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ python:
- 3.3
- 3.4
- 3.5
- 3.6
- pypy

script: python setup.py test
33 changes: 33 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,39 @@ We can also use the result of the function to alter the behavior of retrying.

Any combination of stop, wait, etc. is also supported to give you the freedom to mix and match.

You might need to retry a (repeatable) generator function. The following generator functions are supported:
1. Generator functions whose values can be fetched as a whole. Ideal if the function may not preserve the order
and does not generates many elements
2. Generator functions whose values are always fetched in the same order, but generates many elements

.. code-block:: python

import datetime


@retry(stop_max_delay=3000)
def _few_elements(started: datetime.datetime):
for i in range(10):
if i == 5 and datetime.datetime.now() - started < datetime.timedelta(seconds=2):
raise ValueError
yield i
result = list(_few_elements(datetime.datetime.now())) # here we have [0, 1, ... 9]

@retry(stop_max_delay=3000, deterministic_generators=True)
def _many_elements(started: datetime.datetime):
for i in range(sys.maxsize):
if i == 5 and datetime.datetime.now() - started < datetime.timedelta(seconds=2):
raise ValueError
yield i

bounded_result = []
for i in _many_elements(datetime.datetime.now()):
if i > 9:
break
bounded_result.append(i)
# Here bounded_result is [0, 1, ..., 9] and your RAM is preserved


Contribute
----------

Expand Down
99 changes: 92 additions & 7 deletions retrying.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,22 @@
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
## See the License for the specific language governing permissions and
## limitations under the License.

import inspect
import random
import six
import sys
import time
import traceback

import six


if sys.version_info >= (3, 6):
def _is_async(fn):
return inspect.isasyncgenfunction(fn) or inspect.isgeneratorfunction(fn)
else:
def _is_async(fn):
return inspect.isgeneratorfunction(fn)


# sys.maxint / 2, since Python 3.2 doesn't have a sys.maxint...
MAX_WAIT = 1073741823
Expand All @@ -41,8 +50,9 @@ def wrap_simple(f):

@six.wraps(f)
def wrapped_f(*args, **kw):
if dkw.get('deterministic_generators') and _is_async(f):
return Retrying().call_async(f, *args, **kw)
return Retrying().call(f, *args, **kw)

return wrapped_f

return wrap_simple(dargs[0])
Expand All @@ -52,6 +62,8 @@ def wrap(f):

@six.wraps(f)
def wrapped_f(*args, **kw):
if dkw.get('deterministic_generators') and _is_async(f):
return Retrying(*dargs, **dkw).call_async(f, *args, **kw)
return Retrying(*dargs, **dkw).call(f, *args, **kw)

return wrapped_f
Expand All @@ -77,7 +89,8 @@ def __init__(self,
wait_func=None,
wait_jitter_max=None,
before_attempts=None,
after_attempts=None):
after_attempts=None,
deterministic_generators=False):

self._stop_max_attempt_number = 5 if stop_max_attempt_number is None else stop_max_attempt_number
self._stop_max_delay = 100 if stop_max_delay is None else stop_max_delay
Expand All @@ -92,7 +105,8 @@ def __init__(self,
self._wait_jitter_max = 0 if wait_jitter_max is None else wait_jitter_max
self._before_attempts = before_attempts
self._after_attempts = after_attempts

self._deterministic_generators = deterministic_generators
self._deterministic_offset = -1
# TODO add chaining of stop behaviors
# stop behavior
stop_funcs = []
Expand Down Expand Up @@ -215,19 +229,38 @@ def should_reject(self, attempt):
return reject

def call(self, fn, *args, **kwargs):
self._deterministic_offset = -1
assert not self._deterministic_generators

is_generator = _is_async(fn)

start_time = int(round(time.time() * 1000))
attempt_number = 1
while True:
if self._before_attempts:
self._before_attempts(attempt_number)

try:
attempt = Attempt(fn(*args, **kwargs), attempt_number, False)
if is_generator:
# Here we do not know if the generator will fail.
# In order to avoid partial yield to the caller, which would
# produce partial data and then start from scratch upon error,
# we have to fetch the whole data and, in case of failures, we
# just recreate the result from scratch.
# We could also yield data incrementally and, when retrying,
# skip what we have already produced.
# This would require deterministic order of element production, though.
result = list(fn(*args, **kwargs))
else:
result = fn(*args, **kwargs)
attempt = Attempt(result, attempt_number, False)
except:
tb = sys.exc_info()
attempt = Attempt(tb, attempt_number, True)

if not self.should_reject(attempt):
if is_generator:
return self._yelded_data(attempt)
return attempt.get(self._wrap_exception)

if self._after_attempts:
Expand All @@ -249,6 +282,58 @@ def call(self, fn, *args, **kwargs):

attempt_number += 1

def call_async(self, fn, *args, **kwargs):
self._deterministic_offset = -1
assert self._deterministic_generators
assert _is_async(fn)

start_time = int(round(time.time() * 1000))
attempt_number = 1
while True:
if self._before_attempts:
self._before_attempts(attempt_number)

try:
for d in self._deterministic_generation(fn, *args, **kwargs):
yield d
attempt = Attempt(None, attempt_number, False)
except:
tb = sys.exc_info()
attempt = Attempt(tb, attempt_number, True)

if not self.should_reject(attempt):
self._yelded_data(attempt)
return
if self._after_attempts:
self._after_attempts(attempt_number)

delay_since_first_attempt_ms = int(round(time.time() * 1000)) - start_time
if self.stop(attempt_number, delay_since_first_attempt_ms):
if not self._wrap_exception and attempt.has_exception:
# get() on an attempt with an exception should cause it to be raised, but raise just in case
raise attempt.get()
else:
raise RetryError(attempt)
else:
sleep = self.wait(attempt_number, delay_since_first_attempt_ms)
if self._wait_jitter_max:
jitter = random.random() * self._wait_jitter_max
sleep = sleep + max(0, jitter)
time.sleep(sleep / 1000.0)

attempt_number += 1

def _yelded_data(self, attempt):
for d in attempt.get(self._wrap_exception):
yield d

def _deterministic_generation(self, fn, *args, **kwargs):
for i, v in enumerate(fn(*args, **kwargs)):
if i <= self._deterministic_offset:
continue
yield v
self._deterministic_offset = i


class Attempt(object):
"""
Expand Down
55 changes: 54 additions & 1 deletion test_retrying.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
## See the License for the specific language governing permissions and
## limitations under the License.

import datetime
import sys
import time
import unittest

Expand Down Expand Up @@ -468,5 +469,57 @@ def _test_after():

self.assertTrue(TestBeforeAfterAttempts._attempt_number is 2)


class TestGenerators(unittest.TestCase):
def test(self):
@retry(stop_max_delay=3000)
def _f(started):
for i in range(10):
if i == 5 and datetime.datetime.now() - started < datetime.timedelta(seconds=2):
raise ValueError
yield i

self.assertEqual(list(range(10)), list(_f(datetime.datetime.now())))

def test_deterministic(self):
@retry(stop_max_delay=3000, deterministic_generators=True)
def _f(started):
for i in range(10):
if i == 5 and datetime.datetime.now() - started < datetime.timedelta(seconds=2):
raise ValueError
yield i
self.assertEqual(list(range(10)), list(_f(datetime.datetime.now())))

def test_deterministic_big_values(self):
if sys.version_info >= (3, 0):
safe_range = range
else:
# noinspection PyUnresolvedReferences
safe_range = xrange

# Do NOT use nondeterministic generators. You would get OOM.
@retry(stop_max_delay=3000, deterministic_generators=True)
def _f(started):
for i in safe_range(sys.maxsize):
if i == 5 and datetime.datetime.now() - started < datetime.timedelta(seconds=2):
raise ValueError
yield i

bounded_result = []
for i in _f(datetime.datetime.now()):
if i > 9:
break
bounded_result.append(i)
self.assertEqual(list(range(10)), bounded_result)

def test_simple(self):
@retry(stop_max_delay=3000)
def _f(started):
if datetime.datetime.now() - started < datetime.timedelta(seconds=2):
raise ValueError
yield 'OK'
self.assertEqual(['OK'], list(_f(datetime.datetime.now())))


if __name__ == '__main__':
unittest.main()