Skip to content

Commit

Permalink
Add async methods to source contracts clients (#1505)
Browse files Browse the repository at this point in the history
  • Loading branch information
moisses89 authored Jan 2, 2025
1 parent cabd05f commit 8e75014
Show file tree
Hide file tree
Showing 8 changed files with 362 additions and 38 deletions.
7 changes: 6 additions & 1 deletion safe_eth/eth/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: noqa F401
from .blockscout_client import (
AsyncBlockscoutClient,
BlockscoutClient,
BlockscoutClientException,
BlockScoutConfigurationProblem,
Expand All @@ -12,14 +13,18 @@
EtherscanClientException,
EtherscanRateLimitError,
)
from .etherscan_client_v2 import EtherscanClientV2
from .etherscan_client_v2 import AsyncEtherscanClientV2, EtherscanClientV2
from .sourcify_client import (
AsyncSourcifyClient,
SourcifyClient,
SourcifyClientConfigurationProblem,
SourcifyClientException,
)

__all__ = [
"AsyncBlockscoutClient",
"AsyncEtherscanClientV2",
"AsyncSourcifyClient",
"BlockScoutConfigurationProblem",
"BlockscoutClient",
"BlockscoutClientException",
Expand Down
79 changes: 69 additions & 10 deletions safe_eth/eth/clients/blockscout_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import os
from typing import Any, Dict, Optional
from urllib.parse import urljoin

import aiohttp
import requests
from eth_typing import ChecksumAddress

Expand Down Expand Up @@ -150,9 +152,16 @@ class BlockscoutClient:
EthereumNetwork.EXSAT_TESTNET: "https://scan-testnet.exsat.network/api/v1/graphql",
}

def __init__(self, network: EthereumNetwork):
def __init__(
self,
network: EthereumNetwork,
request_timeout: int = int(
os.environ.get("BLOCKSCOUT_CLIENT_REQUEST_TIMEOUT", 10)
),
):
self.network = network
self.grahpql_url = self.NETWORK_WITH_URL.get(network, "")
self.request_timeout = request_timeout
if not self.grahpql_url:
raise BlockScoutConfigurationProblem(
f"Network {network.name} - {network.value} not supported"
Expand All @@ -169,19 +178,69 @@ def _do_request(self, url: str, query: str) -> Optional[Dict[str, Any]]:

return response.json()

def get_contract_metadata(
self, address: ChecksumAddress
@staticmethod
def _process_contract_metadata(
contract_data: dict[str, Any]
) -> Optional[ContractMetadata]:
query = '{address(hash: "%s") { hash, smartContract {name, abi} }}' % address
result = self._do_request(self.grahpql_url, query)
"""
Return a ContractMetadata from BlockScout response
:param contract_data:
:return:
"""
if (
result
and "error" not in result
and result.get("data", {}).get("address", {})
and result["data"]["address"]["smartContract"]
"error" not in contract_data
and contract_data.get("data", {}).get("address", {})
and contract_data["data"]["address"]["smartContract"]
):
smart_contract = result["data"]["address"]["smartContract"]
smart_contract = contract_data["data"]["address"]["smartContract"]
return ContractMetadata(
smart_contract["name"], json.loads(smart_contract["abi"]), False
)
return None

def get_contract_metadata(
self, address: ChecksumAddress
) -> Optional[ContractMetadata]:
query = '{address(hash: "%s") { hash, smartContract {name, abi} }}' % address
contract_data = self._do_request(self.grahpql_url, query)
if contract_data:
return self._process_contract_metadata(contract_data)
return None


class AsyncBlockscoutClient(BlockscoutClient):
def __init__(
self,
network: EthereumNetwork,
request_timeout: int = int(
os.environ.get("BLOCKSCOUT_CLIENT_REQUEST_TIMEOUT", 10)
),
max_requests: int = int(os.environ.get("BLOCKSCOUT_CLIENT_MAX_REQUESTS", 100)),
):
super().__init__(network, request_timeout)
# Limit simultaneous connections to the same host.
self.async_session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit_per_host=max_requests)
)

async def _async_do_request(self, url: str, query: str) -> Optional[Dict[str, Any]]:
"""
Asynchronous version of _do_request
"""
async with self.async_session.post(
url, json={"query": query}, timeout=self.request_timeout
) as response:
if not response.ok:
return None

return await response.json()

async def async_get_contract_metadata(
self, address: ChecksumAddress
) -> Optional[ContractMetadata]:
query = '{address(hash: "%s") { hash, smartContract {name, abi} }}' % address
contract_data = await self._async_do_request(self.grahpql_url, query)
if contract_data:
return self._process_contract_metadata(contract_data)
return None
45 changes: 28 additions & 17 deletions safe_eth/eth/clients/etherscan_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,19 +353,42 @@ def _retry_request(
time.sleep(5)
return None

@staticmethod
def _process_contract_metadata(
contract_data: Dict[str, Any]
) -> Optional[ContractMetadata]:
contract_name = contract_data["ContractName"]
contract_abi = contract_data["ABI"]
if contract_abi:
return ContractMetadata(contract_name, contract_abi, False)
return None

def get_contract_metadata(
self, contract_address: str, retry: bool = True
) -> Optional[ContractMetadata]:
contract_source_code = self.get_contract_source_code(
contract_address, retry=retry
)
if contract_source_code:
contract_name = contract_source_code["ContractName"]
contract_abi = contract_source_code["ABI"]
if contract_abi:
return ContractMetadata(contract_name, contract_abi, False)
return self._process_contract_metadata(contract_source_code)
return None

@staticmethod
def _process_get_contract_source_code_response(response):
if response and isinstance(response, list):
result = response[0]
abi_str = result.get("ABI")

if isinstance(abi_str, str) and abi_str.startswith("["):
try:
result["ABI"] = json.loads(abi_str)
except json.JSONDecodeError:
result["ABI"] = None # Handle the case where JSON decoding fails
else:
result["ABI"] = None

return result

def get_contract_source_code(self, contract_address: str, retry: bool = True):
"""
Get source code for a contract. Source code query also returns:
Expand All @@ -390,19 +413,7 @@ def get_contract_source_code(self, contract_address: str, retry: bool = True):
f"module=contract&action=getsourcecode&address={contract_address}"
)
response = self._retry_request(url, retry=retry) # Returns a list
if response and isinstance(response, list):
result = response[0]
abi_str = result.get("ABI")

if isinstance(abi_str, str) and abi_str.startswith("["):
try:
result["ABI"] = json.loads(abi_str)
except json.JSONDecodeError:
result["ABI"] = None # Handle the case where JSON decoding fails
else:
result["ABI"] = None

return result
return self._process_get_contract_source_code_response(response)

def get_contract_abi(self, contract_address: str, retry: bool = True):
url = self.build_url(
Expand Down
85 changes: 83 additions & 2 deletions safe_eth/eth/clients/etherscan_client_v2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import json
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urljoin

import aiohttp
import requests

from safe_eth.eth import EthereumNetwork
from safe_eth.eth.clients import EtherscanClient
from safe_eth.eth.clients import (
ContractMetadata,
EtherscanClient,
EtherscanRateLimitError,
)


class EtherscanClientV2(EtherscanClient):
Expand Down Expand Up @@ -78,3 +84,78 @@ def is_supported_network(cls, network: EthereumNetwork) -> bool:
return any(
item.get("chainid") == str(network.value) for item in supported_networks
)


class AsyncEtherscanClientV2(EtherscanClientV2):
def __init__(
self,
network: EthereumNetwork,
api_key: Optional[str] = None,
request_timeout: int = int(
os.environ.get("ETHERSCAN_CLIENT_REQUEST_TIMEOUT", 10)
),
max_requests: int = int(os.environ.get("ETHERSCAN_CLIENT_MAX_REQUESTS", 100)),
):
super().__init__(network, api_key, request_timeout)
self.async_session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit_per_host=max_requests)
)

async def _async_do_request(
self, url: str
) -> Optional[Union[Dict[str, Any], List[Any], str]]:
"""
Async version of _do_request
"""
async with self.async_session.get(
url, timeout=self.request_timeout
) as response:
if response.ok:
response_json = await response.json()
result = response_json["result"]
if "Max rate limit reached" in result:
# Max rate limit reached, please use API Key for higher rate limit
raise EtherscanRateLimitError
if response_json["status"] == "1":
return result
return None

async def async_get_contract_source_code(
self,
contract_address: str,
):
"""
Asynchronous version of get_contract_source_code
Does not implement retries
:param contract_address:
"""
url = self.build_url(
f"module=contract&action=getsourcecode&address={contract_address}"
)
response = await self._async_do_request(url) # Returns a list
return self._process_get_contract_source_code_response(response)

async def async_get_contract_metadata(
self, contract_address: str
) -> Optional[ContractMetadata]:
contract_source_code = await self.async_get_contract_source_code(
contract_address
)
if contract_source_code:
return self._process_contract_metadata(contract_source_code)
return None

async def async_get_contract_abi(self, contract_address: str):
url = self.build_url(
f"module=contract&action=getabi&address={contract_address}"
)
result = await self._async_do_request(url)
if isinstance(result, dict):
return result
elif isinstance(result, str):
try:
return json.loads(result)
except json.JSONDecodeError:
pass
return None
Loading

0 comments on commit 8e75014

Please sign in to comment.