diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index e579c58730..828d9d2873 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -879,17 +879,7 @@ def get_provider( provider_name = "geth" if provider_name in self.providers: - provider = self.providers[provider_name]() - - if not provider.is_connected: - # Apply the settings before first connection process. - provider.provider_settings = provider_settings - - elif provider.provider_settings != provider_settings: - # Handle new settings. - new_settings = {**provider.provider_settings, **provider_settings} - provider.update_settings(new_settings) - + provider = self.providers[provider_name](provider_settings=provider_settings) if provider.connection_id in ProviderContextManager.connected_providers: # Likely multi-chain testing or utilizing multiple on-going connections. return ProviderContextManager.connected_providers[provider.connection_id] diff --git a/src/ape/api/providers.py b/src/ape/api/providers.py index 01d7efe990..71325c9aa3 100644 --- a/src/ape/api/providers.py +++ b/src/ape/api/providers.py @@ -1568,7 +1568,8 @@ def _create_trace_frame(self, evm_frame: EvmTraceFrame) -> TraceFrame: raw=evm_frame.dict(), ) - def _make_request(self, endpoint: str, parameters: List) -> Any: + def _make_request(self, endpoint: str, parameters: Optional[List] = None) -> Any: + parameters = parameters or [] coroutine = self.web3.provider.make_request(RPCEndpoint(endpoint), parameters) result = run_until_complete(coroutine) diff --git a/src/ape_test/provider.py b/src/ape_test/provider.py index 655321516a..58778c5ee1 100644 --- a/src/ape_test/provider.py +++ b/src/ape_test/provider.py @@ -1,16 +1,34 @@ +import operator import re from ast import literal_eval -from typing import Optional, cast +from typing import Dict, Optional, cast from eth.exceptions import HeaderNotFound from eth_tester.backends import PyEVMBackend # type: ignore -from eth_tester.exceptions import TransactionFailed # type: ignore -from eth_utils import is_0x_prefixed +from eth_tester.exceptions import FilterNotFound, TransactionFailed # type: ignore +from eth_utils import decode_hex, encode_hex, is_0x_prefixed, is_null, keccak +from eth_utils.curried import apply_formatter_if from eth_utils.exceptions import ValidationError +from eth_utils.toolz import compose, excepts from ethpm_types import HexBytes from web3 import EthereumTesterProvider, Web3 from web3.exceptions import ContractPanicError -from web3.providers.eth_tester.defaults import API_ENDPOINTS, static_return +from web3.providers.eth_tester.defaults import ( + call_eth_tester, + client_version, + create_log_filter, + create_new_account, + get_logs, + get_transaction_by_block_hash_and_index, + get_transaction_by_block_number_and_index, + not_implemented, + null_if_block_not_found, + null_if_filter_not_found, + null_if_transaction_not_found, + personal_send_transaction, + static_return, + without_eth_tester, +) from web3.types import TxParams from ape.api import PluginConfig, ReceiptAPI, TestProviderAPI, TransactionAPI, Web3Provider @@ -47,7 +65,7 @@ def evm_backend(self) -> PyEVMBackend: def connect(self): chain_id = self.settings.chain_id if self._web3 is not None: - connected_chain_id = self.chain_id + connected_chain_id = self._make_request("eth_chainId") if connected_chain_id == chain_id: # Is already connected and settings have not changed. return @@ -56,7 +74,7 @@ def connect(self): mnemonic=self.config.mnemonic, num_accounts=self.config.number_of_accounts, ) - endpoints = {**API_ENDPOINTS} + endpoints = get_endpoints() endpoints["eth"]["chainId"] = static_return(chain_id) tester = EthereumTesterProvider(ethereum_tester=self._evm_backend, api_endpoints=endpoints) self._web3 = Web3(tester) @@ -65,9 +83,13 @@ def disconnect(self): self.cached_chain_id = None self._web3 = None self._evm_backend = None + self.provider_settings = {} - def update_settings(self, new_settings: dict): - pass + def update_settings(self, new_settings: Dict): + self.cached_chain_id = None + self.provider_settings = {**self.provider_settings, **new_settings} + self.disconnect() + self.connect() def estimate_gas_cost(self, txn: TransactionAPI, **kwargs) -> int: if isinstance(self.network.gas_limit, int): @@ -80,9 +102,7 @@ def estimate_gas_cost(self, txn: TransactionAPI, **kwargs) -> int: block_id = kwargs.pop("block_identifier", kwargs.pop("block_id", None)) estimate_gas = self.web3.eth.estimate_gas txn_dict = txn.dict() - if txn_dict.get("gas") == "auto": - # Remove from dict before estimating - txn_dict.pop("gas") + txn_dict.pop("gas", None) txn_data = cast(TxParams, txn_dict) try: @@ -110,10 +130,9 @@ def estimate_gas_cost(self, txn: TransactionAPI, **kwargs) -> int: @property def settings(self) -> EthTesterProviderConfig: - self.config.provider = EthTesterProviderConfig.parse_obj( + return EthTesterProviderConfig.parse_obj( {**self.config.provider.dict(), **self.provider_settings} ) - return self.config.provider @property def chain_id(self) -> int: @@ -121,7 +140,7 @@ def chain_id(self) -> int: return self.cached_chain_id try: - result = self._make_request("eth_chainId", []) + result = self._make_request("eth_chainId") except ProviderNotConnectedError: result = self.settings.chain_id @@ -276,3 +295,222 @@ def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMa else: return VirtualMachineError(base_err=exception, **kwargs) + + +def get_endpoints(): + # HACK: Until https://github.com/ethereum/web3.py/issues/3136 resolved. + return { + "web3": { + "clientVersion": client_version, + "sha3": compose( + encode_hex, + keccak, + decode_hex, + without_eth_tester(operator.itemgetter(0)), + ), + }, + "net": { + "version": static_return("1"), + "listening": static_return(False), + "peerCount": static_return(0), + }, + "eth": { + "protocolVersion": static_return(63), + "syncing": static_return(False), + "coinbase": compose( + operator.itemgetter(0), + call_eth_tester("get_accounts"), + ), + "mining": static_return(False), + "hashrate": static_return(0), + "chainId": static_return(131277322940537), # from fixture generation file + "feeHistory": not_implemented, + "maxPriorityFeePerGas": static_return(10**9), + "gasPrice": static_return(10**9), # must be >= base fee post-London + "accounts": call_eth_tester("get_accounts"), + "blockNumber": compose( + operator.itemgetter("number"), + call_eth_tester("get_block_by_number", fn_kwargs={"block_number": "latest"}), + ), + "getBalance": call_eth_tester("get_balance"), + "getStorageAt": call_eth_tester("get_storage_at"), + "getProof": not_implemented, + "getTransactionCount": call_eth_tester("get_nonce"), + "getBlockTransactionCountByHash": null_if_block_not_found( + compose( + len, + operator.itemgetter("transactions"), + call_eth_tester("get_block_by_hash"), + ) + ), + "getBlockTransactionCountByNumber": null_if_block_not_found( + compose( + len, + operator.itemgetter("transactions"), + call_eth_tester("get_block_by_number"), + ) + ), + "getUncleCountByBlockHash": null_if_block_not_found( + compose( + len, + operator.itemgetter("uncles"), + call_eth_tester("get_block_by_hash"), + ) + ), + "getUncleCountByBlockNumber": null_if_block_not_found( + compose( + len, + operator.itemgetter("uncles"), + call_eth_tester("get_block_by_number"), + ) + ), + "getCode": call_eth_tester("get_code"), + "sign": not_implemented, + "signTransaction": not_implemented, + "sendTransaction": call_eth_tester("send_transaction"), + "sendRawTransaction": call_eth_tester("send_raw_transaction"), + "call": call_eth_tester("call"), # TODO: untested + "estimateGas": call_eth_tester("estimate_gas"), # TODO: untested + "getBlockByHash": null_if_block_not_found(call_eth_tester("get_block_by_hash")), + "getBlockByNumber": null_if_block_not_found(call_eth_tester("get_block_by_number")), + "getTransactionByHash": null_if_transaction_not_found( + call_eth_tester("get_transaction_by_hash") + ), + "getTransactionByBlockHashAndIndex": get_transaction_by_block_hash_and_index, + "getTransactionByBlockNumberAndIndex": get_transaction_by_block_number_and_index, # noqa: E501 + "getTransactionReceipt": null_if_transaction_not_found( + compose( + apply_formatter_if( + compose(is_null, operator.itemgetter("block_number")), + static_return(None), + ), + call_eth_tester("get_transaction_receipt"), + ) + ), + "getUncleByBlockHashAndIndex": not_implemented, + "getUncleByBlockNumberAndIndex": not_implemented, + "getCompilers": not_implemented, + "compileLLL": not_implemented, + "compileSolidity": not_implemented, + "compileSerpent": not_implemented, + "newFilter": create_log_filter, + "newBlockFilter": call_eth_tester("create_block_filter"), + "newPendingTransactionFilter": call_eth_tester("create_pending_transaction_filter"), + "uninstallFilter": excepts( + FilterNotFound, + compose( + is_null, + call_eth_tester("delete_filter"), + ), + static_return(False), + ), + "getFilterChanges": null_if_filter_not_found( + call_eth_tester("get_only_filter_changes") + ), + "getFilterLogs": null_if_filter_not_found(call_eth_tester("get_all_filter_logs")), + "getLogs": get_logs, + "getWork": not_implemented, + "submitWork": not_implemented, + "submitHashrate": not_implemented, + }, + "db": { + "putString": not_implemented, + "getString": not_implemented, + "putHex": not_implemented, + "getHex": not_implemented, + }, + "admin": { + "add_peer": not_implemented, + "datadir": not_implemented, + "node_info": not_implemented, + "peers": not_implemented, + "start_http": not_implemented, + "start_ws": not_implemented, + "stop_http": not_implemented, + "stop_ws": not_implemented, + }, + "debug": { + "backtraceAt": not_implemented, + "blockProfile": not_implemented, + "cpuProfile": not_implemented, + "dumpBlock": not_implemented, + "gtStats": not_implemented, + "getBlockRLP": not_implemented, + "goTrace": not_implemented, + "memStats": not_implemented, + "seedHashSign": not_implemented, + "setBlockProfileRate": not_implemented, + "setHead": not_implemented, + "stacks": not_implemented, + "startCPUProfile": not_implemented, + "startGoTrace": not_implemented, + "stopCPUProfile": not_implemented, + "stopGoTrace": not_implemented, + "traceBlock": not_implemented, + "traceBlockByNumber": not_implemented, + "traceBlockByHash": not_implemented, + "traceBlockFromFile": not_implemented, + "traceTransaction": not_implemented, + "verbosity": not_implemented, + "vmodule": not_implemented, + "writeBlockProfile": not_implemented, + "writeMemProfile": not_implemented, + }, + "miner": { + "make_dag": not_implemented, + "set_extra": not_implemented, + "set_gas_price": not_implemented, + "start": not_implemented, + "stop": not_implemented, + "start_auto_dag": not_implemented, + "stop_auto_dag": not_implemented, + }, + "personal": { + "ec_recover": not_implemented, + "import_raw_key": call_eth_tester("add_account"), + "list_accounts": call_eth_tester("get_accounts"), + "list_wallets": not_implemented, + "lock_account": excepts( + ValidationError, + compose(static_return(True), call_eth_tester("lock_account")), + static_return(False), + ), + "new_account": create_new_account, + "unlock_account": excepts( + ValidationError, + compose(static_return(True), call_eth_tester("unlock_account")), + static_return(False), + ), + "send_transaction": personal_send_transaction, + "sign": not_implemented, + # deprecated + "ecRecover": not_implemented, + "importRawKey": call_eth_tester("add_account"), + "listAccounts": call_eth_tester("get_accounts"), + "lockAccount": excepts( + ValidationError, + compose(static_return(True), call_eth_tester("lock_account")), + static_return(False), + ), + "newAccount": create_new_account, + "unlockAccount": excepts( + ValidationError, + compose(static_return(True), call_eth_tester("unlock_account")), + static_return(False), + ), + "sendTransaction": personal_send_transaction, + }, + "testing": { + "timeTravel": call_eth_tester("time_travel"), + }, + "txpool": { + "content": not_implemented, + "inspect": not_implemented, + "status": not_implemented, + }, + "evm": { + "mine": call_eth_tester("mine_blocks"), + "revert": call_eth_tester("revert_to_snapshot"), + "snapshot": call_eth_tester("take_snapshot"), + }, + } diff --git a/tests/functional/test_network_manager.py b/tests/functional/test_network_manager.py index e6e2d490a5..127cd6b382 100644 --- a/tests/functional/test_network_manager.py +++ b/tests/functional/test_network_manager.py @@ -167,6 +167,7 @@ def test_parse_network_choice_same_provider(chain, networks_connected_to_tester, assert provider._web3 is not None +@pytest.mark.xdist_group(name="multiple-eth-testers") def test_parse_network_choice_new_chain_id(get_provider_with_unused_chain_id, get_context): start_count = len(get_context().connected_providers) context = get_provider_with_unused_chain_id() @@ -182,17 +183,26 @@ def test_parse_network_choice_new_chain_id(get_provider_with_unused_chain_id, ge assert provider._web3 is not None +@pytest.mark.xdist_group(name="multiple-eth-testers") def test_disconnect_after(get_provider_with_unused_chain_id): context = get_provider_with_unused_chain_id() with context as provider: connection_id = provider.connection_id assert connection_id in context.connected_providers - assert context not in context.connected_providers + assert connection_id not in context.connected_providers -def test_parse_network_choice_multiple_contexts(get_provider_with_unused_chain_id): +@pytest.mark.xdist_group(name="multiple-eth-testers") +def test_parse_network_choice_multiple_contexts( + eth_tester_provider, get_provider_with_unused_chain_id +): first_context = get_provider_with_unused_chain_id() + assert ( + eth_tester_provider.chain_id == DEFAULT_TEST_CHAIN_ID + ), "Test setup failed - expecting to start on default chain ID" + assert eth_tester_provider._make_request("eth_chainId") == DEFAULT_TEST_CHAIN_ID + with first_context: start_count = len(first_context.connected_providers) expected_next_count = start_count + 1 @@ -202,6 +212,9 @@ def test_parse_network_choice_multiple_contexts(get_provider_with_unused_chain_i assert len(first_context.connected_providers) == expected_next_count assert len(second_context.connected_providers) == expected_next_count + assert eth_tester_provider.chain_id == DEFAULT_TEST_CHAIN_ID + assert eth_tester_provider._make_request("eth_chainId") == DEFAULT_TEST_CHAIN_ID + def test_getattr_ecosystem_with_hyphenated_name(networks, ethereum): networks.ecosystems["hyphen-in-name"] = networks.ecosystems["ethereum"]