Skip to content

Commit

Permalink
refactor: utilize core feature for impersonating accounts and sending…
Browse files Browse the repository at this point in the history
… txns (#113)
  • Loading branch information
antazoey authored Aug 7, 2024
1 parent e5b2c0d commit e7c7aae
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 132 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@ repos:
- id: isort

- repo: https://github.com/psf/black
rev: 24.4.2
rev: 24.8.0
hooks:
- id: black
name: black

- repo: https://github.com/pycqa/flake8
rev: 7.1.0
rev: 7.1.1
hooks:
- id: flake8
additional_dependencies: [flake8-breakpoint, flake8-print, flake8-pydantic]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.0
rev: v1.11.1
hooks:
- id: mypy
additional_dependencies: [types-PyYAML, types-requests, types-setuptools, pydantic]
Expand Down
141 changes: 23 additions & 118 deletions ape_foundry/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import random
import shutil
from bisect import bisect_right
from copy import copy
from subprocess import PIPE, call
from typing import Literal, Optional, Union, cast

Expand Down Expand Up @@ -41,7 +40,6 @@
from web3.gas_strategies.rpc import rpc_gas_price_strategy
from web3.middleware import geth_poa_middleware
from web3.middleware.validation import MAX_EXTRADATA_LENGTH
from web3.types import TxParams
from yarl import URL

from ape_foundry.constants import EVM_VERSION_BY_NETWORK
Expand Down Expand Up @@ -489,20 +487,21 @@ def build_command(self) -> list[str]:

def set_balance(self, account: AddressType, amount: Union[int, float, str, bytes]):
is_str = isinstance(amount, str)
_is_hex = False if not is_str else is_0x_prefixed(str(amount))
is_key_word = is_str and len(str(amount).split(" ")) > 1
is_key_word = is_str and " " in amount # type: ignore
_is_hex = is_str and not is_key_word and amount.startswith("0x") # type: ignore

if is_key_word:
# This allows values such as "1000 ETH".
amount = self.conversion_manager.convert(amount, int)
is_str = False

amount_hex_str = str(amount)

# Convert to hex str
amount_hex_str: str
if is_str and not _is_hex:
amount_hex_str = to_hex(int(amount))
elif isinstance(amount, int) or isinstance(amount, bytes):
elif isinstance(amount, (int, bytes)):
amount_hex_str = to_hex(amount)
else:
amount_hex_str = str(amount)

self.make_request("anvil_setBalance", [account, amount_hex_str])

Expand All @@ -518,9 +517,7 @@ def snapshot(self) -> str:
return self.make_request("evm_snapshot", [])

def restore(self, snapshot_id: SnapshotID) -> bool:
if isinstance(snapshot_id, int):
snapshot_id = HexBytes(snapshot_id).hex()

snapshot_id = to_hex(snapshot_id) if isinstance(snapshot_id, int) else snapshot_id
result = self.make_request("evm_revert", [snapshot_id])
return result is True

Expand All @@ -530,103 +527,12 @@ def unlock_account(self, address: AddressType) -> bool:

def relock_account(self, address: AddressType):
self.make_request("anvil_stopImpersonatingAccount", [address])
if address in self.account_manager.test_accounts._impersonated_accounts:
del self.account_manager.test_accounts._impersonated_accounts[address]

def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI:
"""
Creates a new message call transaction or a contract creation
for signed transactions.
"""
sender = txn.sender
if sender:
sender = self.conversion_manager.convert(txn.sender, AddressType)

vm_err = None
if sender and sender in self.unlocked_accounts:
# Allow for an unsigned transaction
txn = self.prepare_transaction(txn)
txn_dict = txn.model_dump(mode="json", by_alias=True)
if isinstance(txn_dict.get("type"), int):
txn_dict["type"] = HexBytes(txn_dict["type"]).hex()

tx_params = cast(TxParams, txn_dict)
try:
txn_hash = self.web3.eth.send_transaction(tx_params)
except ValueError as err:
raise self.get_virtual_machine_error(err, txn=txn) from err

else:
try:
txn_hash = self.web3.eth.send_raw_transaction(txn.serialize_transaction())
except ValueError as err:
vm_err = self.get_virtual_machine_error(err, txn=txn)

if "nonce too low" in str(vm_err):
# Add additional nonce information
new_err_msg = f"Nonce '{txn.nonce}' is too low"
vm_err = VirtualMachineError(
new_err_msg,
base_err=vm_err.base_err,
code=vm_err.code,
txn=txn,
source_traceback=vm_err.source_traceback,
trace=vm_err.trace,
contract_address=vm_err.contract_address,
)

txn_hash = txn.txn_hash
if txn.raise_on_revert:
raise vm_err from err

receipt = self.get_receipt(
txn_hash.hex(),
required_confirmations=(
txn.required_confirmations
if txn.required_confirmations is not None
else self.network.required_confirmations
),
)
if vm_err:
receipt.error = vm_err

if receipt.failed:
txn_dict = receipt.transaction.model_dump(mode="json", by_alias=True)
if isinstance(txn_dict.get("type"), int):
txn_dict["type"] = HexBytes(txn_dict["type"]).hex()

txn_params = cast(TxParams, txn_dict)

# Replay txn to get revert reason
# NOTE: For some reason, `nonce` can't be in the txn params or else it fails.
if "nonce" in txn_params:
del txn_params["nonce"]

try:
self.web3.eth.call(txn_params)
except Exception as err:
vm_err = self.get_virtual_machine_error(err, txn=receipt)
receipt.error = vm_err
if txn.raise_on_revert:
raise vm_err from err

if txn.raise_on_revert:
# If we get here, for some reason the tx-replay did not produce
# a VM error.
receipt.raise_for_status()

self.chain_manager.history.append(receipt)
return receipt

def get_balance(self, address: AddressType, block_id: Optional[BlockID] = None) -> int:
if hasattr(address, "address"):
address = address.address
if result := self.make_request("eth_getBalance", [address, block_id]):
return int(result, 16) if isinstance(result, str) else result

result = self.make_request("eth_getBalance", [address, block_id])
if not result:
raise FoundryProviderError(f"Failed to get balance for account '{address}'.")

return int(result, 16) if isinstance(result, str) else result
raise FoundryProviderError(f"Failed to get balance for account '{address}'.")

def get_transaction_trace(self, transaction_hash: str, **kwargs) -> TraceAPI:
if "debug_trace_transaction_parameters" not in kwargs:
Expand Down Expand Up @@ -673,7 +579,12 @@ def _handle_execution_reverted(
# Unlikely scenario where a transaction is on the error even though a receipt
# exists.
if isinstance(enriched.txn, TransactionAPI) and enriched.txn.receipt:
enriched.txn.receipt.show_trace()
try:
enriched.txn.receipt.show_trace()
except Exception:
# TODO: Fix this in Ape
pass

elif isinstance(enriched.txn, ReceiptAPI):
enriched.txn.show_trace()

Expand Down Expand Up @@ -732,8 +643,12 @@ def _extract_custom_error(self, **kwargs) -> str:
except Exception:
pass

if trace is not None and (revert_msg := trace.revert_message):
return revert_msg
try:
if trace is not None and (revert_msg := trace.revert_message):
return revert_msg
except Exception:
# TODO: Fix this in core
return ""

return ""

Expand Down Expand Up @@ -763,16 +678,6 @@ def set_storage(self, address: AddressType, slot: int, value: HexBytes):
],
)

