Skip to content

Commit

Permalink
Merge pull request #107 from duneanalytics/ratelimit
Browse files Browse the repository at this point in the history
Handle rate limit responses and automatically retry the requests
  • Loading branch information
dune-eng authored Jan 23, 2024
2 parents d2195b2 + 8a465d3 commit 5859507
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 25 deletions.
22 changes: 16 additions & 6 deletions dune_client/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from json import JSONDecodeError
from typing import Dict, Optional, Any

import requests
from requests import Response
from requests import Response, Session
from requests.adapters import HTTPAdapter, Retry

from dune_client.util import get_package_version

Expand All @@ -38,6 +38,17 @@ def __init__( # pylint: disable=too-many-arguments
self.performance = performance
self.logger = logging.getLogger(__name__)
logging.basicConfig(format="%(asctime)s %(levelname)s %(name)s %(message)s")
retry_strategy = Retry(
total=5,
backoff_factor=0.5,
status_forcelist={429, 502, 503, 504},
allowed_methods={"GET", "POST", "PATCH"},
raise_on_status=True,
)
adapter = HTTPAdapter(max_retries=retry_strategy)
self.http = Session()
self.http.mount("https://", adapter)
self.http.mount("http://", adapter)

@classmethod
def from_env(cls) -> BaseDuneClient:
Expand Down Expand Up @@ -93,7 +104,7 @@ def _get(
"""Generic interface for the GET method of a Dune API request"""
url = self._route_url(route)
self.logger.debug(f"GET received input url={url}")
response = requests.get(
response = self.http.get(
url=url,
headers=self.default_headers(),
timeout=self.request_timeout,
Expand All @@ -107,7 +118,7 @@ def _post(self, route: str, params: Optional[Any] = None) -> Any:
"""Generic interface for the POST method of a Dune API request"""
url = self._route_url(route)
self.logger.debug(f"POST received input url={url}, params={params}")
response = requests.post(
response = self.http.post(
url=url,
json=params,
headers=self.default_headers(),
Expand All @@ -119,8 +130,7 @@ def _patch(self, route: str, params: Any) -> Any:
"""Generic interface for the PATCH method of a Dune API request"""
url = self._route_url(route)
self.logger.debug(f"PATCH received input url={url}, params={params}")
response = requests.request(
method="PATCH",
response = self.http.patch(
url=url,
json=params,
headers=self.default_headers(),
Expand Down
95 changes: 76 additions & 19 deletions dune_client/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

import asyncio
from io import BytesIO
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

from aiohttp import (
ClientResponseError,
ClientSession,
ClientResponse,
ContentTypeError,
Expand All @@ -31,6 +32,29 @@
from dune_client.query import QueryBase, parse_query_object_or_id


class RetryableError(Exception):
"""
Internal exception used to signal that the request should be retried
"""

def __init__(self, base_error: ClientResponseError) -> None:
self.base_error = base_error


class MaxRetryError(Exception):
"""
This exception is raised when the maximum number of retries is exceeded,
e.g. due to rate limiting or internal server errors
"""

def __init__(self, url: str, reason: Exception | None = None) -> None:
self.reason = reason

message = f"Max retries exceeded with url: {url} (Caused by {reason!r})"

super().__init__(message)


# pylint: disable=duplicate-code
class AsyncDuneClient(BaseDuneClient):
"""
Expand Down Expand Up @@ -77,6 +101,13 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.disconnect()

async def _handle_response(self, response: ClientResponse) -> Any:
if response.status in {429, 502, 503, 504}:
try:
response.raise_for_status()
except ClientResponseError as err:
raise RetryableError(
base_error=err,
) from err
try:
# Some responses can be decoded and converted to DuneErrors
response_json = await response.json()
Expand All @@ -90,36 +121,62 @@ async def _handle_response(self, response: ClientResponse) -> Any:
def _route_url(self, route: str) -> str:
return f"{self.api_version}{route}"

async def _handle_ratelimit(self, call: Callable[..., Any], url: str) -> Any:
"""Generic wrapper around request callables. If the request fails due to rate limiting,
or server side errors, it will retry it up to five times, sleeping i * 5s in between
"""
backoff_factor = 0.5
error: Optional[ClientResponseError] = None
for i in range(5):
try:
return await call()
except RetryableError as e:
self.logger.warning(
f"Rate limited or internal error. Retrying in {i * 5} seconds..."
)
error = e.base_error
await asyncio.sleep(i**2 * backoff_factor)

raise MaxRetryError(url, error)

async def _get(
self,
route: str,
params: Optional[Any] = None,
raw: bool = False,
) -> Any:
url = self._route_url(route)
if self._session is None:
raise ValueError("Client is not connected; call `await cl.connect()`")
self.logger.debug(f"GET received input url={url}")
response = await self._session.get(
url=url,
headers=self.default_headers(),
params=params,
)
if raw:
return response
return await self._handle_response(response)

async def _get() -> Any:
if self._session is None:
raise ValueError("Client is not connected; call `await cl.connect()`")
response = await self._session.get(
url=url,
headers=self.default_headers(),
params=params,
)
if raw:
return response
return await self._handle_response(response)

return await self._handle_ratelimit(_get, route)

async def _post(self, route: str, params: Any) -> Any:
url = self._route_url(route)
if self._session is None:
raise ValueError("Client is not connected; call `await cl.connect()`")
self.logger.debug(f"POST received input url={url}, params={params}")
response = await self._session.post(
url=url,
json=params,
headers=self.default_headers(),
)
return await self._handle_response(response)

async def _post() -> Any:
if self._session is None:
raise ValueError("Client is not connected; call `await cl.connect()`")
response = await self._session.post(
url=url,
json=params,
headers=self.default_headers(),
)
return await self._handle_response(response)

return await self._handle_ratelimit(_post, route)

async def execute(
self, query: QueryBase, performance: Optional[str] = None
Expand Down

0 comments on commit 5859507

Please sign in to comment.