Skip to content

Commit

Permalink
refactor: share hexy validators
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Aug 2, 2024
1 parent b046e1b commit 6599034
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 142 deletions.
76 changes: 15 additions & 61 deletions src/ape/api/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from datetime import datetime
from typing import IO, TYPE_CHECKING, Any, NoReturn, Optional, Union

from eth_pydantic_types import HexBytes
from eth_utils import is_0x_prefixed, is_hex, to_hex, to_int
from eth_pydantic_types import HexBytes, HexStr
from eth_utils import is_hex, to_hex, to_int
from ethpm_types.abi import EventABI, MethodABI
from pydantic import ConfigDict, field_validator
from pydantic.fields import Field
Expand All @@ -24,6 +24,7 @@
AddressType,
AutoGasLimit,
ContractLogContainer,
HexInt,
SourceTraceback,
TransactionSignature,
)
Expand Down Expand Up @@ -51,19 +52,19 @@ class TransactionAPI(BaseInterfaceModel):
such as typed-transactions from `EIP-1559 <https://eips.ethereum.org/EIPS/eip-1559>`__.
"""

chain_id: Optional[int] = Field(default=0, alias="chainId")
chain_id: Optional[HexInt] = Field(default=0, alias="chainId")
receiver: Optional[AddressType] = Field(default=None, alias="to")
sender: Optional[AddressType] = Field(default=None, alias="from")
gas_limit: Optional[int] = Field(default=None, alias="gas")
nonce: Optional[int] = None # NOTE: `Optional` only to denote using default behavior
value: int = 0
gas_limit: Optional[HexInt] = Field(default=None, alias="gas")
nonce: Optional[HexInt] = None # NOTE: `Optional` only to denote using default behavior
value: HexInt = 0
data: HexBytes = HexBytes("")
type: int
max_fee: Optional[int] = None
max_priority_fee: Optional[int] = None
type: HexInt
max_fee: Optional[HexInt] = None
max_priority_fee: Optional[HexInt] = None

# If left as None, will get set to the network's default required confirmations.
required_confirmations: Optional[int] = Field(default=None, exclude=True)
required_confirmations: Optional[HexInt] = Field(default=None, exclude=True)

signature: Optional[TransactionSignature] = Field(default=None, exclude=True)

Expand All @@ -74,20 +75,6 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._raise_on_revert = raise_on_revert

@field_validator("nonce", mode="before")
@classmethod
def validate_nonce(cls, value):
if value is None or isinstance(value, int):
return value

elif isinstance(value, str) and value.startswith("0x"):
return to_int(hexstr=value)

elif isinstance(value, str):
return int(value)

return to_int(value)

@field_validator("gas_limit", mode="before")
@classmethod
def validate_gas_limit(cls, value):
Expand All @@ -114,33 +101,6 @@ def validate_gas_limit(cls, value):

return value

@field_validator("max_fee", "max_priority_fee", mode="before")
@classmethod
def convert_fees(cls, value):
if isinstance(value, str):
return cls.conversion_manager.convert(value, int)

return value

@field_validator("data", mode="before")
@classmethod
def validate_data(cls, value):
if isinstance(value, str):
return HexBytes(value)

return value

@field_validator("value", mode="before")
@classmethod
def validate_value(cls, value):
if isinstance(value, int):
return value

elif isinstance(value, str) and is_0x_prefixed(value):
return int(value, 16)

return int(value)

@property
def raise_on_revert(self) -> bool:
return self._raise_on_revert
Expand Down Expand Up @@ -173,7 +133,6 @@ def receipt(self) -> Optional["ReceiptAPI"]:
"""
This transaction's associated published receipt, if it exists.
"""

try:
txn_hash = to_hex(self.txn_hash)
except SignatureError:
Expand Down Expand Up @@ -295,11 +254,11 @@ class ReceiptAPI(ExtraAttributesMixin, BaseInterfaceModel):
"""

contract_address: Optional[AddressType] = None
block_number: int
gas_used: int
block_number: HexInt
gas_used: HexInt
logs: list[dict] = []
status: int
txn_hash: str
status: HexInt
txn_hash: HexStr
transaction: TransactionAPI
_error: Optional[TransactionError] = None

Expand Down Expand Up @@ -330,11 +289,6 @@ def _validate_transaction(cls, value):

return ecosystem.create_transaction(**value)

@field_validator("txn_hash", mode="before")
@classmethod
def validate_txn_hash(cls, value):
return HexBytes(value).hex()