def _eth_call(self, arguments: list, raise_on_revert: bool = True) -> HexBytes:
# Overridden to handle unique Foundry pickiness.
txn_dict = copy(arguments[0])
if isinstance(txn_dict.get("type"), int):
txn_dict["type"] = HexBytes(txn_dict["type"]).hex()

txn_dict.pop("chainId", None)
arguments[0] = txn_dict
return super()._eth_call(arguments, raise_on_revert=raise_on_revert)


class FoundryForkProvider(FoundryProvider):
"""
Expand Down
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
"ape-optimism", # For Optimism integration tests
],
"lint": [
"black>=24.4.2,<25", # Auto-formatter and linter
"mypy>=1.11.0,<2", # Static type analyzer
"black>=24.8.0,<25", # Auto-formatter and linter
"mypy>=1.11.1,<2", # Static type analyzer
"types-setuptools", # Needed for mypy type shed
"types-requests", # Needed for mypy type shed
"types-PyYAML", # Needed for mypy type shed
"flake8>=7.1.0,<8", # Style linter
"flake8>=7.1.1,<8", # Style linter
"flake8-breakpoint>=1.1.0,<2", # Detect breakpoints left in code
"flake8-print>=5.0.0,<6", # Detect print statements left in code
"flake8-pydantic", # For detecting issues with Pydantic models
Expand Down Expand Up @@ -74,15 +74,15 @@
url="https://github.com/ApeWorX/ape-foundry",
include_package_data=True,
install_requires=[
"eth-ape>=0.8.10,<0.9",
"eth-ape>=0.8.11,<0.9",
"ethpm-types", # Use same version as eth-ape
"eth-pydantic-types", # Use same version as eth-ape
"evm-trace", # Use same version as ape
"web3", # Use same version as ape
"yarl", # Use same version as ape
"hexbytes", # Use same version as ape
],
python_requires=">=3.8,<4",
python_requires=">=3.9,<4",
extras_require=extras_require,
py_modules=["ape_foundry"],
license="Apache-2.0",
Expand Down
12 changes: 6 additions & 6 deletions tests/expected_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@
│ └── ContractC\.methodC1\(
│ windows95="simpler",
│ jamaica=345457847457457458457457457,
│ cardinal=ContractA
│ cardinal=Contract[A|C]
│ \) \[\d+ gas\]
├── SYMBOL\.callMe\(blue=tx\.origin\) -> tx\.origin \[\d+ gas\]
├── SYMBOL\.methodB2\(trombone=tx\.origin\) \[\d+ gas\]
│ ├── ContractC\.paperwork\(ContractA\) -> \(
│ ├── ContractC\.paperwork\(Contract[A|C]\) -> \(
│ │ os="simpler",
│ │ country=345457847457457458457457457,
│ │ wings=ContractA
│ │ wings=Contract[A|C]
│ │ \) \[\d+ gas\]
│ ├── ContractC\.methodC1\(windows95="simpler", jamaica=0, cardinal=ContractC\) \[\d+ gas\]
│ ├── ContractC\.methodC1\(windows95="simpler", jamaica=0, cardinal=Contract[A|C]\) \[\d+ gas\]
│ ├── ContractC\.methodC2\(\) \[\d+ gas\]
│ └── ContractC\.methodC2\(\) \[\d+ gas\]
├── ContractC\.addressToValue\(tx.origin\) -> 0 \[\d+ gas\]
Expand All @@ -45,14 +45,14 @@
│ │ 111344445534535353,
│ │ 993453434534534534534977788884443333
│ │ \] \[\d+ gas\]
│ └── ContractC\.methodC1\(windows95="simpler", jamaica=0, cardinal=ContractA\) \[\d+ gas\]
│ └── ContractC\.methodC1\(windows95="simpler", jamaica=0, cardinal=Contract[A|C]\) \[\d+ gas\]
└── SYMBOL\.methodB1\(lolol="snitches_get_stiches", dynamo=111\) \[\d+ gas\]
├── ContractC\.getSomeList\(\) -> \[
│ 3425311345134513461345134534531452345,
│ 111344445534535353,
│ 993453434534534534534977788884443333
│ \] \[\d+ gas\]
└── ContractC\.methodC1\(windows95="simpler", jamaica=111, cardinal=ContractA\) \[\d+ gas\]
└── ContractC\.methodC1\(windows95="simpler", jamaica=111, cardinal=Contract[A|C]\) \[\d+ gas\]
"""
MAINNET_FAIL_TRACE_FIRST_10_LINES = r"""
Call trace for '0x053cba5c12172654d894f66d5670bab6215517a94189a9ffc09bc40a589ec04d'
Expand Down

0 comments on commit e7c7aae

Please sign in to comment.