Skip to content

Commit

Permalink
feat: struct-log based trace-frames now include events emitted (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Jun 26, 2024
1 parent 1b1cb33 commit dbbd986
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 43 deletions.
30 changes: 29 additions & 1 deletion evm_trace/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from eth_pydantic_types import HexBytes
from pydantic import BaseModel as _BaseModel
from pydantic import ConfigDict, field_validator
from pydantic import ConfigDict, Field, field_validator

from evm_trace.display import get_tree_display
from evm_trace.enums import CallType
Expand All @@ -16,6 +16,31 @@ class BaseModel(_BaseModel):
)


class EventNode(BaseModel):
"""
An event emitted during a CALL.
"""

call_type: CallType = CallType.EVENT
"""The call-type for events is always ``EVENT``."""

data: HexBytes = HexBytes(b"")
"""The remaining event data besides the topics."""

depth: int
"""The depth in a call-tree where the event took place."""

topics: list[HexBytes] = Field(min_length=1)
"""Event topics, including the selector."""

@property
def selector(self) -> HexBytes:
"""
The selector is always the first topic.
"""
return self.topics[0]


class CallTreeNode(BaseModel):
"""
A higher-level object modeling a node in an execution call tree.
Expand Down Expand Up @@ -59,6 +84,9 @@ class CallTreeNode(BaseModel):
failed: bool = False
"""Whether the call failed or not."""

events: list[EventNode] = []
"""All events made in the call."""

def __str__(self) -> str:
try:
return get_tree_display(self)
Expand Down
83 changes: 63 additions & 20 deletions evm_trace/display.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,36 @@
from collections.abc import Iterator
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Optional, Union, cast

from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address

from evm_trace.enums import CallType

if TYPE_CHECKING:
from evm_trace.base import CallTreeNode
from evm_trace.base import CallTreeNode, EventNode


def get_tree_display(call: "CallTreeNode") -> str:
return "\n".join([str(t) for t in TreeRepresentation.make_tree(call)])


class TreeRepresentation:
FILE_MIDDLE_PREFIX = "├──"
FILE_LAST_PREFIX = "└──"
"""
A class for creating a simple tree-representation of a call-tree node.
**NOTE**: We purposely are not using the rich library here to keep
evm-trace small and simple while sill offering a nice stringified
version of a :class:`~evm_trace.base.CallTreeNode`.
"""

MIDDLE_PREFIX = "├──"
LAST_PREFIX = "└──"
PARENT_PREFIX_MIDDLE = " "
PARENT_PREFIX_LAST = "│ "

def __init__(
self,
call: "CallTreeNode",
call: Union["CallTreeNode", "EventNode"],
parent: Optional["TreeRepresentation"] = None,
is_last: bool = False,
):
Expand All @@ -32,13 +40,27 @@ def __init__(

@property
def depth(self) -> int:
"""
The depth in the call tree, such as the
number of calls deep.
"""
return self.call.depth

@property
def title(self) -> str:
"""
The title of the node representation, including address, calldata, and return-data.
For event-nodes, it is mostly the selector string.
"""
call_type = self.call.call_type.value
address_hex_str = self.call.address.hex() if self.call.address else None

if hasattr(self.call, "selector"):
# Is an Event-node
selector = self.call.selector.hex() if self.call.selector else None
return f"{call_type}: {selector}"
# else: Is a CallTreeNode

address_hex_str = self.call.address.hex() if self.call.address else None
try:
address = to_checksum_address(address_hex_str) if address_hex_str else None
except (ImportError, ValueError):
Expand Down Expand Up @@ -77,33 +99,54 @@ def title(self) -> str:
@classmethod
def make_tree(
cls,
root: "CallTreeNode",
root: Union["CallTreeNode", "EventNode"],
parent: Optional["TreeRepresentation"] = None,
is_last: bool = False,
) -> Iterator["TreeRepresentation"]:
"""
Create a node representation object from a :class:`~evm_trace.base.CallTreeNode`.
Args:
root (:class:`~evm_trace.base.CallTreeNode` | :class:`~evm_trace.base.EventNode`):
The call-tree node or event-node to display.
parent (Optional[:class:`~evm_trace.display.TreeRepresentation`]): The parent
node of this node.
is_last (bool): True if a leaf-node.
"""
displayable_root = cls(root, parent=parent, is_last=is_last)
yield displayable_root

count = 1
for child_node in root.calls:
is_last = count == len(root.calls)
if child_node.calls:
yield from cls.make_tree(child_node, parent=displayable_root, is_last=is_last)
else:
yield cls(child_node, parent=displayable_root, is_last=is_last)

count += 1
if hasattr(root, "topics"):
# Events have no children.
return

# Handle events, which won't have any sub-calls or anything.
total_events = len(root.events)
for index, event in enumerate(root.events, start=1):
is_last = index == total_events
yield cls(event, parent=displayable_root, is_last=is_last)

# Handle calls (and calls of calls).
total_calls = len(root.calls)
for index, child_node in enumerate(root.calls, start=1):
is_last = index == total_calls
# NOTE: `.make_tree()` will handle calls of calls (recursion).
yield from cls.make_tree(child_node, parent=displayable_root, is_last=is_last)

def __str__(self) -> str:
"""
The representation str via ``calling str()``.
"""
if self.parent is None:
return self.title

filename_prefix = self.FILE_LAST_PREFIX if self.is_last else self.FILE_MIDDLE_PREFIX

parts = [f"{filename_prefix} {self.title}"]
tree_prefix = self.LAST_PREFIX if self.is_last else self.MIDDLE_PREFIX
parts = [f"{tree_prefix} {self.title}"]
parent = self.parent
while parent and parent.parent is not None:
parts.append(self.PARENT_PREFIX_MIDDLE if parent.is_last else self.PARENT_PREFIX_LAST)
parent = parent.parent

return "".join(reversed(parts))

def __repr__(self) -> str:
return str(self)
9 changes: 5 additions & 4 deletions evm_trace/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@


class CallType(Enum):
INTERNAL = "INTERNAL" # Non-opcode internal call
CALL = "CALL"
CALLCODE = "CALLCODE"
CREATE = "CREATE"
CREATE2 = "CREATE2"
CALL = "CALL"
DELEGATECALL = "DELEGATECALL"
STATICCALL = "STATICCALL"
CALLCODE = "CALLCODE"
EVENT = "EVENT"
INTERNAL = "INTERNAL" # Non-opcode internal call
SELFDESTRUCT = "SELFDESTRUCT"
STATICCALL = "STATICCALL"

def __eq__(self, other):
return self.value == getattr(other, "value", other)
Expand Down
43 changes: 36 additions & 7 deletions evm_trace/geth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from eth_utils import to_int
from pydantic import Field, RootModel, field_validator

from evm_trace.base import BaseModel, CallTreeNode
from evm_trace.base import BaseModel, CallTreeNode, EventNode
from evm_trace.enums import CALL_OPCODES, CallType


Expand Down Expand Up @@ -78,7 +78,7 @@ def create_trace_frames(data: Iterator[dict]) -> Iterator[TraceFrame]:
looking ahead and finding it.
Args:
data (Iterator[Dict]): An iterator of response struct logs.
data (Iterator[dict]): An iterator of response struct logs.
Returns:
Iterator[:class:`~evm_trace.geth.TraceFrame`]
Expand Down Expand Up @@ -132,7 +132,7 @@ def get_calltree_from_geth_call_trace(data: dict) -> CallTreeNode:
Creates a CallTreeNode from a given transaction call trace.
Args:
data (Dict): The response from ``debug_traceTransaction`` when using
data (dict): The response from ``debug_traceTransaction`` when using
``tracer=callTracer``.
Returns:
Expand Down Expand Up @@ -216,7 +216,7 @@ def extract_memory(offset: HexBytes, size: HexBytes, memory: list[HexBytes]) ->
Args:
offset (HexBytes): Offset byte location in memory.
size (HexBytes): Number of bytes to return.
memory (List[HexBytes]): Memory stack.
memory (list[HexBytes]): Memory stack.
Returns:
HexBytes: Byte value from memory stack.
Expand Down Expand Up @@ -252,7 +252,6 @@ def _create_node(
Use specified opcodes to create a branching callnode
https://www.evm.codes/
"""

