Skip to content

Commit

Permalink
refactor: share hex-related validators (#2203)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Aug 5, 2024
1 parent b046e1b commit 011f776
Show file tree
Hide file tree
Showing 12 changed files with 124 additions and 161 deletions.
16 changes: 8 additions & 8 deletions src/ape/api/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
VirtualMachineError,
)
from ape.logging import LogLevel, logger
from ape.types import AddressType, BlockID, ContractCode, ContractLog, LogFilter, SnapshotID
from ape.types import AddressType, BlockID, ContractCode, ContractLog, HexInt, LogFilter, SnapshotID
from ape.utils import BaseInterfaceModel, JoinableQueue, abstractmethod, cached_property, spawn
from ape.utils.misc import (
EMPTY_BYTES32,
Expand All @@ -54,33 +54,33 @@ class BlockAPI(BaseInterfaceModel):
# NOTE: All fields in this class (and it's subclasses) should not be `Optional`
# except the edge cases noted below

num_transactions: HexInt = 0
"""
The number of transactions in the block.
"""
num_transactions: int = 0

hash: Optional[Any] = None # NOTE: pending block does not have a hash
"""
The block hash identifier.
"""
hash: Optional[Any] = None # NOTE: pending block does not have a hash

number: Optional[HexInt] = None # NOTE: pending block does not have a number
"""
The block number identifier.
"""
number: Optional[int] = None # NOTE: pending block does not have a number

"""
The preceeding block's hash.
"""
parent_hash: Any = Field(
default=EMPTY_BYTES32, alias="parentHash"
) # NOTE: genesis block has no parent hash
"""
The preceeding block's hash.
"""

timestamp: HexInt
"""
The timestamp the block was produced.
NOTE: The pending block uses the current timestamp.
"""
timestamp: int

_size: Optional[int] = None

Expand Down
76 changes: 15 additions & 61 deletions src/ape/api/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from datetime import datetime
from typing import IO, TYPE_CHECKING, Any, NoReturn, Optional, Union

from eth_pydantic_types import HexBytes
from eth_utils import is_0x_prefixed, is_hex, to_hex, to_int
from eth_pydantic_types import HexBytes, HexStr
from eth_utils import is_hex, to_hex, to_int
from ethpm_types.abi import EventABI, MethodABI
from pydantic import ConfigDict, field_validator
from pydantic.fields import Field
Expand All @@ -24,6 +24,7 @@
AddressType,
AutoGasLimit,
ContractLogContainer,
HexInt,
SourceTraceback,
TransactionSignature,
)
Expand Down Expand Up @@ -51,19 +52,19 @@ class TransactionAPI(BaseInterfaceModel):
such as typed-transactions from `EIP-1559 <https://eips.ethereum.org/EIPS/eip-1559>`__.
"""

chain_id: Optional[int] = Field(default=0, alias="chainId")
chain_id: Optional[HexInt] = Field(default=0, alias="chainId")
receiver: Optional[AddressType] = Field(default=None, alias="to")
sender: Optional[AddressType] = Field(default=None, alias="from")
gas_limit: Optional[int] = Field(default=None, alias="gas")
nonce: Optional[int] = None # NOTE: `Optional` only to denote using default behavior
value: int = 0
gas_limit: Optional[HexInt] = Field(default=None, alias="gas")
nonce: Optional[HexInt] = None # NOTE: `Optional` only to denote using default behavior
value: HexInt = 0
data: HexBytes = HexBytes("")
type: int
max_fee: Optional[int] = None
max_priority_fee: Optional[int] = None
type: HexInt
max_fee: Optional[HexInt] = None
max_priority_fee: Optional[HexInt] = None

# If left as None, will get set to the network's default required confirmations.
required_confirmations: Optional[int] = Field(default=None, exclude=True)
required_confirmations: Optional[HexInt] = Field(default=None, exclude=True)

signature: Optional[TransactionSignature] = Field(default=None, exclude=True)

Expand All @@ -74,20 +75,6 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._raise_on_revert = raise_on_revert

@field_validator("nonce", mode="before")
@classmethod
def validate_nonce(cls, value):
if value is None or isinstance(value, int):
return value

elif isinstance(value, str) and value.startswith("0x"):
return to_int(hexstr=value)

elif isinstance(value, str):
return int(value)

return to_int(value)

@field_validator("gas_limit", mode="before")
@classmethod
def validate_gas_limit(cls, value):
Expand All @@ -114,33 +101,6 @@ def validate_gas_limit(cls, value):

return value

@field_validator("max_fee", "max_priority_fee", mode="before")
@classmethod
def convert_fees(cls, value):
if isinstance(value, str):
return cls.conversion_manager.convert(value, int)

return value

@field_validator("data", mode="before")
@classmethod
def validate_data(cls, value):
if isinstance(value, str):
return HexBytes(value)

return value

@field_validator("value", mode="before")
@classmethod
def validate_value(cls, value):
if isinstance(value, int):
return value

elif isinstance(value, str) and is_0x_prefixed(value):
return int(value, 16)

return int(value)

@property
def raise_on_revert(self) -> bool:
return self._raise_on_revert
Expand Down Expand Up @@ -173,7 +133,6 @@ def receipt(self) -> Optional["ReceiptAPI"]:
"""
This transaction's associated published receipt, if it exists.
"""

try:
txn_hash = to_hex(self.txn_hash)
except SignatureError:
Expand Down Expand Up @@ -295,11 +254,11 @@ class ReceiptAPI(ExtraAttributesMixin, BaseInterfaceModel):
"""

