Skip to content

Commit

Permalink
feat: wip
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Aug 6, 2024
1 parent bce1824 commit 223ddd0
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 56 deletions.
3 changes: 3 additions & 0 deletions src/ape/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,9 @@ def __eq__(self, other: Any) -> bool:
return NotImplemented


CurrencyValueComparable.__name__ = int.__name__


CurrencyValue: TypeAlias = CurrencyValueComparable
"""
An alias to :class:`~ape.types.CurrencyValueComparable` for
Expand Down
101 changes: 54 additions & 47 deletions src/ape_ethereum/ecosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
is_hex,
is_hex_address,
keccak,
to_bytes,
to_checksum_address,
to_hex,
)
from ethpm_types import ContractType
from ethpm_types.abi import ABIType, ConstructorABI, EventABI, MethodABI
Expand Down Expand Up @@ -161,7 +163,7 @@ def validate_gas_limit(cls, value):
return int(value)

elif isinstance(value, str) and is_hex(value) and is_0x_prefixed(value):
return to_int(HexBytes(value))
return int(value, 16)

elif is_hex(value):
raise ValueError("Gas limit hex str must include '0x' prefix.")
Expand Down Expand Up @@ -400,7 +402,7 @@ def decode_address(cls, raw_address: RawAddress) -> AddressType:

@classmethod
def encode_address(cls, address: AddressType) -> RawAddress:
return str(address)
return f"{address}"

def decode_transaction_type(self, transaction_type_id: Any) -> type[TransactionAPI]:
if isinstance(transaction_type_id, TransactionType):
Expand Down Expand Up @@ -742,7 +744,7 @@ def _enrich_value(self, value: Any, **kwargs) -> Any:
if len(value) > 24:
return humanize_hash(cast(Hash32, value))

hex_str = HexBytes(value).hex()
hex_str = to_hex(value)
if is_hex_address(hex_str):
return self._enrich_value(hex_str, **kwargs)

Expand Down Expand Up @@ -1023,6 +1025,7 @@ def get_abi(_topic: HexStr) -> Optional[LogInputABICollection]:
def enrich_trace(self, trace: TraceAPI, **kwargs) -> TraceAPI:
kwargs["trace"] = trace
if not isinstance(trace, Trace):
# Can only enrich `ape_ethereum.trace.Trace` (or subclass) implementations.
return trace

elif trace._enriched_calltree is not None:
Expand All @@ -1042,11 +1045,8 @@ def enrich_trace(self, trace: TraceAPI, **kwargs) -> TraceAPI:
# Return value was discovered already.
kwargs["return_value"] = return_value

enriched_calltree = self._enrich_calltree(data, **kwargs)

# Cache the result back on the trace.
trace._enriched_calltree = enriched_calltree

trace._enriched_calltree = self._enrich_calltree(data, **kwargs)
return trace

def _enrich_calltree(self, call: dict, **kwargs) -> dict:
Expand Down Expand Up @@ -1075,10 +1075,10 @@ def _enrich_calltree(self, call: dict, **kwargs) -> dict:
call["calls"] = [self._enrich_calltree(c, **kwargs) for c in subcalls]

# Figure out the contract.
address = call.pop("address", "")
address: AddressType = call.pop("address", "")
try:
call["contract_id"] = address = kwargs["contract_address"] = str(
self.decode_address(address)
call["contract_id"] = address = kwargs["contract_address"] = self.decode_address(
address
)
except Exception:
# Tx was made with a weird address.
Expand All @@ -1099,25 +1099,24 @@ def _enrich_calltree(self, call: dict, **kwargs) -> dict:
else:
# Collapse pre-compile address calls
if 1 <= address_int <= 9:
if len(call.get("calls", [])) == 1:
return call["calls"][0]

return {"contract_id": f"{address_int}", "calls": call["calls"]}
return (
call["calls"][0]
if len(call.get("calls", [])) == 1
else {"contract_id": f"{address_int}", "calls": call["calls"]}
)

depth = call.get("depth", 0)
if depth == 0 and address in self.account_manager:
call["contract_id"] = f"__{self.fee_token_symbol}_transfer__"
else:
call["contract_id"] = self._enrich_contract_id(call["contract_id"], **kwargs)

if not (contract_type := self.chain_manager.contracts.get(address)):
# Without a contract, we can enrich no further.
if not (contract_type := self._get_contract_type_for_enrichment(address, **kwargs)):
# Without a contract type, we can enrich no further.
return call

if events := call.get("events"):
call["events"] = self._enrich_trace_events(
events, address=address, contract_type=contract_type
)
call["events"] = self._enrich_trace_events(events, address=address, **kwargs)

method_abi: Optional[Union[MethodABI, ConstructorABI]] = None
if is_create:
Expand All @@ -1126,17 +1125,19 @@ def _enrich_calltree(self, call: dict, **kwargs) -> dict:

elif call["method_id"] != "0x":
method_id_bytes = HexBytes(call["method_id"])
if method_id_bytes in contract_type.methods:

# perf: use try/except instead of __contains__ check.
try:
method_abi = contract_type.methods[method_id_bytes]
except KeyError:
name = call["method_id"]
else:
assert isinstance(method_abi, MethodABI) # For mypy

# Check if method name duplicated. If that is the case, use selector.
times = len([x for x in contract_type.methods if x.name == method_abi.name])
name = (method_abi.name if times == 1 else method_abi.selector) or call["method_id"]
call = self._enrich_calldata(call, method_abi, contract_type, **kwargs)

else:
name = call["method_id"]
else:
name = call.get("method_id") or "0x"

Expand Down Expand Up @@ -1167,10 +1168,11 @@ def _enrich_contract_id(self, address: AddressType, **kwargs) -> str:
elif address == ZERO_ADDRESS:
return "ZERO_ADDRESS"

if not (contract_type := self.chain_manager.contracts.get(address)):
elif not (contract_type := self._get_contract_type_for_enrichment(address, **kwargs)):
# Without a contract type, we can enrich no further.
return address

elif kwargs.get("use_symbol_for_tokens") and "symbol" in contract_type.view_methods:
if kwargs.get("use_symbol_for_tokens") and "symbol" in contract_type.view_methods:
# Use token symbol as name
contract = self.chain_manager.contracts.instance_at(
address, contract_type=contract_type
Expand Down Expand Up @@ -1202,12 +1204,13 @@ def _enrich_calldata(
**kwargs,
) -> dict:
calldata = call["calldata"]
if isinstance(calldata, (str, bytes, int)):
calldata_arg = HexBytes(calldata)
if isinstance(calldata, str):
calldata_arg = to_bytes(hexstr=calldata)
elif isinstance(calldata, bytes):
calldata_arg = calldata
else:
# Not sure if we can get here.
# Mostly for mypy's sake.
return call
# Already enriched.
return calldata

if call.get("call_type") and "CREATE" in call.get("call_type", ""):
# Strip off bytecode
Expand All @@ -1217,7 +1220,7 @@ def _enrich_calldata(
else b""
)
# TODO: Handle Solidity Metadata (delegate to Compilers again?)
calldata_arg = HexBytes(calldata_arg.split(bytecode)[-1])
calldata_arg = HexBytes(calldata.split(bytecode)[-1])

try:
call["calldata"] = self.decode_calldata(method_abi, calldata_arg)
Expand Down Expand Up @@ -1311,18 +1314,15 @@ def _enrich_trace_events(
self,
events: list[dict],
address: Optional[AddressType] = None,
contract_type: Optional[ContractType] = None,
**kwargs,
) -> list[dict]:
return [
self._enrich_trace_event(e, address=address, contract_type=contract_type)
for e in events
]
return [self._enrich_trace_event(e, address=address, **kwargs) for e in events]

def _enrich_trace_event(
self,
event: dict,
address: Optional[AddressType] = None,
contract_type: Optional[ContractType] = None,
**kwargs,
) -> dict:
if "topics" not in event or len(event["topics"]) < 1:
# Already enriched or wrong.
Expand All @@ -1334,16 +1334,9 @@ def _enrich_trace_event(
# Cannot enrich further w/o an address.
return event

if not contract_type:
try:
contract_type = self.chain_manager.contracts.get(address)
except Exception as err:
logger.debug(f"Error getting contract type during event enrichment: {err}")
return event

if not contract_type:
# Cannot enrich further w/o an contract type.
return event
if not (contract_type := self._get_contract_type_for_enrichment(address, **kwargs)):
# Without a contract type, we can enrich no further.
return event

# The selector is always the first topic.
selector = event["topics"][0]
Expand Down Expand Up @@ -1388,6 +1381,20 @@ def _enrich_revert_message(self, call: dict) -> dict:

return call

def _get_contract_type_for_enrichment(
self, address: AddressType, **kwargs
) -> Optional[ContractType]:
if not (contract_type := kwargs.get("contract_type")):
try:
contract_type = self.chain_manager.contracts.get(address)
except Exception as err:
logger.debug(f"Error getting contract type during event enrichment: {err}")
else:
# Set even if None to avoid re-checking.
kwargs["contract_type"] = contract_type

return contract_type

def get_python_types(self, abi_type: ABIType) -> Union[type, Sequence]:
return self._python_type_for_abi_type(abi_type)

Expand Down
15 changes: 8 additions & 7 deletions src/ape_ethereum/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,12 @@ def return_value(self) -> Any:

elif abi := self.root_method_abi:
return_data = self._return_data_from_trace_frames
try:
return self._ecosystem.decode_returndata(abi, return_data)
except Exception as err:
logger.debug(f"Failed decoding return data from trace frames. Error: {err}")
# Use enrichment method. It is slow but it'll at least work.
if return_data is not None:
try:
return self._ecosystem.decode_returndata(abi, return_data)
except Exception as err:
logger.debug(f"Failed decoding return data from trace frames. Error: {err}")
# Use enrichment method. It is slow but it'll at least work.

return self._return_value_from_enriched_calltree

Expand Down Expand Up @@ -289,13 +290,13 @@ def _revert_str_from_trace_frames(self) -> Optional[HexBytes]:
return None

@cached_property
def _return_data_from_trace_frames(self) -> HexBytes:
def _return_data_from_trace_frames(self) -> Optional[HexBytes]:
if frame := self._last_frame:
memory = frame["memory"]
start_pos = int(frame["stack"][2], 16) // 32
return HexBytes("".join(memory[start_pos:]))

return HexBytes("")
return None

""" API Methods """

Expand Down
17 changes: 16 additions & 1 deletion src/ape_test/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from web3.providers.eth_tester.defaults import API_ENDPOINTS, static_return
from web3.types import TxParams

from ape.api import BlockAPI, PluginConfig, ReceiptAPI, TestProviderAPI, TransactionAPI
from ape.api import BlockAPI, PluginConfig, ReceiptAPI, TestProviderAPI, TraceAPI, TransactionAPI
from ape.exceptions import (
APINotImplementedError,
ContractLogicError,
Expand All @@ -31,6 +31,7 @@
from ape.types import AddressType, BlockID, ContractLog, LogFilter, SnapshotID
from ape.utils import DEFAULT_TEST_CHAIN_ID, DEFAULT_TEST_HD_PATH, gas_estimation_error_message
from ape_ethereum.provider import Web3Provider
from ape_ethereum.trace import TraceApproach, TransactionTrace

if TYPE_CHECKING:
from ape.api.accounts import TestAccountAPI
Expand Down Expand Up @@ -388,6 +389,12 @@ def _get_last_base_fee(self) -> int:

raise APINotImplementedError("No base fee found in block.")

def get_transaction_trace(self, transaction_hash: str, **kwargs) -> TraceAPI:
if "call_trace_approach" not in kwargs:
kwargs["call_trace_approach"] = TraceApproach.BASIC

return EthTesterTransactionTrace(transaction_hash=transaction_hash, **kwargs)

def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMachineError:
if isinstance(exception, ValidationError):
match = self._CANNOT_AFFORD_GAS_PATTERN.match(str(exception))
Expand Down Expand Up @@ -438,3 +445,11 @@ def _get_latest_block(self) -> BlockAPI:

def _get_latest_block_rpc(self) -> dict:
return self.evm_backend.get_block_by_number("latest")


class EthTesterTransactionTrace(TransactionTrace):
@cached_property
def return_value(self) -> Any:
# perf: skip trying anything else, because eth-tester doesn't
# yet implement any tracing RPCs.
return self._return_value_from_enriched_calltree
1 change: 0 additions & 1 deletion tests/functional/geth/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@ def test_return_value_list(geth_account, geth_contract, geth_provider):
def test_return_value_nested_address_array(
geth_account, geth_contract, geth_provider, zero_address
):
geth_contract.getNestedAddressArray()
receipt = geth_contract.getNestedAddressArray.transact(sender=geth_account)
expected = [
[geth_account.address, geth_account.address, geth_account.address],
Expand Down

0 comments on commit 223ddd0

Please sign in to comment.