From 65990343dda9fc1206fc543227ce4d2a09d12d32 Mon Sep 17 00:00:00 2001 From: Juliya Smith Date: Fri, 2 Aug 2024 16:11:33 -0500 Subject: [PATCH] refactor: share hexy validators --- src/ape/api/transactions.py | 76 ++++++---------------------- src/ape/types/__init__.py | 45 +++++++--------- src/ape/utils/misc.py | 7 ++- src/ape_ethereum/ecosystem.py | 15 ------ src/ape_ethereum/transactions.py | 56 ++++++++------------ tests/functional/test_transaction.py | 25 +++++++++ tests/functional/test_types.py | 11 +++- 7 files changed, 93 insertions(+), 142 deletions(-) diff --git a/src/ape/api/transactions.py b/src/ape/api/transactions.py index 9cb8984ab7..9ee0b1db1b 100644 --- a/src/ape/api/transactions.py +++ b/src/ape/api/transactions.py @@ -4,8 +4,8 @@ from datetime import datetime from typing import IO, TYPE_CHECKING, Any, NoReturn, Optional, Union -from eth_pydantic_types import HexBytes -from eth_utils import is_0x_prefixed, is_hex, to_hex, to_int +from eth_pydantic_types import HexBytes, HexStr +from eth_utils import is_hex, to_hex, to_int from ethpm_types.abi import EventABI, MethodABI from pydantic import ConfigDict, field_validator from pydantic.fields import Field @@ -24,6 +24,7 @@ AddressType, AutoGasLimit, ContractLogContainer, + HexInt, SourceTraceback, TransactionSignature, ) @@ -51,19 +52,19 @@ class TransactionAPI(BaseInterfaceModel): such as typed-transactions from `EIP-1559 `__. """ - chain_id: Optional[int] = Field(default=0, alias="chainId") + chain_id: Optional[HexInt] = Field(default=0, alias="chainId") receiver: Optional[AddressType] = Field(default=None, alias="to") sender: Optional[AddressType] = Field(default=None, alias="from") - gas_limit: Optional[int] = Field(default=None, alias="gas") - nonce: Optional[int] = None # NOTE: `Optional` only to denote using default behavior - value: int = 0 + gas_limit: Optional[HexInt] = Field(default=None, alias="gas") + nonce: Optional[HexInt] = None # NOTE: `Optional` only to denote using default behavior + value: HexInt = 0 data: HexBytes = HexBytes("") - type: int - max_fee: Optional[int] = None - max_priority_fee: Optional[int] = None + type: HexInt + max_fee: Optional[HexInt] = None + max_priority_fee: Optional[HexInt] = None # If left as None, will get set to the network's default required confirmations. - required_confirmations: Optional[int] = Field(default=None, exclude=True) + required_confirmations: Optional[HexInt] = Field(default=None, exclude=True) signature: Optional[TransactionSignature] = Field(default=None, exclude=True) @@ -74,20 +75,6 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._raise_on_revert = raise_on_revert - @field_validator("nonce", mode="before") - @classmethod - def validate_nonce(cls, value): - if value is None or isinstance(value, int): - return value - - elif isinstance(value, str) and value.startswith("0x"): - return to_int(hexstr=value) - - elif isinstance(value, str): - return int(value) - - return to_int(value) - @field_validator("gas_limit", mode="before") @classmethod def validate_gas_limit(cls, value): @@ -114,33 +101,6 @@ def validate_gas_limit(cls, value): return value - @field_validator("max_fee", "max_priority_fee", mode="before") - @classmethod - def convert_fees(cls, value): - if isinstance(value, str): - return cls.conversion_manager.convert(value, int) - - return value - - @field_validator("data", mode="before") - @classmethod - def validate_data(cls, value): - if isinstance(value, str): - return HexBytes(value) - - return value - - @field_validator("value", mode="before") - @classmethod - def validate_value(cls, value): - if isinstance(value, int): - return value - - elif isinstance(value, str) and is_0x_prefixed(value): - return int(value, 16) - - return int(value) - @property def raise_on_revert(self) -> bool: return self._raise_on_revert @@ -173,7 +133,6 @@ def receipt(self) -> Optional["ReceiptAPI"]: """ This transaction's associated published receipt, if it exists. """ - try: txn_hash = to_hex(self.txn_hash) except SignatureError: @@ -295,11 +254,11 @@ class ReceiptAPI(ExtraAttributesMixin, BaseInterfaceModel): """ contract_address: Optional[AddressType] = None - block_number: int - gas_used: int + block_number: HexInt + gas_used: HexInt logs: list[dict] = [] - status: int - txn_hash: str + status: HexInt + txn_hash: HexStr transaction: TransactionAPI _error: Optional[TransactionError] = None @@ -330,11 +289,6 @@ def _validate_transaction(cls, value): return ecosystem.create_transaction(**value) - @field_validator("txn_hash", mode="before") - @classmethod - def validate_txn_hash(cls, value): - return HexBytes(value).hex() - @cached_property def debug_logs_typed(self) -> list[tuple[Any]]: """Return any debug log data outputted by the transaction.""" diff --git a/src/ape/types/__init__.py b/src/ape/types/__init__.py index 4af0a17124..57390e348c 100644 --- a/src/ape/types/__init__.py +++ b/src/ape/types/__init__.py @@ -1,6 +1,6 @@ from collections.abc import Callable, Iterator, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, cast, overload +from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, TypeVar, Union, cast, overload from eth_abi.abi import encode from eth_abi.packed import encode_packed @@ -19,7 +19,7 @@ ) from ethpm_types.abi import EventABI from ethpm_types.source import Closure -from pydantic import BaseModel, field_validator, model_validator +from pydantic import BaseModel, BeforeValidator, field_validator, model_validator from typing_extensions import TypeAlias from web3.types import FilterParams @@ -41,7 +41,7 @@ ManagerAccessMixin, cached_property, ) -from ape.utils.misc import ZERO_ADDRESS, log_instead_of_fail, to_int +from ape.utils.misc import ZERO_ADDRESS, log_instead_of_fail if TYPE_CHECKING: from ape.api.providers import BlockAPI @@ -69,6 +69,15 @@ """ +HexInt = Annotated[ + int, BeforeValidator(lambda v, info: ManagerAccessMixin.conversion_manager.convert(v, int)) +] +""" +Validate any hex-str or bytes into an integer. +To be used on pydantic-fields. +""" + + class AutoGasLimit(BaseModel): """ Additional settings for ``gas_limit: auto``. @@ -229,11 +238,6 @@ class BaseContractLog(BaseInterfaceModel): event_arguments: dict[str, Any] = {} """The arguments to the event, including both indexed and non-indexed data.""" - @field_validator("contract_address", mode="before") - @classmethod - def validate_address(cls, value): - return cls.conversion_manager.convert(value, AddressType) - def __eq__(self, other: Any) -> bool: if self.contract_address != other.contract_address or self.event_name != other.event_name: return False @@ -254,38 +258,21 @@ class ContractLog(ExtraAttributesMixin, BaseContractLog): transaction_hash: Any """The hash of the transaction containing this log.""" - block_number: int + block_number: HexInt """The number of the block containing the transaction that produced this log.""" block_hash: Any """The hash of the block containing the transaction that produced this log.""" - log_index: int + log_index: HexInt """The index of the log on the transaction.""" - transaction_index: Optional[int] = None + transaction_index: Optional[HexInt] = None """ The index of the transaction's position when the log was created. Is `None` when from the pending block. """ - @field_validator("block_number", "log_index", "transaction_index", mode="before") - @classmethod - def validate_hex_ints(cls, value): - if value is None: - # Should only happen for optionals. - return value - - elif not isinstance(value, int): - return to_int(value) - - return value - - @field_validator("contract_address", mode="before") - @classmethod - def validate_address(cls, value): - return cls.conversion_manager.convert(value, AddressType) - # NOTE: This class has an overridden `__getattr__` method, but `block` is a reserved keyword # in most smart contract languages, so it is safe to use. Purposely avoid adding # `.datetime` and `.timestamp` in case they are used as event arg names. @@ -524,6 +511,8 @@ def __eq__(self, other: Any) -> bool: "CurrencyValue", "CurrencyValueComparable", "GasReport", + "HexInt", + "HexBytes", "MessageSignature", "PackageManifest", "PackageMeta", diff --git a/src/ape/utils/misc.py b/src/ape/utils/misc.py index d0977f1334..12c8cac9a4 100644 --- a/src/ape/utils/misc.py +++ b/src/ape/utils/misc.py @@ -367,7 +367,11 @@ def _create_raises_not_implemented_error(fn): ) -def to_int(value) -> int: +def to_int(value: Any) -> int: + """ + Convert the given value, such as hex-strs or hex-bytes, to an integer. + """ + if isinstance(value, int): return value elif isinstance(value, str): @@ -594,5 +598,6 @@ def as_our_module(cls_or_def: _MOD_T, doc_str: Optional[str] = None) -> _MOD_T: "run_until_complete", "singledispatchmethod", "stream_response", + "to_int", "USER_AGENT", ] diff --git a/src/ape_ethereum/ecosystem.py b/src/ape_ethereum/ecosystem.py index 9a98b615cf..3706d7c841 100644 --- a/src/ape_ethereum/ecosystem.py +++ b/src/ape_ethereum/ecosystem.py @@ -341,21 +341,6 @@ class Block(BlockAPI): default=EMPTY_BYTES32, alias="parentHash" ) # NOTE: genesis block has no parent hash - @field_validator( - "base_fee", - "difficulty", - "gas_limit", - "gas_used", - "number", - "size", - "timestamp", - "total_difficulty", - mode="before", - ) - @classmethod - def validate_ints(cls, value): - return to_int(value) if value else 0 - @computed_field() # type: ignore[misc] @property def size(self) -> int: diff --git a/src/ape_ethereum/transactions.py b/src/ape_ethereum/transactions.py index 4ecbf59a33..b51a9d9f55 100644 --- a/src/ape_ethereum/transactions.py +++ b/src/ape_ethereum/transactions.py @@ -19,7 +19,7 @@ from ape.contracts import ContractEvent from ape.exceptions import OutOfGasError, SignatureError, TransactionError from ape.logging import logger -from ape.types import AddressType, ContractLog, ContractLogContainer, SourceTraceback +from ape.types import AddressType, ContractLog, ContractLogContainer, HexInt, SourceTraceback from ape.utils import ZERO_ADDRESS from ape_ethereum.trace import Trace, _events_to_trees @@ -115,19 +115,17 @@ class StaticFeeTransaction(BaseTransaction): Transactions that are pre-EIP-1559 and use the ``gasPrice`` field. """ - gas_price: Optional[int] = Field(default=None, alias="gasPrice") - max_priority_fee: Optional[int] = Field(default=None, exclude=True) # type: ignore - type: int = Field(default=TransactionType.STATIC.value, exclude=True) - max_fee: Optional[int] = Field(default=None, exclude=True) # type: ignore + gas_price: Optional[HexInt] = Field(default=None, alias="gasPrice") + max_priority_fee: Optional[HexInt] = Field(default=None, exclude=True) # type: ignore + type: HexInt = Field(default=TransactionType.STATIC.value, exclude=True) + max_fee: Optional[HexInt] = Field(default=None, exclude=True) # type: ignore - @model_validator(mode="before") + @model_validator(mode="after") @classmethod - def calculate_read_only_max_fee(cls, values) -> dict: - # NOTE: Work-around, Pydantic doesn't handle calculated fields well. - values["max_fee"] = cls.conversion_manager.convert( - values.get("gas_limit", 0), int - ) * cls.conversion_manager.convert(values.get("gas_price", 0), int) - return values + def calculate_read_only_max_fee(cls, tx): + # Work-around: we cannot use a computed field to override a non-computed field. + tx.max_fee = (tx.gas_limit or 0) * (tx.gas_price or 0) + return tx class AccessListTransaction(StaticFeeTransaction): @@ -152,9 +150,9 @@ class DynamicFeeTransaction(BaseTransaction): and ``maxPriorityFeePerGas`` fields. """ - max_priority_fee: Optional[int] = Field(default=None, alias="maxPriorityFeePerGas") - max_fee: Optional[int] = Field(default=None, alias="maxFeePerGas") - type: int = TransactionType.DYNAMIC.value + max_priority_fee: Optional[HexInt] = Field(default=None, alias="maxPriorityFeePerGas") + max_fee: Optional[HexInt] = Field(default=None, alias="maxFeePerGas") + type: HexInt = TransactionType.DYNAMIC.value access_list: list[AccessList] = Field(default_factory=list, alias="accessList") @field_validator("type") @@ -168,24 +166,19 @@ class SharedBlobTransaction(DynamicFeeTransaction): `EIP-4844 `__ transactions. """ - max_fee_per_blob_gas: int = Field(default=0, alias="maxFeePerBlobGas") + max_fee_per_blob_gas: HexInt = Field(default=0, alias="maxFeePerBlobGas") blob_versioned_hashes: list[HexBytes] = Field(default_factory=list, alias="blobVersionedHashes") + receiver: AddressType = Field(default=ZERO_ADDRESS, alias="to") """ Overridden because EIP-4844 states it cannot be nil. """ - receiver: AddressType = Field(default=ZERO_ADDRESS, alias="to") - - @field_validator("max_fee_per_blob_gas", mode="before") - @classmethod - def hex_to_int(cls, value): - return value if isinstance(value, int) else int(HexBytes(value).hex(), 16) class Receipt(ReceiptAPI): - gas_limit: int - gas_price: int - gas_used: int + gas_limit: HexInt + gas_price: HexInt + gas_used: HexInt @property def ran_out_of_gas(self) -> bool: @@ -419,21 +412,12 @@ class SharedBlobReceipt(Receipt): blob transaction. """ - blob_gas_used: int + blob_gas_used: HexInt """ The total amount of blob gas consumed by the transactions within the block. """ - blob_gas_price: int + blob_gas_price: HexInt """ The blob-gas price, independent from regular gas price. """ - - @field_validator("blob_gas_used", "blob_gas_price", mode="before") - @classmethod - def validate_hex(cls, value): - return cls._hex_to_int(value or 0) - - @classmethod - def _hex_to_int(cls, value) -> int: - return value if isinstance(value, int) else int(HexBytes(value).hex(), 16) diff --git a/tests/functional/test_transaction.py b/tests/functional/test_transaction.py index a4b75f81dc..65e70a8e45 100644 --- a/tests/functional/test_transaction.py +++ b/tests/functional/test_transaction.py @@ -1,10 +1,12 @@ import re import warnings +from typing import Optional import pytest from eth_pydantic_types import HexBytes from hexbytes import HexBytes as BaseHexBytes +from ape.api import TransactionAPI from ape.exceptions import SignatureError from ape_ethereum.transactions import ( AccessList, @@ -380,3 +382,26 @@ def test_address(self, address): def test_storage_keys(self, storage_key, zero_address): actual = AccessList(address=zero_address, storageKeys=[storage_key]) assert actual.storage_keys == [HexBytes(storage_key)] + + +def test_override_annotated_fields(): + """ + This test is to prove that a user may use an `int` for a base-class + when the API field is described as a `HexInt`. + """ + + class MyTransaction(TransactionAPI): + @property + def txn_hash(self) -> HexBytes: + return HexBytes("") + + def serialize_transaction(self) -> bytes: + return b"" + + chain_id: Optional[int] = None # The base type is `Optional[HexInt]`. + + chain_id = 123123123123123123123123123123 + tx_type = 120 + my_tx = MyTransaction.model_validate({"chain_id": chain_id, "type": tx_type}) + assert my_tx.chain_id == chain_id + assert my_tx.type == tx_type diff --git a/tests/functional/test_types.py b/tests/functional/test_types.py index 56bf2b5bbb..5eaedb178c 100644 --- a/tests/functional/test_types.py +++ b/tests/functional/test_types.py @@ -4,7 +4,7 @@ from hexbytes import HexBytes from pydantic import BaseModel -from ape.types import AddressType, ContractLog, LogFilter +from ape.types import AddressType, ContractLog, HexInt, LogFilter from ape.utils import ZERO_ADDRESS TXN_HASH = "0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa222222222222222222222222" @@ -126,3 +126,12 @@ class MyModel(BaseModel): # Show int works. instance_bytes = MyModel(addr=int(owner.address, 16)) assert instance_bytes.addr == owner.address + + +class TestHexInt: + class MyModel(BaseModel): + ual: HexInt = 0 + + act = MyModel.model_validate({"ual": "0x123"}) + expected = 291 # Base-10 form of 0x123. + assert act.ual == expected