if isinstance(trace, list):
# NOTE: We don't officially support lists here,
# but if we don't do this, the user gets a recursion error
Expand Down Expand Up @@ -298,6 +297,13 @@ def _create_node(
else:
node_kwargs["calls"] = [subcall]

elif frame.op.startswith("LOG") and len(frame.op) > 3 and frame.op[3].isnumeric():
event = _create_event_node(frame)
if "events" in node_kwargs:
node_kwargs["events"].append(event)
else:
node_kwargs["events"] = [event]

# TODO: Handle internal nodes using JUMP and JUMPI

elif frame.op == CallType.SELFDESTRUCT.value:
Expand All @@ -324,14 +330,37 @@ def _create_node(
if "last_create_depth" in node_kwargs:
del node_kwargs["last_create_depth"]

if "callType" in node_kwargs:
node_kwargs["call_type"] = node_kwargs.pop("callType")
elif "call_type" not in node_kwargs:
node_kwargs["call_type"] = CallType.CALL # Default.

if node_kwargs["call_type"] in (CallType.CREATE, CallType.CREATE2) and not node_kwargs.get(
"address"
):
# Set temporary address so validation succeeds.
node_kwargs["address"] = 20 * b"\x00"

node = CallTreeNode(**node_kwargs)
return node
return CallTreeNode(**node_kwargs)


def _create_event_node(frame: TraceFrame) -> EventNode:
# The number of topics is derived from the opcode,
# e.g. LOG2 meaning 2 topics (not counting the selector).
num_topics = int(frame.op[3])

# The selector always seems to be here.
selector_idx = -3
selector = frame.stack[-3]

# Figure out topics.
start_topic_idx = selector_idx - num_topics + 1
topics = [selector, *[HexBytes(t) for t in reversed(frame.stack[start_topic_idx:selector_idx])]]

# Figure out data.
data = frame.memory.get(frame.stack[-1], frame.stack[-2])

return EventNode(data=data, depth=frame.depth, topics=topics)


def _validate_data_from_call_tracer(data: dict) -> dict:
Expand Down
32 changes: 21 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@
GETH_DATA = DATA_PATH / "geth"
EVM_TRACE_DATA = DATA_PATH / "evm_trace"
PARITY_DATA = DATA_PATH / "parity"
TRACE_FRAME_DATA = json.loads((EVM_TRACE_DATA / "frame.json").read_text())
CALL_FRAME_DATA = json.loads((EVM_TRACE_DATA / "call.json").read_text())
MUTABLE_CALL_TREE_DATA = json.loads((EVM_TRACE_DATA / "mutable_call.json").read_text())
STATIC_CALL_TREE_DATA = json.loads((EVM_TRACE_DATA / "static_call.json").read_text())
DELEGATE_CALL_TREE_DATA = json.loads((EVM_TRACE_DATA / "delegate_call.json").read_text())
TRACE_FRAME_DATA = json.loads((EVM_TRACE_DATA / "frame.json").read_text(encoding="utf8"))
CALL_FRAME_DATA = json.loads((EVM_TRACE_DATA / "call.json").read_text(encoding="utf8"))
MUTABLE_CALL_TREE_DATA = json.loads(
(EVM_TRACE_DATA / "mutable_call.json").read_text(encoding="utf8")
)
STATIC_CALL_TREE_DATA = json.loads((EVM_TRACE_DATA / "static_call.json").read_text(encoding="utf8"))
DELEGATE_CALL_TREE_DATA = json.loads(
(EVM_TRACE_DATA / "delegate_call.json").read_text(encoding="utf8")
)
CALL_TRACE_DATA = json.loads((GETH_DATA / "call.json").read_text())
CREATE_CALL_TRACE_DATA = json.loads((GETH_DATA / "create_call.json").read_text())
GETH_CREATE2_TRACE = json.loads((GETH_DATA / "create2_structlogs.json").read_text())
PARITY_CREATE2_TRACE = json.loads((PARITY_DATA / "create2.json").read_text())
CREATE_CALL_TRACE_DATA = json.loads((GETH_DATA / "create_call.json").read_text(encoding="utf8"))
GETH_TRACE = json.loads((GETH_DATA / "structlogs.json").read_text(encoding="utf8"))
GETH_CREATE2_TRACE = json.loads((GETH_DATA / "create2_structlogs.json").read_text(encoding="utf8"))
PARITY_CREATE2_TRACE = json.loads((PARITY_DATA / "create2.json").read_text(encoding="utf8"))
CALL_TREE_DATA_MAP = {
CallType.CALL.value: MUTABLE_CALL_TREE_DATA,
CallType.STATICCALL.value: STATIC_CALL_TREE_DATA,
Expand Down Expand Up @@ -70,18 +75,23 @@ def call_tree_data(request):
yield CALL_TREE_DATA_MAP[request.param]


@pytest.fixture
@pytest.fixture(scope="session")
def parity_create2_trace_list():
trace_list = [ParityTrace.model_validate(x) for x in PARITY_CREATE2_TRACE]
return ParityTraceList(root=trace_list)


@pytest.fixture
@pytest.fixture(scope="session")
def geth_structlogs():
return GETH_TRACE


@pytest.fixture(scope="session")
def geth_create2_struct_logs():
return GETH_CREATE2_TRACE


@pytest.fixture
@pytest.fixture(scope="session")
def geth_create2_trace_frames(geth_create2_struct_logs):
# NOTE: These frames won't have the CREATE address set.
return [TraceFrame(**x) for x in geth_create2_struct_logs]
1 change: 1 addition & 0 deletions tests/data/geth/structlogs.json

Large diffs are not rendered by default.

Loading

0 comments on commit dbbd986

Please sign in to comment.