Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Formatting changes #4

Merged
merged 9 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ defaults:

on:
pull_request:
branches: [master, dev]
branches: [main, dev]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down
7 changes: 3 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
SHELL := /bin/bash

PY_MODULE := test-generator
PY_MODULE := test_generator
TEST_MODULE := tests

# Optionally overriden by the user, if they're using a virtual environment manager.
Expand Down Expand Up @@ -40,9 +40,8 @@ run: $(VENV)/pyvenv.cfg
lint: $(VENV)/pyvenv.cfg
. $(VENV_BIN)/activate && \
black --check . && \
pylint $(PY_MODULE) $(TEST_MODULE)
# ruff $(ALL_PY_SRCS) && \
# mypy $(PY_MODULE) &&
pylint $(PY_MODULE) $(TEST_MODULE) && \
mypy $(PY_MODULE)

.PHONY: reformat
reformat:
Expand Down
11 changes: 11 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[mypy]
warn_incomplete_stub = true
ignore_missing_imports = true
disallow_untyped_calls = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
disallow_untyped_decorators = false
warn_redundant_casts = true
warn_no_return = true
warn_unreachable = true
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ logging-fstring-interpolation,
logging-not-lazy,
duplicate-code,
import-error,
unsubscriptable-object
unsubscriptable-object,
too-many-arguments,
unpacking-non-sequence
"""
[tool.mypy]
warn_incomplete_stub = true
Expand Down
94 changes: 55 additions & 39 deletions test_generator/fuzzers/Echidna.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" Generates a test file from Echidna reproducers """
import re
# type: ignore[misc] # Ignores 'Any' input parameter
import sys
from typing import Any
import jinja2

Expand All @@ -14,6 +15,7 @@
from test_generator.templates.foundry_templates import templates
from test_generator.utils.encoding import parse_echidna_byte_string


class Echidna:
"""
Handles the generation of Foundry test files from Echidna reproducers
Expand All @@ -23,20 +25,21 @@ def __init__(self, target_name: str, corpus_path: str, slither: Slither) -> None
self.name = "Echidna"
self.target_name = target_name
self.slither = slither
self.target = self._get_target_contract()
self.target = self.get_target_contract()
self.reproducer_dir = f"{corpus_path}/reproducers"

def _get_target_contract(self) -> Contract:
def get_target_contract(self) -> Contract:
"""Finds and returns Slither Contract"""
contracts = self.slither.get_contract_from_name(self.target_name)
# Loop in case slither fetches multiple contracts for some reason (e.g., similar names?)
for contract in contracts:
if contract.name == self.target_name:
return contract

# TODO throw error if no contract found
exit(-1)
sys.exit(-1)

def parse_reproducer(self, calls: list, index: int) -> str:
def parse_reproducer(self, calls: Any, index: int) -> str:
"""
Takes a list of call dicts and returns a Foundry unit test string containing the call sequence.
"""
Expand All @@ -56,15 +59,15 @@ def parse_reproducer(self, calls: list, index: int) -> str:
# 3. Using the call list to generate a test string
# 4. Return the test string

def _parse_call_object(self, call_dict) -> (str, str):
def _parse_call_object(self, call_dict: dict[Any, Any]) -> tuple[str, str]:
"""
Takes a single call dictionary, parses it, and returns the series of function calls as a string, along with
the name of the last function, which is used as the name of the test.
"""
# 1. Parse call object and save the variables
time_delay = int(call_dict["delay"][0], 16)
block_delay = int(call_dict["delay"][1], 16)
has_delay = True if time_delay > 0 or block_delay > 0 else False
has_delay = time_delay > 0 or block_delay > 0

if call_dict["call"]["tag"] == "NoCall":
template = jinja2.Template(templates["EMPTY"])
Expand All @@ -88,7 +91,6 @@ def _parse_call_object(self, call_dict) -> (str, str):
variable_definition, call_definition = self._decode_function_params(
function_parameters, False, slither_entry_point
)
params = ", ".join(call_definition)

# 3. Generate a call string and return it
template = jinja2.Template(templates["CALL"])
Expand All @@ -98,7 +100,7 @@ def _parse_call_object(self, call_dict) -> (str, str):
block_delay=block_delay,
caller=caller,
value=value,
function_parameters=params,
function_parameters=", ".join(call_definition),
function_name=function_name,
contract_name=self.target_name,
)
Expand All @@ -108,6 +110,7 @@ def _parse_call_object(self, call_dict) -> (str, str):

return call_str, function_name

# pylint: disable=R0201
def _match_elementary_types(self, param: dict, recursive: bool) -> str:
"""
Returns a string which represents a elementary type literal value. e.g. "5" or "uint256(5)"
Expand All @@ -130,9 +133,9 @@ def _match_elementary_types(self, param: dict, recursive: bool) -> str:
cast = "uint" if param["tag"] == "AbiUInt" else "int"
if not recursive:
return param["contents"][1]
else:
casting = f'{cast}{str(param["contents"][0])}({param["contents"][1]})'
return casting

