Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retry on 429 error code #455

Merged
merged 3 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def test_extra_credential_value_encoding(mock_get_and_post):
def test_extra_credential_value_object(mock_get_and_post):
_, post = mock_get_and_post

class TestCredential(object):
class TestCredential:
value = "initial"

def __str__(self):
Expand Down Expand Up @@ -972,7 +972,7 @@ def test_authentication_gssapi_init_arguments(
assert session.auth.creds == expected_credentials


class RetryRecorder(object):
class RetryRecorder:
def __init__(self, error=None, result=None):
self.__name__ = "RetryRecorder"
self._retry_count = 0
Expand Down Expand Up @@ -1058,6 +1058,33 @@ def test_5XX_error_retry(status_code, attempts, monkeypatch):
assert post_retry.retry_count == attempts


def test_429_error_retry(monkeypatch):
http_resp = TrinoRequest.http.Response()
http_resp.status_code = 429
http_resp.headers["Retry-After"] = 1

post_retry = RetryRecorder(result=http_resp)
monkeypatch.setattr(TrinoRequest.http.Session, "post", post_retry)

get_retry = RetryRecorder(result=http_resp)
monkeypatch.setattr(TrinoRequest.http.Session, "get", get_retry)

req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(
user="test",
),
max_attempts=3
)

req.post("URL")
assert post_retry.retry_count == 3

req.get("URL")
assert post_retry.retry_count == 3


@pytest.mark.parametrize("status_code", [
501
])
Expand Down Expand Up @@ -1087,7 +1114,7 @@ def test_error_no_retry(status_code, monkeypatch):
assert post_retry.retry_count == 1


class FakeGatewayResponse(object):
class FakeGatewayResponse:
def __init__(self, http_response, redirect_count=1):
self.__name__ = "FakeGatewayResponse"
self.http_response = http_response
Expand Down Expand Up @@ -1197,7 +1224,7 @@ def test_retry_with():
max_attempts=max_attempts,
)

class FailerUntil(object):
class FailerUntil:
def __init__(self, until=1):
self.attempt = 0
self._until = until
Expand Down
48 changes: 37 additions & 11 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
import urllib.parse
import warnings
from dataclasses import dataclass
from datetime import datetime
from email.utils import parsedate_to_datetime
from time import sleep
from typing import Any, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -76,7 +78,7 @@
ROLE_PATTERN = re.compile(r"^ROLE\{(.*)\}$")


class ClientSession(object):
class ClientSession:
"""
Manage the current Client Session properties of a specific connection. This class is thread-safe.

Expand Down Expand Up @@ -319,9 +321,9 @@ def __repr__(self):
)


class _DelayExponential(object):
class _DelayExponential:
def __init__(
self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours
self, base=0.1, exponent=2, jitter=True, max_delay=1800 # 100ms # 30 min
):
self._base = base
self._exponent = exponent
Expand All @@ -336,9 +338,9 @@ def __call__(self, attempt):
return delay


class _RetryWithExponentialBackoff(object):
class _RetryWithExponentialBackoff:
def __init__(
self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours
self, base=0.1, exponent=2, jitter=True, max_delay=1800 # 100ms # 30 min
):
self._get_delay = _DelayExponential(base, exponent, jitter, max_delay)

Expand All @@ -347,7 +349,15 @@ def retry(self, func, args, kwargs, err, attempt):
sleep(delay)


class TrinoRequest(object):
class _RetryAfterSleep:
def __init__(self, retry_after_header):
self._retry_after_header = retry_after_header

def retry(self):
sleep(self._retry_after_header)


class TrinoRequest:
"""
Manage the HTTP requests of a Trino query.

Expand Down Expand Up @@ -523,9 +533,9 @@ def max_attempts(self, value) -> None:
self._handle_retry,
handled_exceptions=self._exceptions,
conditions=(
# need retry when there is no exception but the status code is 502, 503, or 504
# need retry when there is no exception but the status code is 429, 502, 503, or 504
lambda response: getattr(response, "status_code", None)
in (502, 503, 504),
in (429, 502, 503, 504),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add any delays for the retry? Immediate retry will just make the problem worse (not sure if that's already built in)

),
max_attempts=self._max_attempts,
)
Expand Down Expand Up @@ -683,7 +693,7 @@ def _verify_extra_credential(self, header):
raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'")


class TrinoResult(object):
class TrinoResult:
"""
Represent the result of a Trino query as an iterator on rows.

Expand Down Expand Up @@ -721,7 +731,7 @@ def __iter__(self):
self._rows = next_rows


class TrinoQuery(object):
class TrinoQuery:
"""Represent the execution of a SQL statement by Trino."""

def __init__(
Expand Down Expand Up @@ -887,7 +897,12 @@ def decorated(*args, **kwargs):
try:
result = func(*args, **kwargs)
if any(guard(result) for guard in conditions):
handle_retry.retry(func, args, kwargs, None, attempt)
if result.status_code == 429 and "Retry-After" in result.headers:
retry_after = _parse_retry_after_header(result.headers.get("Retry-After"))
handle_retry_sleep = _RetryAfterSleep(retry_after)
handle_retry_sleep.retry()
else:
handle_retry.retry(func, args, kwargs, None, attempt)
continue
return result
except Exception as err:
Expand All @@ -904,3 +919,14 @@ def decorated(*args, **kwargs):
return decorated

return wrapper


def _parse_retry_after_header(retry_after):
if isinstance(retry_after, int):
return retry_after
elif isinstance(retry_after, str) and retry_after.isdigit():
return int(retry_after)
else:
retry_date = parsedate_to_datetime(retry_after)
now = datetime.utcnow()
return (retry_date - now).total_seconds()
4 changes: 2 additions & 2 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def connect(*args, **kwargs):
return Connection(*args, **kwargs)


class Connection(object):
class Connection:
"""Trino supports transactions and the ability to either commit or rollback
a sequence of SQL statements. A single query i.e. the execution of a SQL
statement, can also be cancelled. Transactions are not supported by this
Expand Down Expand Up @@ -329,7 +329,7 @@ def from_column(cls, column: Dict[str, Any]):
)


class Cursor(object):
class Cursor:
"""Database cursor.

Cursors are not isolated, i.e., any changes done to the database by a
Expand Down
2 changes: 1 addition & 1 deletion trino/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def check(cls, level: int) -> int:
return level


class Transaction(object):
class Transaction:
def __init__(self, request: trino.client.TrinoRequest) -> None:
self._request = request
self._id = NO_TRANSACTION
Expand Down