Skip to content

Commit

Permalink
fix: custom error check fails when error has receipt instead of trans…
Browse files Browse the repository at this point in the history
…action (#87)
  • Loading branch information
antazoey authored Feb 1, 2024
1 parent 3df33c1 commit 0b6b55f
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 23 deletions.
56 changes: 35 additions & 21 deletions ape_foundry/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,9 +687,12 @@ def get_transaction_trace(self, txn_hash: str) -> Iterator[TraceFrame]:
for trace in self._get_transaction_trace(txn_hash):
yield self._create_trace_frame(trace)

def _get_transaction_trace(self, txn_hash: str) -> Iterator[EvmTraceFrame]:
def _get_transaction_trace(
self, txn_hash: str, steps_tracing: bool = True, enable_memory: bool = True
) -> Iterator[EvmTraceFrame]:
result = self._make_request(
"debug_traceTransaction", [txn_hash, {"stepsTracing": True, "enableMemory": True}]
"debug_traceTransaction",
[txn_hash, {"stepsTracing": steps_tracing, "enableMemory": enable_memory}],
)
frames = result.get("structLogs", [])
for frame in frames:
Expand Down Expand Up @@ -731,25 +734,10 @@ def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMa

elif message.lower() == "execution reverted":
err = ContractLogicError(TransactionError.DEFAULT_MESSAGE, base_err=exception, **kwargs)

if isinstance(exception, Web3ContractLogicError):
# Check for custom error.
data = {}
if "trace" in kwargs:
kwargs["trace"], new_trace = tee(kwargs["trace"])
data = list(new_trace)[-1].raw

elif "txn" in kwargs:
try:
txn_hash = kwargs["txn"].txn_hash.hex()
data = list(self.get_transaction_trace(txn_hash))[-1].raw
except Exception:
pass

if data.get("op") == "REVERT":
custom_err = "".join([x[2:] for x in data["memory"][4:]])
if custom_err:
err.message = f"0x{custom_err}"
if isinstance(exception, Web3ContractLogicError) and (
msg := _extract_custom_error(self, **kwargs)
):
err.message = msg

err.message = (
TransactionError.DEFAULT_MESSAGE if err.message in ("", "0x", None) else err.message
Expand Down Expand Up @@ -991,3 +979,29 @@ def reset_fork(self, block_number: Optional[int] = None):
# # Rest the fork
result = self._make_request("anvil_reset", [{"forking": forking_params}])
return result


# Abstracted for easier testing conditions.
def _extract_custom_error(provider: FoundryProvider, **kwargs) -> str:
data = {}
if "trace" in kwargs:
kwargs["trace"], new_trace = tee(kwargs["trace"])
data = list(new_trace)[-1].raw

elif "txn" in kwargs:
txn = kwargs["txn"]
txn_hash = txn.txn_hash if isinstance(txn.txn_hash, str) else txn.txn_hash.hex()
try:
trace = list(provider._get_transaction_trace(txn_hash))[-1]
except Exception:
return ""

data = trace.model_dump(by_alias=True, mode="json") if trace else {}

if data and data.get("op") == "REVERT":
memory = data.get("memory", [])
custom_err = "".join([x[2:] for x in memory[4:]])
if custom_err:
return f"0x{custom_err}"

return ""
43 changes: 41 additions & 2 deletions tests/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from ape.types import CallTreeNode, TraceFrame
from ape_ethereum.transactions import TransactionStatusEnum, TransactionType
from eth_pydantic_types import HashBytes32
from eth_utils import to_int
from eth_utils import encode_hex, to_int
from evm_trace import CallType
from hexbytes import HexBytes

from ape_foundry import FoundryProviderError
from ape_foundry.provider import FOUNDRY_CHAIN_ID
from ape_foundry.provider import FOUNDRY_CHAIN_ID, _extract_custom_error

TEST_WALLET_ADDRESS = "0xD9b7fdb3FC0A0Aa3A507dCf0976bc23D49a9C7A3"

Expand Down Expand Up @@ -435,3 +435,42 @@ def test_disable_block_gas_limit(temp_config, disconnected_provider):
with temp_config(data):
cmd = disconnected_provider.build_command()
assert "--disable-block-gas-limit" in cmd


def test_extract_custom_error_trace_given(mocker):
provider = mocker.MagicMock()
trace = mocker.MagicMock()
trace.raw = {"op": "REVERT", "memory": [None, None, None, None, encode_hex("CustomError")]}
trace = iter([trace])
actual = _extract_custom_error(provider, trace=trace)
assert actual.startswith("0x")


def test_extract_custom_error_transaction_given(
connected_provider, vyper_contract_instance, not_owner
):
with pytest.raises(ContractLogicError) as err:
vyper_contract_instance.setNumber(546, sender=not_owner, allow_fail=True)

actual = _extract_custom_error(connected_provider, txn=err.value.txn)
assert actual == ""


@pytest.mark.parametrize("tx_hash", ("0x0123", HexBytes("0x0123")))
def test_extract_custom_error_transaction_given_trace_fails(mocker, tx_hash):
provider = mocker.MagicMock()
tx = mocker.MagicMock()
tx.txn_hash = tx_hash
tracker = []

def trace(txn_hash: str, *args, **kwargs):
tracker.append(txn_hash)
raise ValueError("Connection failed.")

provider._get_transaction_trace.side_effect = trace

actual = _extract_custom_error(provider, txn=tx)
assert actual == ""

# Show failure was tracked
assert tracker[0] == HexBytes(tx.txn_hash).hex()

0 comments on commit 0b6b55f

Please sign in to comment.