@cached_property
def debug_logs_typed(self) -> list[tuple[Any]]:
"""Return any debug log data outputted by the transaction."""
Expand Down
45 changes: 17 additions & 28 deletions src/ape/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Callable, Iterator, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, cast, overload
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, TypeVar, Union, cast, overload

from eth_abi.abi import encode
from eth_abi.packed import encode_packed
Expand All @@ -19,7 +19,7 @@
)
from ethpm_types.abi import EventABI
from ethpm_types.source import Closure
from pydantic import BaseModel, field_validator, model_validator
from pydantic import BaseModel, BeforeValidator, field_validator, model_validator
from typing_extensions import TypeAlias
from web3.types import FilterParams

Expand All @@ -41,7 +41,7 @@
ManagerAccessMixin,
cached_property,
)
from ape.utils.misc import ZERO_ADDRESS, log_instead_of_fail, to_int
from ape.utils.misc import ZERO_ADDRESS, log_instead_of_fail

if TYPE_CHECKING:
from ape.api.providers import BlockAPI
Expand Down Expand Up @@ -69,6 +69,15 @@
"""


HexInt = Annotated[
int, BeforeValidator(lambda v, info: ManagerAccessMixin.conversion_manager.convert(v, int))
]
"""
Validate any hex-str or bytes into an integer.
To be used on pydantic-fields.
"""


class AutoGasLimit(BaseModel):
"""
Additional settings for ``gas_limit: auto``.
Expand Down Expand Up @@ -229,11 +238,6 @@ class BaseContractLog(BaseInterfaceModel):
event_arguments: dict[str, Any] = {}
"""The arguments to the event, including both indexed and non-indexed data."""

@field_validator("contract_address", mode="before")
@classmethod
def validate_address(cls, value):
return cls.conversion_manager.convert(value, AddressType)

def __eq__(self, other: Any) -> bool:
if self.contract_address != other.contract_address or self.event_name != other.event_name:
return False
Expand All @@ -254,38 +258,21 @@ class ContractLog(ExtraAttributesMixin, BaseContractLog):
transaction_hash: Any
"""The hash of the transaction containing this log."""

block_number: int
block_number: HexInt
"""The number of the block containing the transaction that produced this log."""

block_hash: Any
"""The hash of the block containing the transaction that produced this log."""

log_index: int
log_index: HexInt
"""The index of the log on the transaction."""

transaction_index: Optional[int] = None
transaction_index: Optional[HexInt] = None
"""
The index of the transaction's position when the log was created.
Is `None` when from the pending block.
"""

@field_validator("block_number", "log_index", "transaction_index", mode="before")
@classmethod
def validate_hex_ints(cls, value):
if value is None:
# Should only happen for optionals.
return value

elif not isinstance(value, int):
return to_int(value)

return value

@field_validator("contract_address", mode="before")
@classmethod
def validate_address(cls, value):
return cls.conversion_manager.convert(value, AddressType)

# NOTE: This class has an overridden `__getattr__` method, but `block` is a reserved keyword
# in most smart contract languages, so it is safe to use. Purposely avoid adding
# `.datetime` and `.timestamp` in case they are used as event arg names.
Expand Down Expand Up @@ -524,6 +511,8 @@ def __eq__(self, other: Any) -> bool:
"CurrencyValue",
"CurrencyValueComparable",
"GasReport",
"HexInt",
"HexBytes",
"MessageSignature",
"PackageManifest",
"PackageMeta",
Expand Down
7 changes: 6 additions & 1 deletion src/ape/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,11 @@ def _create_raises_not_implemented_error(fn):
)


def to_int(value) -> int:
def to_int(value: Any) -> int:
"""
Convert the given value, such as hex-strs or hex-bytes, to an integer.
"""

if isinstance(value, int):
return value
elif isinstance(value, str):
Expand Down Expand Up @@ -594,5 +598,6 @@ def as_our_module(cls_or_def: _MOD_T, doc_str: Optional[str] = None) -> _MOD_T:
"run_until_complete",
"singledispatchmethod",
"stream_response",
"to_int",
"USER_AGENT",
]
15 changes: 0 additions & 15 deletions src/ape_ethereum/ecosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,21 +341,6 @@ class Block(BlockAPI):
default=EMPTY_BYTES32, alias="parentHash"
) # NOTE: genesis block has no parent hash

@field_validator(
"base_fee",
"difficulty",
"gas_limit",
"gas_used",
"number",
"size",
"timestamp",
"total_difficulty",
mode="before",
)
@classmethod
def validate_ints(cls, value):
return to_int(value) if value else 0

