Skip to content

Commit

Permalink
feat: ethereum.decode_returndata values to be comparable with curre…
Browse files Browse the repository at this point in the history
…ncy strings (#2159)
  • Loading branch information
antazoey authored Jun 25, 2024
1 parent c153052 commit 2dc9b10
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 76 deletions.
21 changes: 19 additions & 2 deletions src/ape/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from ethpm_types.abi import EventABI
from ethpm_types.source import Closure
from pydantic import BaseModel, field_validator, model_validator
from typing_extensions import TypeAlias
from web3.types import FilterParams

from ape.exceptions import ConversionError
from ape.types.address import AddressType, RawAddress
from ape.types.coverage import (
ContractCoverage,
Expand Down Expand Up @@ -471,7 +473,7 @@ def generator(self) -> Iterator:
yield from self._generator


class CurrencyValue(int):
class CurrencyValueComparable(int):
"""
An integer you can compare with currency-value
strings, such as ``"1 ether"``.
Expand All @@ -480,14 +482,28 @@ class CurrencyValue(int):
def __eq__(self, other: Any) -> bool:
if isinstance(other, int):
return super().__eq__(other)

elif isinstance(other, str):
other_value = ManagerAccessMixin.conversion_manager.convert(other, int)
try:
other_value = ManagerAccessMixin.conversion_manager.convert(other, int)
except ConversionError:
# Not a currency-value, it's ok.
return False

return super().__eq__(other_value)

# Try from the other end, if hasn't already.
return NotImplemented


CurrencyValue: TypeAlias = CurrencyValueComparable
"""
An alias to :class:`~ape.types.CurrencyValueComparable` for
situations when you know for sure the type is a currency-value
(and not just comparable to one).
"""


__all__ = [
"ABI",
"AddressType",
Expand All @@ -506,6 +522,7 @@ def __eq__(self, other: Any) -> bool:
"CoverageReport",
"CoverageStatement",
"CurrencyValue",
"CurrencyValueComparable",
"GasReport",
"MessageSignature",
"PackageManifest",
Expand Down
2 changes: 1 addition & 1 deletion src/ape/utils/abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def encode_input(self, values: Union[list, tuple, dict]) -> Any:
Convert dicts and other objects to struct inputs.
Args:
values (Union[list, ttuple]): A list of of input values.
values (Union[list, tuple]): A list of of input values.
Returns:
Any: The same input values only decoded into structs when applicable.
Expand Down
18 changes: 16 additions & 2 deletions src/ape_ethereum/ecosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
AddressType,
AutoGasLimit,
ContractLog,
CurrencyValueComparable,
GasLimit,
RawAddress,
TransactionSignature,
Expand Down Expand Up @@ -723,7 +724,7 @@ def decode_returndata(self, abi: MethodABI, raw_data: bytes) -> tuple[Any, ...]:
],
)
if issubclass(type(output_values[0]), Struct)
else ([o for o in output_values[0]],)
else ([o for o in output_values[0]],) # type: ignore[union-attr]
)

elif returns_array(abi):
Expand Down Expand Up @@ -756,6 +757,9 @@ def _enrich_value(self, value: Any, **kwargs) -> Any:
# Surround non-address strings with quotes.
return f'"{value}"'

elif isinstance(value, int):
return int(value) # Eliminate int-base classes.

elif isinstance(value, (list, tuple)):
return [self._enrich_value(v, **kwargs) for v in value]

Expand All @@ -766,7 +770,7 @@ def _enrich_value(self, value: Any, **kwargs) -> Any:

def decode_primitive_value(
self, value: Any, output_type: Union[str, tuple, list]
) -> Union[str, HexBytes, tuple, list]:
) -> Union[str, HexBytes, int, tuple, list]:
if output_type == "address":
try:
return self.decode_address(value)
Expand All @@ -776,6 +780,11 @@ def decode_primitive_value(
elif isinstance(value, bytes):
return HexBytes(value)

elif isinstance(value, int):
# Wrap integers in a special type that allows us to compare
# them with currency-value strings.
return CurrencyValueComparable(value)

elif isinstance(output_type, str) and is_array(output_type):
sub_type = "[".join(output_type.split("[")[:-1])

Expand Down Expand Up @@ -989,6 +998,11 @@ def get_abi(_topic: HexStr) -> Optional[LogInputABICollection]:
[self.decode_address(v) for v in value] if sub_type == "address" else value
)

elif isinstance(value, int):
# This allows integers to be comparable with currency-value
# strings, such as "1 ETH".
converted_arguments[key] = CurrencyValueComparable(value)

else:
# No change.
converted_arguments[key] = value
Expand Down
11 changes: 10 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ape.logging import LogLevel, logger
from ape.pytest.config import ConfigWrapper
from ape.pytest.gas import GasTracker
from ape.types import AddressType
from ape.types import AddressType, CurrencyValue
from ape.utils import DEFAULT_TEST_CHAIN_ID, ZERO_ADDRESS
from ape.utils.basemodel import only_raise_attribute_error

Expand All @@ -35,6 +35,10 @@
ape.config.DATA_FOLDER = DATA_FOLDER


EXPECTED_MYSTRUCT_C = 244
"""Hardcoded everywhere in contracts."""


def pytest_addoption(parser):
parser.addoption(
"--includepip", action="store_true", help="run tests that depend on pip install operations"
Expand Down Expand Up @@ -172,6 +176,11 @@ def helper(test_accounts):
return test_accounts[4]


@pytest.fixture(scope="session")
def mystruct_c():
return CurrencyValue(EXPECTED_MYSTRUCT_C)


@pytest.fixture
def signer(test_accounts):
return test_accounts[5]
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/functional/data/sources/SolidityContract.sol
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.2;

contract TestContractSol {
contract SolidityContract {
address public owner;
uint256 public myNumber;
uint256 public prevNumber;
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/data/sources/VyperContract.vy
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def setBalance(_address: address, bal: uint256):
@view
@external
def getStruct() -> MyStruct:
return MyStruct({a: msg.sender, b: block.prevhash})
return MyStruct({a: msg.sender, b: block.prevhash, c: 244})

@view
@external
Expand Down
7 changes: 5 additions & 2 deletions tests/functional/test_contract_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def _assert_log_values(log: ContractLog, number: int, previous_number: Optional[
expected_previous_number = number - 1 if previous_number is None else previous_number
assert log.prevNum == expected_previous_number, "Event param 'prevNum' has unexpected value"
assert log.newNum == number, "Event param 'newNum' has unexpected value"
assert log.newNum == f"{number} wei", "string comparison with number not working"
assert log.dynData == "Dynamic"
assert log.dynIndexed == HexBytes(
"0x9f3d45ac20ccf04b45028b8080bb191eab93e29f7898ed43acf480dd80bba94d"
Expand Down Expand Up @@ -321,10 +322,12 @@ def test_filter_events_with_same_abi(
assert result_c == [leaf_contract.OneOfMany(addr=contract_with_call_depth.address)]


def test_structs_in_events(contract_instance, owner):
def test_structs_in_events(contract_instance, owner, mystruct_c):
tx = contract_instance.logStruct(sender=owner)
expected_bytes = HexBytes(0x1234567890ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF)
expected = contract_instance.EventWithStruct(a_struct={"a": owner, "b": expected_bytes})
expected = contract_instance.EventWithStruct(
a_struct={"a": owner, "b": expected_bytes, "c": mystruct_c}
)
assert tx.events == [expected]


Expand Down
46 changes: 33 additions & 13 deletions tests/functional/test_contract_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def data_object(owner):
class DataObject(BaseModel):
a: AddressType = owner.address
b: HexBytes = HexBytes(123)
c: str = "GETS IGNORED"
c: int = 888
# Showing that extra keys don't matter.
extra_key: str = "GETS IGNORED"

return DataObject()

Expand All @@ -49,7 +51,12 @@ def test_contract_interaction(eth_tester_provider, owner, vyper_contract_instanc
assert receipt.gas_limit == max_gas

# Show contract state changed.
assert vyper_contract_instance.myNumber() == 102
num_val = vyper_contract_instance.myNumber()
assert num_val == 102

# Show that numbers from contract-return values are always currency-value
# comparable.
assert num_val == "102 WEI"

# Verify the estimate gas RPC was not used (since we are using max_gas).
assert estimate_gas_spy.call_count == 0
Expand Down Expand Up @@ -213,9 +220,9 @@ def test_repr(vyper_contract_instance):
)


def test_structs(contract_instance, owner, chain):
def test_structs(contract_instance, owner, chain, mystruct_c):
actual = contract_instance.getStruct()
actual_sender, actual_prev_block = actual
actual_sender, actual_prev_block, actual_c = actual
tx_hash = chain.blocks[-2].hash

# Expected: a == msg.sender
Expand All @@ -226,24 +233,27 @@ def test_structs(contract_instance, owner, chain):
assert actual.b == actual["b"] == actual[1] == actual_prev_block == tx_hash
assert isinstance(actual.b, bytes)

expected_dict = {"a": owner, "b": tx_hash}
# Expected: c == 244
assert actual.c == actual["c"] == actual_c == mystruct_c == f"{mystruct_c} wei"

expected_dict = {"a": owner, "b": tx_hash, "c": mystruct_c}
assert actual == expected_dict

expected_tuple = (owner, tx_hash)
expected_tuple = (owner, tx_hash, mystruct_c)
assert actual == expected_tuple

expected_list = [owner, tx_hash]
expected_list = [owner, tx_hash, mystruct_c]
assert actual == expected_list

expected_struct = contract_instance.getStruct()
assert actual == expected_struct


def test_nested_structs(contract_instance, owner, chain):
def test_nested_structs(contract_instance, owner, chain, mystruct_c):
actual_1 = contract_instance.getNestedStruct1()
actual_2 = contract_instance.getNestedStruct2()
actual_sender_1, actual_prev_block_1 = actual_1.t
actual_sender_2, actual_prev_block_2 = actual_2.t
actual_sender_1, actual_prev_block_1, actual_c_1 = actual_1.t
actual_sender_2, actual_prev_block_2, actual_c_2 = actual_2.t

# Expected: t.a == msg.sender
assert actual_1.t.a == actual_1.t["a"] == actual_1.t[0] == actual_sender_1 == owner
Expand Down Expand Up @@ -275,6 +285,9 @@ def test_nested_structs(contract_instance, owner, chain):
assert isinstance(actual_2.t.b, bytes)
assert isinstance(actual_2.t.b, HexBytes)

# Expected: t.c == 244
assert actual_c_1 == actual_c_2 == mystruct_c == f"{mystruct_c} wei"


def test_nested_structs_in_tuples(contract_instance, owner, chain):
result_1 = contract_instance.getNestedStructWithTuple1()
Expand Down Expand Up @@ -305,13 +318,18 @@ def test_get_empty_tuple_of_dyn_array_structs(contract_instance):


def test_get_empty_tuple_of_array_of_structs_and_dyn_array_of_structs(
contract_instance, zero_address
contract_instance,
zero_address,
):
actual = contract_instance.getEmptyTupleOfArrayOfStructsAndDynArrayOfStructs()

# empty address, bytes, and int.
expected_fixed_array = (
zero_address,
HexBytes("0x0000000000000000000000000000000000000000000000000000000000000000"),
0,
)

assert actual[0] == [expected_fixed_array, expected_fixed_array, expected_fixed_array]
assert actual[1] == []

Expand Down Expand Up @@ -693,7 +711,8 @@ def test_obj_as_struct_input(contract_instance, owner, data_object):


def test_dict_as_struct_input(contract_instance, owner):
data = {"a": owner, "b": HexBytes(123), "c": "GETS IGNORED"}
# NOTE: Also showing extra keys like "extra_key" don't matter and are ignored.
data = {"a": owner, "b": HexBytes(123), "c": 999, "extra_key": "GETS_IGNORED"}
assert contract_instance.setStruct(data) is None


Expand All @@ -708,7 +727,8 @@ def test_obj_list_as_struct_array_input(contract_instance, data_object, sequence

@pytest.mark.parametrize("sequence_type", (list, tuple))
def test_dict_list_as_struct_array_input(contract_instance, owner, sequence_type):
data = {"a": owner, "b": HexBytes(123), "c": "GETS IGNORED"}
# NOTE: Also showing extra keys like "extra_key" don't matter and are ignored.
data = {"a": owner, "b": HexBytes(123), "c": 444, "extra_key": "GETS IGNORED"}
parameter = sequence_type([data, data])
actual = contract_instance.setStructArray(parameter)
# The function is pure and doesn't return anything.
Expand Down
Loading

0 comments on commit 2dc9b10

Please sign in to comment.