contract_address: Optional[AddressType] = None
block_number: int
gas_used: int
block_number: HexInt
gas_used: HexInt
logs: list[dict] = []
status: int
txn_hash: str
status: HexInt
txn_hash: HexStr
transaction: TransactionAPI
_error: Optional[TransactionError] = None

Expand Down Expand Up @@ -330,11 +289,6 @@ def _validate_transaction(cls, value):

return ecosystem.create_transaction(**value)

@field_validator("txn_hash", mode="before")
@classmethod
def validate_txn_hash(cls, value):
return HexBytes(value).hex()

@cached_property
def debug_logs_typed(self) -> list[tuple[Any]]:
"""Return any debug log data outputted by the transaction."""
Expand Down
47 changes: 19 additions & 28 deletions src/ape/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Callable, Iterator, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, cast, overload
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, TypeVar, Union, cast, overload

from eth_abi.abi import encode
from eth_abi.packed import encode_packed
Expand All @@ -19,7 +19,7 @@
)
from ethpm_types.abi import EventABI
from ethpm_types.source import Closure
from pydantic import BaseModel, field_validator, model_validator
from pydantic import BaseModel, BeforeValidator, field_validator, model_validator
from typing_extensions import TypeAlias
from web3.types import FilterParams

Expand All @@ -41,7 +41,7 @@
ManagerAccessMixin,
cached_property,
)
from ape.utils.misc import ZERO_ADDRESS, log_instead_of_fail, to_int
from ape.utils.misc import ZERO_ADDRESS, log_instead_of_fail

if TYPE_CHECKING:
from ape.api.providers import BlockAPI
Expand All @@ -68,6 +68,17 @@
cases.
"""

HexInt = Annotated[
int,
BeforeValidator(
lambda v, info: None if v is None else ManagerAccessMixin.conversion_manager.convert(v, int)
),
]
"""
Validate any hex-str or bytes into an integer.
To be used on pydantic-fields.
"""


class AutoGasLimit(BaseModel):
"""
Expand Down Expand Up @@ -229,11 +240,6 @@ class BaseContractLog(BaseInterfaceModel):
event_arguments: dict[str, Any] = {}
"""The arguments to the event, including both indexed and non-indexed data."""

@field_validator("contract_address", mode="before")
@classmethod
def validate_address(cls, value):
return cls.conversion_manager.convert(value, AddressType)

def __eq__(self, other: Any) -> bool:
if self.contract_address != other.contract_address or self.event_name != other.event_name:
return False
Expand All @@ -254,38 +260,21 @@ class ContractLog(ExtraAttributesMixin, BaseContractLog):
transaction_hash: Any
"""The hash of the transaction containing this log."""

block_number: int
block_number: HexInt
"""The number of the block containing the transaction that produced this log."""

block_hash: Any
"""The hash of the block containing the transaction that produced this log."""

log_index: int
log_index: HexInt
"""The index of the log on the transaction."""

transaction_index: Optional[int] = None
transaction_index: Optional[HexInt] = None
"""
The index of the transaction's position when the log was created.
Is `None` when from the pending block.
"""

