From 72718229c6507049d95a4f5c68b018aae529efca Mon Sep 17 00:00:00 2001 From: eliseygusev Date: Tue, 8 Nov 2022 18:06:00 +0400 Subject: [PATCH] Async client (#31) Add Async Client --- .gitignore | 4 +- dune_client/base_client.py | 30 ++++++ dune_client/client.py | 61 ++++++------ dune_client/client_async.py | 155 +++++++++++++++++++++++++++++ dune_client/interface.py | 5 +- dune_client/query.py | 8 +- dune_client/types.py | 4 +- requirements/dev.txt | 3 +- tests/e2e/test_async_client.py | 173 +++++++++++++++++++++++++++++++++ tests/unit/test_query.py | 11 +++ 10 files changed, 417 insertions(+), 37 deletions(-) create mode 100644 dune_client/base_client.py create mode 100644 dune_client/client_async.py create mode 100644 tests/e2e/test_async_client.py diff --git a/.gitignore b/.gitignore index 5760737..92d18f7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ .env __pycache__/ dist -*.egg-info \ No newline at end of file +*.egg-info +_version.py +.idea/ diff --git a/dune_client/base_client.py b/dune_client/base_client.py new file mode 100644 index 0000000..fa38b58 --- /dev/null +++ b/dune_client/base_client.py @@ -0,0 +1,30 @@ +"""" +Basic Dune Client Class responsible for refreshing Dune Queries +Framework built on Dune's API Documentation +https://duneanalytics.notion.site/API-Documentation-1b93d16e0fa941398e15047f643e003a +""" +from __future__ import annotations + +import logging.config +from typing import Dict + + +# pylint: disable=too-few-public-methods +class BaseDuneClient: + """ + A Base Client for Dune which sets up default values + and provides some convenient functions to use in other clients + """ + + BASE_URL = "https://api.dune.com" + API_PATH = "/api/v1" + DEFAULT_TIMEOUT = 10 + + def __init__(self, api_key: str): + self.token = api_key + self.logger = logging.getLogger(__name__) + logging.basicConfig(format="%(asctime)s %(levelname)s %(name)s %(message)s") + + def default_headers(self) -> Dict[str, str]: + """Return default headers containing Dune Api token""" + return {"x-dune-api-key": self.token} diff --git a/dune_client/client.py b/dune_client/client.py index 53586b8..1a7e3f9 100644 --- a/dune_client/client.py +++ b/dune_client/client.py @@ -5,14 +5,13 @@ """ from __future__ import annotations -import logging.config import time -from json import JSONDecodeError from typing import Any import requests -from requests import Response +from requests import Response, JSONDecodeError +from dune_client.base_client import BaseDuneClient from dune_client.interface import DuneInterface from dune_client.models import ( ExecutionResponse, @@ -24,53 +23,55 @@ from dune_client.query import Query -log = logging.getLogger(__name__) -logging.basicConfig( - format="%(asctime)s %(levelname)s %(name)s %(message)s", level=logging.DEBUG -) - -BASE_URL = "https://api.dune.com/api/v1" - -class DuneClient(DuneInterface): +class DuneClient(DuneInterface, BaseDuneClient): """ An interface for Dune API with a few convenience methods combining the use of endpoints (e.g. refresh) """ - def __init__(self, api_key: str): - self.token = api_key - - @staticmethod def _handle_response( + self, response: Response, ) -> Any: try: # Some responses can be decoded and converted to DuneErrors response_json = response.json() - log.debug(f"received response {response_json}") + self.logger.debug(f"received response {response_json}") return response_json except JSONDecodeError as err: # Others can't. Only raise HTTP error for not decodable errors response.raise_for_status() raise ValueError("Unreachable since previous line raises") from err - def _get(self, url: str) -> Any: - log.debug(f"GET received input url={url}") - response = requests.get(url, headers={"x-dune-api-key": self.token}, timeout=10) + def _route_url(self, route: str) -> str: + return f"{self.BASE_URL}{self.API_PATH}/{route}" + + def _get(self, route: str) -> Any: + url = self._route_url(route) + self.logger.debug(f"GET received input url={url}") + response = requests.get( + url, + headers={"x-dune-api-key": self.token}, + timeout=self.DEFAULT_TIMEOUT, + ) return self._handle_response(response) - def _post(self, url: str, params: Any) -> Any: - log.debug(f"POST received input url={url}, params={params}") + def _post(self, route: str, params: Any) -> Any: + url = self._route_url(route) + self.logger.debug(f"POST received input url={url}, params={params}") response = requests.post( - url=url, json=params, headers={"x-dune-api-key": self.token}, timeout=10 + url=url, + json=params, + headers={"x-dune-api-key": self.token}, + timeout=self.DEFAULT_TIMEOUT, ) return self._handle_response(response) def execute(self, query: Query) -> ExecutionResponse: """Post's to Dune API for execute `query`""" response_json = self._post( - url=f"{BASE_URL}/query/{query.query_id}/execute", + route=f"query/{query.query_id}/execute", params={ "query_parameters": { p.key: p.to_dict()["value"] for p in query.parameters() @@ -85,7 +86,7 @@ def execute(self, query: Query) -> ExecutionResponse: def get_status(self, job_id: str) -> ExecutionStatusResponse: """GET status from Dune API for `job_id` (aka `execution_id`)""" response_json = self._get( - url=f"{BASE_URL}/execution/{job_id}/status", + route=f"execution/{job_id}/status", ) try: return ExecutionStatusResponse.from_dict(response_json) @@ -94,7 +95,7 @@ def get_status(self, job_id: str) -> ExecutionStatusResponse: def get_result(self, job_id: str) -> ResultsResponse: """GET results from Dune API for `job_id` (aka `execution_id`)""" - response_json = self._get(url=f"{BASE_URL}/execution/{job_id}/results") + response_json = self._get(route=f"execution/{job_id}/results") try: return ResultsResponse.from_dict(response_json) except KeyError as err: @@ -102,9 +103,7 @@ def get_result(self, job_id: str) -> ResultsResponse: def cancel_execution(self, job_id: str) -> bool: """POST Execution Cancellation to Dune API for `job_id` (aka `execution_id`)""" - response_json = self._post( - url=f"{BASE_URL}/execution/{job_id}/cancel", params=None - ) + response_json = self._post(route=f"execution/{job_id}/cancel", params=None) try: # No need to make a dataclass for this since it's just a boolean. success: bool = response_json["success"] @@ -121,12 +120,14 @@ def refresh(self, query: Query, ping_frequency: int = 5) -> ResultsResponse: job_id = self.execute(query).execution_id status = self.get_status(job_id) while status.state not in ExecutionState.terminal_states(): - log.info(f"waiting for query execution {job_id} to complete: {status}") + self.logger.info( + f"waiting for query execution {job_id} to complete: {status}" + ) time.sleep(ping_frequency) status = self.get_status(job_id) full_response = self.get_result(job_id) if status.state == ExecutionState.FAILED: - log.error(status) + self.logger.error(status) raise Exception(f"{status}. Perhaps your query took too long to run!") return full_response diff --git a/dune_client/client_async.py b/dune_client/client_async.py new file mode 100644 index 0000000..d7fc609 --- /dev/null +++ b/dune_client/client_async.py @@ -0,0 +1,155 @@ +"""" +Async Dune Client Class responsible for refreshing Dune Queries +Framework built on Dune's API Documentation +https://duneanalytics.notion.site/API-Documentation-1b93d16e0fa941398e15047f643e003a +""" +import asyncio +from typing import Any + +from aiohttp import ( + ClientSession, + ClientResponse, + ContentTypeError, + TCPConnector, + ClientTimeout, +) + +from dune_client.base_client import BaseDuneClient +from dune_client.models import ( + ExecutionResponse, + DuneError, + ExecutionStatusResponse, + ResultsResponse, + ExecutionState, +) + +from dune_client.query import Query + + +# pylint: disable=duplicate-code +class AsyncDuneClient(BaseDuneClient): + """ + An asynchronous interface for Dune API with a few convenience methods + combining the use of endpoints (e.g. refresh) + """ + + _connection_limit = 3 + + def __init__(self, api_key: str, connection_limit: int = 3): + """ + api_key - Dune API key + connection_limit - number of parallel requests to execute. + For non-pro accounts Dune allows only up to 3 requests but that number can be increased. + """ + super().__init__(api_key=api_key) + self._connection_limit = connection_limit + self._session = self._create_session() + + def _create_session(self) -> ClientSession: + conn = TCPConnector(limit=self._connection_limit) + return ClientSession( + connector=conn, + base_url=self.BASE_URL, + timeout=ClientTimeout(total=self.DEFAULT_TIMEOUT), + ) + + async def close_session(self) -> None: + """Closes client session""" + await self._session.close() + + async def __aenter__(self) -> None: + self._session = self._create_session() + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + await self.close_session() + + async def _handle_response( + self, + response: ClientResponse, + ) -> Any: + try: + # Some responses can be decoded and converted to DuneErrors + response_json = await response.json() + self.logger.debug(f"received response {response_json}") + return response_json + except ContentTypeError as err: + # Others can't. Only raise HTTP error for not decodable errors + response.raise_for_status() + raise ValueError("Unreachable since previous line raises") from err + + async def _get(self, url: str) -> Any: + self.logger.debug(f"GET received input url={url}") + response = await self._session.get( + url=f"{self.API_PATH}{url}", + headers=self.default_headers(), + ) + return await self._handle_response(response) + + async def _post(self, url: str, params: Any) -> Any: + self.logger.debug(f"POST received input url={url}, params={params}") + response = await self._session.post( + url=f"{self.API_PATH}{url}", + json=params, + headers=self.default_headers(), + ) + return await self._handle_response(response) + + async def execute(self, query: Query) -> ExecutionResponse: + """Post's to Dune API for execute `query`""" + response_json = await self._post( + url=f"/query/{query.query_id}/execute", + params=query.request_format(), + ) + try: + return ExecutionResponse.from_dict(response_json) + except KeyError as err: + raise DuneError(response_json, "ExecutionResponse", err) from err + + async def get_status(self, job_id: str) -> ExecutionStatusResponse: + """GET status from Dune API for `job_id` (aka `execution_id`)""" + response_json = await self._get( + url=f"/execution/{job_id}/status", + ) + try: + return ExecutionStatusResponse.from_dict(response_json) + except KeyError as err: + raise DuneError(response_json, "ExecutionStatusResponse", err) from err + + async def get_result(self, job_id: str) -> ResultsResponse: + """GET results from Dune API for `job_id` (aka `execution_id`)""" + response_json = await self._get(url=f"/execution/{job_id}/results") + try: + return ResultsResponse.from_dict(response_json) + except KeyError as err: + raise DuneError(response_json, "ResultsResponse", err) from err + + async def cancel_execution(self, job_id: str) -> bool: + """POST Execution Cancellation to Dune API for `job_id` (aka `execution_id`)""" + response_json = await self._post(url=f"/execution/{job_id}/cancel", params=None) + try: + # No need to make a dataclass for this since it's just a boolean. + success: bool = response_json["success"] + return success + except KeyError as err: + raise DuneError(response_json, "CancellationResponse", err) from err + + async def refresh(self, query: Query, ping_frequency: int = 5) -> ResultsResponse: + """ + Executes a Dune `query`, waits until execution completes, + fetches and returns the results. + Sleeps `ping_frequency` seconds between each status request. + """ + job_id = (await self.execute(query)).execution_id + status = await self.get_status(job_id) + while status.state not in ExecutionState.terminal_states(): + self.logger.info( + f"waiting for query execution {job_id} to complete: {status}" + ) + await asyncio.sleep(ping_frequency) + status = await self.get_status(job_id) + + full_response = await self.get_result(job_id) + if status.state == ExecutionState.FAILED: + self.logger.error(status) + raise Exception(f"{status}. Perhaps your query took too long to run!") + return full_response diff --git a/dune_client/interface.py b/dune_client/interface.py index bf81f10..f7b2cc9 100644 --- a/dune_client/interface.py +++ b/dune_client/interface.py @@ -1,18 +1,19 @@ """ Abstract class for a basic Dune Interface with refresh method used by Query Runner. """ -from abc import ABC +import abc from dune_client.models import ResultsResponse from dune_client.query import Query # pylint: disable=too-few-public-methods -class DuneInterface(ABC): +class DuneInterface(abc.ABC): """ User Facing Methods for a Dune Client """ + @abc.abstractmethod def refresh(self, query: Query) -> ResultsResponse: """ Executes a Dune query, waits till query execution completes, diff --git a/dune_client/query.py b/dune_client/query.py index 8858df0..72eef67 100644 --- a/dune_client/query.py +++ b/dune_client/query.py @@ -3,7 +3,7 @@ """ import urllib.parse from dataclasses import dataclass -from typing import Optional, List +from typing import Optional, List, Dict from dune_client.types import QueryParameter @@ -40,3 +40,9 @@ def __hash__(self) -> int: Thus, it is unique for caching purposes """ return self.url().__hash__() + + def request_format(self) -> Dict[str, Dict[str, str]]: + """Transforms Query objects to params to pass in API""" + return { + "query_parameters": {p.key: p.to_dict()["value"] for p in self.parameters()} + } diff --git a/dune_client/types.py b/dune_client/types.py index 4247946..17d2f19 100644 --- a/dune_client/types.py +++ b/dune_client/types.py @@ -159,9 +159,9 @@ def value_str(self) -> str: return str(self.value.strftime("%Y-%m-%d %H:%M:%S")) raise TypeError(f"Type {self.type} not recognized!") - def to_dict(self) -> dict[str, str | list[str]]: + def to_dict(self) -> dict[str, str]: """Converts QueryParameter into string json format accepted by Dune API""" - results: dict[str, str | list[str]] = { + results: dict[str, str] = { "key": self.key, "type": self.type.value, "value": self.value_str(), diff --git a/requirements/dev.txt b/requirements/dev.txt index c2fcc63..4e77d21 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -3,4 +3,5 @@ black>=22.8.0 pylint>=2.15.0 pytest>=7.1.3 python-dotenv>=0.21.0 -mypy>=0.971 \ No newline at end of file +mypy>=0.971 +aiounittest>=1.4.2 diff --git a/tests/e2e/test_async_client.py b/tests/e2e/test_async_client.py new file mode 100644 index 0000000..c12a611 --- /dev/null +++ b/tests/e2e/test_async_client.py @@ -0,0 +1,173 @@ +import copy +import os +import time +import unittest + +import aiounittest +import dotenv + +from dune_client.client_async import AsyncDuneClient +from dune_client.types import QueryParameter +from dune_client.client import ( + ExecutionResponse, + ExecutionStatusResponse, + ExecutionState, + DuneError, +) +from dune_client.query import Query + + +class TestDuneClient(aiounittest.AsyncTestCase): + def setUp(self) -> None: + self.query = Query( + name="Sample Query", + query_id=1215383, + params=[ + # These are the queries default parameters. + QueryParameter.text_type(name="TextField", value="Plain Text"), + QueryParameter.number_type(name="NumberField", value=3.1415926535), + QueryParameter.date_type(name="DateField", value="2022-05-04 00:00:00"), + QueryParameter.enum_type(name="ListField", value="Option 1"), + ], + ) + dotenv.load_dotenv() + self.valid_api_key = os.environ["DUNE_API_KEY"] + + async def test_get_status(self): + query = Query(name="No Name", query_id=1276442, params=[]) + dune = AsyncDuneClient(self.valid_api_key) + job_id = (await dune.execute(query)).execution_id + status = await dune.get_status(job_id) + self.assertTrue( + status.state in [ExecutionState.EXECUTING, ExecutionState.PENDING] + ) + await dune.close_session() + + async def test_refresh(self): + dune = AsyncDuneClient(self.valid_api_key) + results = (await dune.refresh(self.query)).get_rows() + self.assertGreater(len(results), 0) + await dune.close_session() + + async def test_parameters_recognized(self): + query = copy.copy(self.query) + new_params = [ + # Using all different values for parameters. + QueryParameter.text_type(name="TextField", value="different word"), + QueryParameter.number_type(name="NumberField", value=22), + QueryParameter.date_type(name="DateField", value="1991-01-01 00:00:00"), + QueryParameter.enum_type(name="ListField", value="Option 2"), + ] + query.params = new_params + self.assertEqual(query.parameters(), new_params) + + dune = AsyncDuneClient(self.valid_api_key) + results = await dune.refresh(query) + self.assertEqual( + results.get_rows(), + [ + { + "text_field": "different word", + "number_field": "22", + "date_field": "1991-01-01 00:00:00", + "list_field": "Option 2", + } + ], + ) + await dune.close_session() + + async def test_endpoints(self): + dune = AsyncDuneClient(self.valid_api_key) + execution_response = await dune.execute(self.query) + self.assertIsInstance(execution_response, ExecutionResponse) + job_id = execution_response.execution_id + status = await dune.get_status(job_id) + self.assertIsInstance(status, ExecutionStatusResponse) + state = (await dune.get_status(job_id)).state + while state != ExecutionState.COMPLETED: + state = (await dune.get_status(job_id)).state + time.sleep(1) + results = (await dune.get_result(job_id)).result.rows + self.assertGreater(len(results), 0) + await dune.close_session() + + async def test_cancel_execution(self): + dune = AsyncDuneClient(self.valid_api_key) + query = Query( + name="Long Running Query", + query_id=1229120, + ) + execution_response = await dune.execute(query) + job_id = execution_response.execution_id + # POST Cancellation + success = await dune.cancel_execution(job_id) + self.assertTrue(success) + + results = await dune.get_result(job_id) + self.assertEqual(results.state, ExecutionState.CANCELLED) + await dune.close_session() + + async def test_invalid_api_key_error(self): + dune = AsyncDuneClient(api_key="Invalid Key") + with self.assertRaises(DuneError) as err: + await dune.execute(self.query) + self.assertEqual( + str(err.exception), + "Can't build ExecutionResponse from {'error': 'invalid API Key'}", + ) + with self.assertRaises(DuneError) as err: + await dune.get_status("wonky job_id") + self.assertEqual( + str(err.exception), + "Can't build ExecutionStatusResponse from {'error': 'invalid API Key'}", + ) + with self.assertRaises(DuneError) as err: + await dune.get_result("wonky job_id") + self.assertEqual( + str(err.exception), + "Can't build ResultsResponse from {'error': 'invalid API Key'}", + ) + await dune.close_session() + + async def test_query_not_found_error(self): + dune = AsyncDuneClient(self.valid_api_key) + query = copy.copy(self.query) + query.query_id = 99999999 # Invalid Query Id. + + with self.assertRaises(DuneError) as err: + await dune.execute(query) + self.assertEqual( + str(err.exception), + "Can't build ExecutionResponse from {'error': 'Query not found'}", + ) + await dune.close_session() + + async def test_internal_error(self): + dune = AsyncDuneClient(self.valid_api_key) + query = copy.copy(self.query) + # This query ID is too large! + query.query_id = 9999999999999 + + with self.assertRaises(DuneError) as err: + await dune.execute(query) + self.assertEqual( + str(err.exception), + "Can't build ExecutionResponse from {'error': 'An internal error occured'}", + ) + await dune.close_session() + + async def test_invalid_job_id_error(self): + dune = AsyncDuneClient(self.valid_api_key) + + with self.assertRaises(DuneError) as err: + await dune.get_status("Wonky Job ID") + self.assertEqual( + str(err.exception), + "Can't build ExecutionStatusResponse from " + "{'error': 'The requested execution ID (ID: Wonky Job ID) is invalid.'}", + ) + await dune.close_session() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index 904bfeb..a08dc21 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -29,6 +29,17 @@ def test_url(self): def test_parameters(self): self.assertEqual(self.query.parameters(), self.query_params) + def test_request_format(self): + expected_answer = { + "query_parameters": { + "Enum": "option1", + "Text": "plain text", + "Number": "12", + "Date": "2021-01-01 12:34:56", + } + } + self.assertEqual(self.query.request_format(), expected_answer) + def test_hash(self): # Same ID, different params query1 = Query(query_id=0, params=[QueryParameter.text_type("Text", "word1")])