Skip to content

Commit

Permalink
feat: add calltreenode model (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
NotPeopling2day authored May 2, 2022
1 parent ba8edde commit f344d29
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 19 deletions.
4 changes: 2 additions & 2 deletions evm_trace/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .base import TraceFrame
from .base import CallTreeNode, CallType, TraceFrame

__all__ = ["TraceFrame"]
__all__ = ["CallTreeNode", "CallType", "TraceFrame"]
54 changes: 39 additions & 15 deletions evm_trace/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from enum import Enum
from typing import Any, Dict, List

from hexbytes import HexBytes
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, ValidationError, validator


def _convert_hexbytes(cls, v: Any) -> HexBytes:
try:
return HexBytes(v)
except ValueError:
raise ValidationError(f"Value '{v}' could not be converted to Hexbytes.", cls)


class TraceFrame(BaseModel):
Expand All @@ -12,21 +20,37 @@ class TraceFrame(BaseModel):
depth: int
stack: List[Any]
memory: List[Any]
storage: Dict[Any, Any]
storage: Dict[Any, Any] = {}

@validator("stack", "memory", pre=True, each_item=True)
def convert_hexbytes(cls, v) -> HexBytes:
return _convert_hexbytes(cls, v)

@validator("storage", pre=True)
def convert_hexbytes_dict(cls, v) -> Dict[HexBytes, HexBytes]:
return {_convert_hexbytes(cls, k): _convert_hexbytes(cls, val) for k, val in v.items()}


class CallType(Enum):
INTERNAL = "INTERNAL" # Non-opcode internal call
STATIC = "STATIC" # STATICCALL opcode
MUTABLE = "MUTABLE" # CALL opcode
DELEGATE = "DELEGATE" # DELEGATECALL opcode
SELFDESTRUCT = "SELFDESTRUCT" # SELFDESTRUCT opcode

@validator("stack", "memory", pre=True, each_item=True)
def validate_hexbytes(value: Any) -> HexBytes:
if value and not isinstance(value, HexBytes):
raise ValueError(f"Hash `{value}` is not a valid Hexbyte.")
return value

class CallTreeNode(BaseModel):
call_type: CallType
address: Any
value: int = 0
gas_limit: int
gas_cost: int # calculated from call starting and return
calldata: Any = HexBytes(b"")
returndata: Any = HexBytes(b"")
calls: List["CallTreeNode"] = []
selfdestruct: bool = False
failed: bool = False

@validator("storage")
def validate_hexbytes_dict(value: Any) -> Dict[HexBytes, HexBytes]:
for k, v in value:
if k and not isinstance(k, HexBytes):
raise ValueError(f"Key `{value}` is not a valid Hexbyte.")
if v and not isinstance(v, HexBytes):
raise ValueError(f"Value `{value}` is not a valid Hexbyte.")
return value
@validator("address", "calldata", "returndata", pre=True)
def validate_hexbytes(cls, v) -> HexBytes:
return _convert_hexbytes(cls, v)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ markers = "fuzzing: Run Hypothesis fuzz test suite"
line_length = 100
force_grid_wrap = 0
include_trailing_comma = true
known_third_party = ["hexbytes", "pydantic", "setuptools"]
known_third_party = ["hexbytes", "pydantic", "pytest", "setuptools"]
known_first_party = ["MODULE_NAME"]
multi_line_output = 3
use_parentheses = true
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"hypothesis>=6.2.0,<7.0", # Strategy-based fuzzer
],
"lint": [
"black>=21.10b0,<22.0", # auto-formatter and linter
"black>=22.3.0,<23.0", # auto-formatter and linter
"mypy>=0.910,<1.0", # Static type analyzer
"flake8>=3.8.3,<4.0", # Style linter
"isort>=5.9.3,<6.0", # Import sorting linter
Expand Down
49 changes: 49 additions & 0 deletions tests/test_trace_frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from copy import deepcopy

import pytest
from pydantic import ValidationError

from evm_trace.base import TraceFrame

TRACE_FRAME_STRUCTURE = {
"pc": 1564,
"op": "RETURN",
"gas": 0,
"gasCost": 0,
"depth": 1,
"stack": [
"0000000000000000000000000000000000000000000000000000000040c10f19",
"0000000000000000000000000000000000000000000000000000000000000020",
"0000000000000000000000000000000000000000000000000000000000000140",
],
"memory": [
"0000000000000000000000001e59ce931b4cfea3fe4b875411e280e173cb7a9c",
"0000000000000000000000000000000000000000000000000000000000000001",
],
"storage": {
"0000000000000000000000000000000000000000000000000000000000000004": "0000000000000000000000001e59ce931b4cfea3fe4b875411e280e173cb7a9c", # noqa: E501
"ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5": "0000000000000000000000001e59ce931b4cfea3fe4b875411e280e173cb7a9c", # noqa: E501
"aadb61a4b4c5d48b7a5669391b7c73852a3ab7795f24721b9a439220b54b591b": "0000000000000000000000000000000000000000000000000000000000000001", # noqa: E501
},
}


def test_trace_frame_validation_passes():
frame = TraceFrame(**TRACE_FRAME_STRUCTURE)
assert frame


trace_frame_test_cases = (
{"stack": ["potato"]},
{"memory": ["potato"]},
{"storage": {"piggy": "dippin"}},
)


@pytest.mark.parametrize("test_value", trace_frame_test_cases)
def test_trace_frame_validation_fails(test_value):
trace_frame_structure = deepcopy(TRACE_FRAME_STRUCTURE)
trace_frame_structure.update(test_value)

with pytest.raises(ValidationError):
TraceFrame(**trace_frame_structure)

0 comments on commit f344d29

Please sign in to comment.