diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 0cc9deec..8748e80f 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -873,7 +873,7 @@ def test_extra_credential_value_encoding(mock_get_and_post): def test_extra_credential_value_object(mock_get_and_post): _, post = mock_get_and_post - class TestCredential(object): + class TestCredential: value = "initial" def __str__(self): @@ -972,7 +972,7 @@ def test_authentication_gssapi_init_arguments( assert session.auth.creds == expected_credentials -class RetryRecorder(object): +class RetryRecorder: def __init__(self, error=None, result=None): self.__name__ = "RetryRecorder" self._retry_count = 0 @@ -1058,6 +1058,33 @@ def test_5XX_error_retry(status_code, attempts, monkeypatch): assert post_retry.retry_count == attempts +def test_429_error_retry(monkeypatch): + http_resp = TrinoRequest.http.Response() + http_resp.status_code = 429 + http_resp.headers["Retry-After"] = 1 + + post_retry = RetryRecorder(result=http_resp) + monkeypatch.setattr(TrinoRequest.http.Session, "post", post_retry) + + get_retry = RetryRecorder(result=http_resp) + monkeypatch.setattr(TrinoRequest.http.Session, "get", get_retry) + + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession( + user="test", + ), + max_attempts=3 + ) + + req.post("URL") + assert post_retry.retry_count == 3 + + req.get("URL") + assert post_retry.retry_count == 3 + + @pytest.mark.parametrize("status_code", [ 501 ]) @@ -1087,7 +1114,7 @@ def test_error_no_retry(status_code, monkeypatch): assert post_retry.retry_count == 1 -class FakeGatewayResponse(object): +class FakeGatewayResponse: def __init__(self, http_response, redirect_count=1): self.__name__ = "FakeGatewayResponse" self.http_response = http_response @@ -1197,7 +1224,7 @@ def test_retry_with(): max_attempts=max_attempts, ) - class FailerUntil(object): + class FailerUntil: def __init__(self, until=1): self.attempt = 0 self._until = until diff --git a/trino/client.py b/trino/client.py index eca78537..4fea8321 100644 --- a/trino/client.py +++ b/trino/client.py @@ -43,6 +43,8 @@ import urllib.parse import warnings from dataclasses import dataclass +from datetime import datetime +from email.utils import parsedate_to_datetime from time import sleep from typing import Any, Dict, List, Optional, Tuple, Union @@ -76,7 +78,7 @@ ROLE_PATTERN = re.compile(r"^ROLE\{(.*)\}$") -class ClientSession(object): +class ClientSession: """ Manage the current Client Session properties of a specific connection. This class is thread-safe. @@ -319,9 +321,9 @@ def __repr__(self): ) -class _DelayExponential(object): +class _DelayExponential: def __init__( - self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours + self, base=0.1, exponent=2, jitter=True, max_delay=1800 # 100ms # 30 min ): self._base = base self._exponent = exponent @@ -336,9 +338,9 @@ def __call__(self, attempt): return delay -class _RetryWithExponentialBackoff(object): +class _RetryWithExponentialBackoff: def __init__( - self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours + self, base=0.1, exponent=2, jitter=True, max_delay=1800 # 100ms # 30 min ): self._get_delay = _DelayExponential(base, exponent, jitter, max_delay) @@ -347,7 +349,15 @@ def retry(self, func, args, kwargs, err, attempt): sleep(delay) -class TrinoRequest(object): +class _RetryAfterSleep: + def __init__(self, retry_after_header): + self._retry_after_header = retry_after_header + + def retry(self): + sleep(self._retry_after_header) + + +class TrinoRequest: """ Manage the HTTP requests of a Trino query. @@ -523,9 +533,9 @@ def max_attempts(self, value) -> None: self._handle_retry, handled_exceptions=self._exceptions, conditions=( - # need retry when there is no exception but the status code is 502, 503, or 504 + # need retry when there is no exception but the status code is 429, 502, 503, or 504 lambda response: getattr(response, "status_code", None) - in (502, 503, 504), + in (429, 502, 503, 504), ), max_attempts=self._max_attempts, ) @@ -683,7 +693,7 @@ def _verify_extra_credential(self, header): raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'") -class TrinoResult(object): +class TrinoResult: """ Represent the result of a Trino query as an iterator on rows. @@ -721,7 +731,7 @@ def __iter__(self): self._rows = next_rows -class TrinoQuery(object): +class TrinoQuery: """Represent the execution of a SQL statement by Trino.""" def __init__( @@ -887,7 +897,12 @@ def decorated(*args, **kwargs): try: result = func(*args, **kwargs) if any(guard(result) for guard in conditions): - handle_retry.retry(func, args, kwargs, None, attempt) + if result.status_code == 429 and "Retry-After" in result.headers: + retry_after = _parse_retry_after_header(result.headers.get("Retry-After")) + handle_retry_sleep = _RetryAfterSleep(retry_after) + handle_retry_sleep.retry() + else: + handle_retry.retry(func, args, kwargs, None, attempt) continue return result except Exception as err: @@ -904,3 +919,14 @@ def decorated(*args, **kwargs): return decorated return wrapper + + +def _parse_retry_after_header(retry_after): + if isinstance(retry_after, int): + return retry_after + elif isinstance(retry_after, str) and retry_after.isdigit(): + return int(retry_after) + else: + retry_date = parsedate_to_datetime(retry_after) + now = datetime.utcnow() + return (retry_date - now).total_seconds() diff --git a/trino/dbapi.py b/trino/dbapi.py index bfa64886..9125a487 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -126,7 +126,7 @@ def connect(*args, **kwargs): return Connection(*args, **kwargs) -class Connection(object): +class Connection: """Trino supports transactions and the ability to either commit or rollback a sequence of SQL statements. A single query i.e. the execution of a SQL statement, can also be cancelled. Transactions are not supported by this @@ -329,7 +329,7 @@ def from_column(cls, column: Dict[str, Any]): ) -class Cursor(object): +class Cursor: """Database cursor. Cursors are not isolated, i.e., any changes done to the database by a diff --git a/trino/transaction.py b/trino/transaction.py index ebead938..b308875d 100644 --- a/trino/transaction.py +++ b/trino/transaction.py @@ -49,7 +49,7 @@ def check(cls, level: int) -> int: return level -class Transaction(object): +class Transaction: def __init__(self, request: trino.client.TrinoRequest) -> None: self._request = request self._id = NO_TRANSACTION