@field_validator("block_number", "log_index", "transaction_index", mode="before")
@classmethod
def validate_hex_ints(cls, value):
if value is None:
# Should only happen for optionals.
return value

elif not isinstance(value, int):
return to_int(value)

return value

@field_validator("contract_address", mode="before")
@classmethod
def validate_address(cls, value):
return cls.conversion_manager.convert(value, AddressType)

# NOTE: This class has an overridden `__getattr__` method, but `block` is a reserved keyword
# in most smart contract languages, so it is safe to use. Purposely avoid adding
# `.datetime` and `.timestamp` in case they are used as event arg names.
Expand Down Expand Up @@ -524,6 +513,8 @@ def __eq__(self, other: Any) -> bool:
"CurrencyValue",
"CurrencyValueComparable",
"GasReport",
"HexInt",
"HexBytes",
"MessageSignature",
"PackageManifest",
"PackageMeta",
Expand Down
7 changes: 6 additions & 1 deletion src/ape/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,11 @@ def _create_raises_not_implemented_error(fn):
)


def to_int(value) -> int:
def to_int(value: Any) -> int:
"""
Convert the given value, such as hex-strs or hex-bytes, to an integer.
"""

if isinstance(value, int):
return value
elif isinstance(value, str):
Expand Down Expand Up @@ -594,5 +598,6 @@ def as_our_module(cls_or_def: _MOD_T, doc_str: Optional[str] = None) -> _MOD_T:
"run_until_complete",
"singledispatchmethod",
"stream_response",
"to_int",
"USER_AGENT",
]
34 changes: 12 additions & 22 deletions src/ape_ethereum/ecosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
ContractLog,
CurrencyValueComparable,
GasLimit,
HexInt,
RawAddress,
TransactionSignature,
)
Expand Down Expand Up @@ -328,11 +329,11 @@ class Block(BlockAPI):
Class for representing a block on a chain.
"""

gas_limit: int = Field(alias="gasLimit")
gas_used: int = Field(alias="gasUsed")
base_fee: int = Field(default=0, alias="baseFeePerGas")
difficulty: int = 0
total_difficulty: int = Field(default=0, alias="totalDifficulty")
gas_limit: HexInt = Field(alias="gasLimit")
gas_used: HexInt = Field(alias="gasUsed")
base_fee: HexInt = Field(default=0, alias="baseFeePerGas")
difficulty: HexInt = 0
total_difficulty: HexInt = Field(default=0, alias="totalDifficulty")
uncles: list[HexBytes] = []

# Type re-declares.
Expand All @@ -341,21 +342,6 @@ class Block(BlockAPI):
default=EMPTY_BYTES32, alias="parentHash"
) # NOTE: genesis block has no parent hash

@field_validator(
"base_fee",
"difficulty",
"gas_limit",
"gas_used",
"number",
"size",
"timestamp",
"total_difficulty",
mode="before",
)
@classmethod
def validate_ints(cls, value):
return to_int(value) if value else 0

@computed_field() # type: ignore[misc]
@property
def size(self) -> int:
Expand Down Expand Up @@ -591,7 +577,9 @@ def decode_receipt(self, data: dict) -> ReceiptAPI:
):
receipt_cls = SharedBlobReceipt
receipt_kwargs["blob_gas_price"] = data.get("blob_gas_price", data.get("blobGasPrice"))
receipt_kwargs["blob_gas_used"] = data.get("blob_gas_used", data.get("blobGasUsed"))
receipt_kwargs["blob_gas_used"] = (
data.get("blob_gas_used", data.get("blobGasUsed")) or 0
)
else:
receipt_cls = Receipt

Expand All @@ -611,7 +599,9 @@ def decode_block(self, data: dict) -> BlockAPI:
if "transaction_ids" in data:
data["transactions"] = data.pop("transaction_ids")
if "total_difficulty" in data:
data["totalDifficulty"] = data.pop("total_difficulty")
data["totalDifficulty"] = data.pop("total_difficulty") or 0
elif "totalDifficulty" in data:
data["totalDifficulty"] = data.pop("totalDifficulty") or 0
if "base_fee" in data:
data["baseFeePerGas"] = data.pop("base_fee")
elif "baseFee" in data:
Expand Down
Loading

0 comments on commit 011f776

Please sign in to comment.