casting = f'{cast}{str(param["contents"][0])}({param["contents"][1]})'
return casting
case "AbiAddress":
return param["contents"]
case "AbiBytes" | "AbiBytesDynamic":
Expand All @@ -144,17 +147,25 @@ def _match_elementary_types(self, param: dict, recursive: bool) -> str:
hex_string = parse_echidna_byte_string(contents.strip('"'))
interpreted_string = f'hex"{hex_string}"'
if not recursive:
result = f"bytes{size}({interpreted_string})" if is_fixed_size else interpreted_string
result = (
f"bytes{size}({interpreted_string})"
if is_fixed_size
else interpreted_string
)
return result
else:
casting = f"bytes{size}({interpreted_string})"
return casting

casting = f"bytes{size}({interpreted_string})"
return casting
case "AbiString":
hex_string = parse_echidna_byte_string(param["contents"].strip('"'))
interpreted_string = f'string(hex"{hex_string}")'
return interpreted_string
case _:
return ""

def _match_array_type(self, param: dict, index: int, input_parameter) -> tuple[str, str, int]:
def _match_array_type(
self, param: dict, index: int, input_parameter: Any
) -> tuple[str, str, int]:
match param["tag"]:
case "AbiArray":
# Consider cases where the array items are more complex types (bytes, string, tuples)
Expand All @@ -167,35 +178,41 @@ def _match_array_type(self, param: dict, index: int, input_parameter) -> tuple[s
definitions, func_params = self._decode_function_params(
param["contents"][1], True, input_parameter
)
name, var_def = self._get_memarr(param["contents"], index)
name, var_def = self._get_memarr(param["contents"], index) # type: ignore[unpacking-non-sequence]
definitions += var_def

for idx, temp_param in enumerate(func_params):
definitions += f"\t\t{name}[{idx}] = {temp_param};\n"
index += 1

return name, definitions, index
case _:
return "", "", index

def _match_user_defined_type(self, param: dict, input_parameter) -> tuple[str, str]:
def _match_user_defined_type(self, param: dict, input_parameter: Any) -> tuple[str, str]:
match param["tag"]:
case "AbiTuple":
match input_parameter.type:
case Structure() | StructureContract():
definitions, func_params = self._decode_function_params(
case Structure() | StructureContract(): # type: ignore[misc]
definitions, func_params = self._decode_function_params( # type: ignore[unreachable]
param["contents"], True, input_parameter.type.elems_ordered
)
return definitions, f"{input_parameter}({','.join(func_params)})"
case _:
return "", ""
case "AbiUInt":
if isinstance(input_parameter.type, Enum):
enum_uint = self._match_elementary_types(param, False)
return "", f"{input_parameter}({enum_uint})"
else:
# TODO is this even reachable?
return "", ""

# TODO is this even reachable?
return "", ""
case _:
return "", ""

def _decode_function_params(
self, function_params: list, recursive: bool, entry_point: Any
) -> (str | None, list):
) -> tuple[str, list]:
params = []
variable_definitions = ""
index = 0
Expand All @@ -204,28 +221,26 @@ def _decode_function_params(
for param_idx, param in enumerate(function_params):
input_parameter = None
if recursive:
try:
if isinstance(entry_point, list):
input_parameter = entry_point[param_idx].type
except:
else:
input_parameter = entry_point.type

else:
input_parameter = entry_point.parameters[param_idx].type

match input_parameter:
case ElementaryType():
params.append(self._match_elementary_types(param, recursive))
case ArrayType():
[inputs, definitions, new_index] = self._match_array_type(
case ElementaryType(): # type: ignore[misc]
params.append(self._match_elementary_types(param, recursive)) # type: ignore[unreachable]
case ArrayType(): # type: ignore[misc]
inputs, definitions, new_index = self._match_array_type( # type: ignore[unreachable,unpacking-non-sequence]
param, index, input_parameter
)
params.append(inputs)
variable_definitions += definitions
index = new_index
case UserDefinedType():
[definitions, func_params] = self._match_user_defined_type(
param, input_parameter
)
case UserDefinedType(): # type: ignore[misc]
definitions, func_params = self._match_user_defined_type(param, input_parameter) # type: ignore[unreachable, unpacking-non-sequence]
variable_definitions += definitions
params.append(func_params)
case _:
Expand All @@ -236,10 +251,11 @@ def _decode_function_params(
# 3. Return a list of function parameters
if len(variable_definitions) > 0:
return variable_definitions, params
else:
return "", params

def _get_memarr(self, function_params: dict, index: int) -> (str | None, str | None):
return "", params

# pylint: disable=R0201
def _get_memarr(self, function_params: dict, index: int) -> tuple[str, str]:
length = len(function_params[1])
match function_params[0]["tag"]:
case "AbiBoolType":
Expand Down Expand Up @@ -267,4 +283,4 @@ def _get_memarr(self, function_params: dict, index: int) -> (str | None, str | N
name = f"dynStringArr_{index}"
return name, f"string[] memory {name} = new string[]({length});\n"
case _:
return None, None
return "", ""
Loading