@computed_field() # type: ignore[misc]
@property
def size(self) -> int:
Expand Down
56 changes: 20 additions & 36 deletions src/ape_ethereum/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ape.contracts import ContractEvent
from ape.exceptions import OutOfGasError, SignatureError, TransactionError
from ape.logging import logger
from ape.types import AddressType, ContractLog, ContractLogContainer, SourceTraceback
from ape.types import AddressType, ContractLog, ContractLogContainer, HexInt, SourceTraceback
from ape.utils import ZERO_ADDRESS
from ape_ethereum.trace import Trace, _events_to_trees

Expand Down Expand Up @@ -115,19 +115,17 @@ class StaticFeeTransaction(BaseTransaction):
Transactions that are pre-EIP-1559 and use the ``gasPrice`` field.
"""

gas_price: Optional[int] = Field(default=None, alias="gasPrice")
max_priority_fee: Optional[int] = Field(default=None, exclude=True) # type: ignore
type: int = Field(default=TransactionType.STATIC.value, exclude=True)
max_fee: Optional[int] = Field(default=None, exclude=True) # type: ignore
gas_price: Optional[HexInt] = Field(default=None, alias="gasPrice")
max_priority_fee: Optional[HexInt] = Field(default=None, exclude=True) # type: ignore
type: HexInt = Field(default=TransactionType.STATIC.value, exclude=True)
max_fee: Optional[HexInt] = Field(default=None, exclude=True) # type: ignore

@model_validator(mode="before")
@model_validator(mode="after")
@classmethod
def calculate_read_only_max_fee(cls, values) -> dict:
# NOTE: Work-around, Pydantic doesn't handle calculated fields well.
values["max_fee"] = cls.conversion_manager.convert(
values.get("gas_limit", 0), int
) * cls.conversion_manager.convert(values.get("gas_price", 0), int)
return values
def calculate_read_only_max_fee(cls, tx):
# Work-around: we cannot use a computed field to override a non-computed field.
tx.max_fee = (tx.gas_limit or 0) * (tx.gas_price or 0)
return tx


class AccessListTransaction(StaticFeeTransaction):
Expand All @@ -152,9 +150,9 @@ class DynamicFeeTransaction(BaseTransaction):
and ``maxPriorityFeePerGas`` fields.
"""

max_priority_fee: Optional[int] = Field(default=None, alias="maxPriorityFeePerGas")
max_fee: Optional[int] = Field(default=None, alias="maxFeePerGas")
type: int = TransactionType.DYNAMIC.value
max_priority_fee: Optional[HexInt] = Field(default=None, alias="maxPriorityFeePerGas")
max_fee: Optional[HexInt] = Field(default=None, alias="maxFeePerGas")
type: HexInt = TransactionType.DYNAMIC.value
access_list: list[AccessList] = Field(default_factory=list, alias="accessList")

@field_validator("type")
Expand All @@ -168,24 +166,19 @@ class SharedBlobTransaction(DynamicFeeTransaction):
`EIP-4844 <https://eips.ethereum.org/EIPS/eip-4844>`__ transactions.
"""

max_fee_per_blob_gas: int = Field(default=0, alias="maxFeePerBlobGas")
max_fee_per_blob_gas: HexInt = Field(default=0, alias="maxFeePerBlobGas")
blob_versioned_hashes: list[HexBytes] = Field(default_factory=list, alias="blobVersionedHashes")

receiver: AddressType = Field(default=ZERO_ADDRESS, alias="to")
"""
Overridden because EIP-4844 states it cannot be nil.
"""
receiver: AddressType = Field(default=ZERO_ADDRESS, alias="to")

@field_validator("max_fee_per_blob_gas", mode="before")
@classmethod
def hex_to_int(cls, value):
return value if isinstance(value, int) else int(HexBytes(value).hex(), 16)


class Receipt(ReceiptAPI):
gas_limit: int
gas_price: int
gas_used: int
gas_limit: HexInt
gas_price: HexInt
gas_used: HexInt

@property
def ran_out_of_gas(self) -> bool:
Expand Down Expand Up @@ -419,21 +412,12 @@ class SharedBlobReceipt(Receipt):
blob transaction.
"""

blob_gas_used: int
blob_gas_used: HexInt
"""
The total amount of blob gas consumed by the transactions within the block.
"""

blob_gas_price: int
blob_gas_price: HexInt
"""
The blob-gas price, independent from regular gas price.
"""

@field_validator("blob_gas_used", "blob_gas_price", mode="before")
@classmethod
def validate_hex(cls, value):
return cls._hex_to_int(value or 0)

@classmethod
def _hex_to_int(cls, value) -> int:
return value if isinstance(value, int) else int(HexBytes(value).hex(), 16)
Loading

0 comments on commit 6599034

Please sign in to comment.