diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a7a4502..dccf3e6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: - id: isort - repo: https://github.com/psf/black - rev: 23.11.0 + rev: 23.12.0 hooks: - id: black name: black diff --git a/ape_foundry/provider.py b/ape_foundry/provider.py index dc12039..2d789f9 100644 --- a/ape_foundry/provider.py +++ b/ape_foundry/provider.py @@ -44,7 +44,8 @@ from evm_trace import CallType, ParityTraceList from evm_trace import TraceFrame as EvmTraceFrame from evm_trace import get_calltree_from_geth_trace, get_calltree_from_parity_trace -from pydantic import ConfigDict, model_validator +from pydantic import model_validator +from pydantic_settings import SettingsConfigDict from web3 import HTTPProvider, Web3 from web3.exceptions import ContractCustomError from web3.exceptions import ContractLogicError as Web3ContractLogicError @@ -109,7 +110,7 @@ class FoundryNetworkConfig(PluginConfig): Set a block time to allow mining to happen on an interval rather than only when a new transaction is submitted. """ - model_config = ConfigDict(extra="allow") + model_config = SettingsConfigDict(extra="allow") def _call(*args): @@ -595,7 +596,13 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: self.chain_manager.history.append(receipt) return receipt - def send_call(self, txn: TransactionAPI, **kwargs: Any) -> bytes: + def send_call( + self, + txn: TransactionAPI, + block_id: Optional[BlockID] = None, + state: Optional[Dict] = None, + **kwargs, + ) -> HexBytes: skip_trace = kwargs.pop("skip_trace", False) arguments = self._prepare_call(txn, **kwargs) @@ -682,7 +689,7 @@ def _trace_call(self, arguments: List[Any]) -> Tuple[Dict, Iterator[EvmTraceFram trace_data = result.get("structLogs", []) return result, (EvmTraceFrame(**f) for f in trace_data) - def get_balance(self, address: str) -> int: + def get_balance(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: if hasattr(address, "address"): address = address.address @@ -806,7 +813,7 @@ def set_storage(self, address: AddressType, slot: int, value: HexBytes): [address, to_bytes32(slot).hex(), to_bytes32(value).hex()], ) - def _eth_call(self, arguments: List) -> bytes: + def _eth_call(self, arguments: List) -> HexBytes: # Override from Web3Provider because foundry is pickier. txn_dict = copy(arguments[0]) diff --git a/tests/conftest.py b/tests/conftest.py index 84ce1a1..b72fd0e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -116,7 +116,8 @@ def networks(): @pytest.fixture def get_contract_type(): def fn(name: str) -> ContractType: - return ContractType.parse_file(LOCAL_CONTRACTS_PATH / f"{name}.json") + json_file = json.loads((LOCAL_CONTRACTS_PATH / f"{name}.json").read_text()) + return ContractType.model_validate(json_file) return fn