Skip to content

Commit

Permalink
Merge pull request #5 from crytic/dev/error-handling
Browse files Browse the repository at this point in the history
Add basic error handling
  • Loading branch information
tuturu-tech authored Feb 6, 2024
2 parents a70083c + 47c0def commit fcde9ba
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 30 deletions.
54 changes: 36 additions & 18 deletions test_generator/fuzzers/Echidna.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
""" Generates a test file from Echidna reproducers """
# type: ignore[misc] # Ignores 'Any' input parameter
import sys
from typing import Any
from typing import Any, NoReturn
import jinja2

from slither import Slither
from slither.core.declarations.contract import Contract
from slither.core.declarations.function_contract import FunctionContract
from slither.core.solidity_types.elementary_type import ElementaryType
from slither.core.solidity_types.user_defined_type import UserDefinedType
from slither.core.solidity_types.array_type import ArrayType
from slither.core.declarations.structure import Structure
from slither.core.declarations.structure_contract import StructureContract
from slither.core.declarations.enum import Enum
from test_generator.utils.crytic_print import CryticPrint
from test_generator.templates.foundry_templates import templates
from test_generator.utils.encoding import parse_echidna_byte_string
from test_generator.utils.error_handler import handle_exit


class Echidna:
Expand All @@ -36,8 +38,7 @@ def get_target_contract(self) -> Contract:
if contract.name == self.target_name:
return contract

# TODO throw error if no contract found
sys.exit(-1)
handle_exit(f"\n* Slither could not find the specified contract `{self.target_name}`.")

def parse_reproducer(self, calls: Any, index: int) -> str:
"""
Expand All @@ -46,18 +47,16 @@ def parse_reproducer(self, calls: Any, index: int) -> str:
call_list = []
end = len(calls) - 1
function_name = ""
# 1. For each object in the list process the call object and add it to the call list
for idx, call in enumerate(calls):
call_str, fn_name = self._parse_call_object(call)
call_list.append(call_str)
if idx == end:
function_name = fn_name + "_" + str(index)

# 2. Generate the test string and return it
template = jinja2.Template(templates["TEST"])
return template.render(function_name=function_name, call_list=call_list)
# 1. Take a reproducer list and create a test file based on the name of the last function of the list e.g. test_auto_$function_name
# 2. For each object in the list process the call object and add it to the call list
# 3. Using the call list to generate a test string
# 4. Return the test string

def _parse_call_object(self, call_dict: dict[Any, Any]) -> tuple[str, str]:
"""
Expand Down Expand Up @@ -89,12 +88,17 @@ def _parse_call_object(self, call_dict: dict[Any, Any]) -> tuple[str, str]:
if len(function_parameters) == 0:
function_parameters = ""

slither_entry_point = None
slither_entry_point: FunctionContract

for entry_point in self.target.functions_entry_points:
if entry_point.name == function_name:
slither_entry_point = entry_point

if "slither_entry_point" not in locals():
handle_exit(
f"\n* Slither could not find the function `{function_name}` specified in the call object"
)

# 2. Decode the function parameters
variable_definition, call_definition = self._decode_function_params(
function_parameters, False, slither_entry_point
Expand All @@ -119,7 +123,7 @@ def _parse_call_object(self, call_dict: dict[Any, Any]) -> tuple[str, str]:
return call_str, function_name

# pylint: disable=R0201
def _match_elementary_types(self, param: dict, recursive: bool) -> str:
def _match_elementary_types(self, param: dict, recursive: bool) -> str | NoReturn:
"""
Returns a string which represents a elementary type literal value. e.g. "5" or "uint256(5)"
Expand Down Expand Up @@ -169,11 +173,13 @@ def _match_elementary_types(self, param: dict, recursive: bool) -> str:
interpreted_string = f'string(hex"{hex_string}")'
return interpreted_string
case _:
return ""
handle_exit(
f"\n* The parameter tag `{param['tag']}` could not be found in the call object. This could indicate an issue in decoding the call sequence, or a missing feature. Please open an issue at https://github.com/crytic/test-generator/issues"
)

def _match_array_type(
self, param: dict, index: int, input_parameter: Any
) -> tuple[str, str, int]:
) -> tuple[str, str, int] | NoReturn:
match param["tag"]:
case "AbiArray":
# Consider cases where the array items are more complex types (bytes, string, tuples)
Expand All @@ -195,9 +201,13 @@ def _match_array_type(

return name, definitions, index
case _:
return "", "", index
handle_exit(
f"\n* The parameter tag `{param['tag']}` could not be found in the call object. This could indicate an issue in decoding the call sequence, or a missing feature. Please open an issue at https://github.com/crytic/test-generator/issues"
)

def _match_user_defined_type(self, param: dict, input_parameter: Any) -> tuple[str, str]:
def _match_user_defined_type(
self, param: dict, input_parameter: Any
) -> tuple[str, str] | NoReturn:
match param["tag"]:
case "AbiTuple":
match input_parameter.type:
Expand All @@ -207,16 +217,22 @@ def _match_user_defined_type(self, param: dict, input_parameter: Any) -> tuple[s
)
return definitions, f"{input_parameter}({','.join(func_params)})"
case _:
return "", ""
handle_exit(
f"\n* The parameter type `{input_parameter.type}` could not be found. This could indicate an issue in decoding the call sequence, or a missing feature. Please open an issue at https://github.com/crytic/test-generator/issues"
)
case "AbiUInt":
if isinstance(input_parameter.type, Enum):
enum_uint = self._match_elementary_types(param, False)
return "", f"{input_parameter}({enum_uint})"

# TODO is this even reachable?
return "", ""
handle_exit(
f"\n* The parameter type `{input_parameter.type}` does not match the intended type `Enum`. This could indicate an issue in decoding the call sequence, or a missing feature. Please open an issue at https://github.com/crytic/test-generator/issues"
)
case _:
return "", ""
handle_exit(
f"\n* The parameter tag `{param['tag']}` could not be found in the call object. This could indicate an issue in decoding the call sequence, or a missing feature. Please open an issue at https://github.com/crytic/test-generator/issues"
)

def _decode_function_params(
self, function_params: list, recursive: bool, entry_point: Any
Expand Down Expand Up @@ -253,7 +269,9 @@ def _decode_function_params(
params.append(func_params)
case _:
# TODO should handle all cases, but keeping this just in case
print("UNHANDLED INPUT TYPE -> DEFAULT CASE")
CryticPrint().print_information(
f"\n* Attempted to decode an unidentified type {input_parameter}, this call will be skipped. Please open an issue at https://github.com/crytic/test-generator/issues"
)
continue

# 3. Return a list of function parameters
Expand Down
33 changes: 24 additions & 9 deletions test_generator/fuzzers/Medusa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
""" Generates a test file from Medusa reproducers """
import sys
from typing import Any
from typing import Any, NoReturn
import jinja2

from slither import Slither
Expand All @@ -13,8 +12,10 @@
from slither.core.declarations.structure_contract import StructureContract
from slither.core.declarations.enum import Enum
from slither.core.declarations.enum_contract import EnumContract
from test_generator.utils.crytic_print import CryticPrint
from test_generator.templates.foundry_templates import templates
from test_generator.utils.encoding import parse_medusa_byte_string
from test_generator.utils.error_handler import handle_exit


class Medusa:
Expand All @@ -38,8 +39,7 @@ def get_target_contract(self) -> Contract:
if contract.name == self.target_name:
return contract

# TODO throw error if no contract found
sys.exit(-1)
handle_exit(f"\n* Slither could not find the specified contract `{self.target_name}`.")

def parse_reproducer(self, calls: Any, index: int) -> str:
"""
Expand Down Expand Up @@ -86,6 +86,11 @@ def _parse_call_object(self, call_dict: dict) -> tuple[str, str]:
if entry_point.name == function_name:
slither_entry_point = entry_point

if "slither_entry_point" not in locals():
handle_exit(
f"\n* Slither could not find the function `{function_name}` specified in the call object"
)

# 2. Decode the function parameters
variable_definition, call_definition = self._decode_function_params(
function_parameters, False, slither_entry_point
Expand Down Expand Up @@ -138,8 +143,12 @@ def _match_elementary_types(self, param: str, recursive: bool, input_parameter:
hex_string = parse_medusa_byte_string(param)
interpreted_string = f'string(hex"{hex_string}")'
return interpreted_string
if "address" in input_type:
return param

return param
handle_exit(
f"\n* The parameter type `{input_type}` could not be found in the call object. This could indicate an issue in decoding the call sequence, or a missing feature. Please open an issue at https://github.com/crytic/test-generator/issues"
)

def _match_array_type(
self, param: dict, index: int, input_parameter: Any
Expand Down Expand Up @@ -175,11 +184,13 @@ def _match_user_defined_type(
case Enum() | EnumContract(): # type: ignore[misc]
return "", f"{input_parameter}({param})" # type: ignore[unreachable]
case _:
return "", ""
handle_exit(
f"\n* The parameter type `{input_parameter.type}` could not be found in the call object. This could indicate an issue in decoding the call sequence, or a missing feature. Please open an issue at https://github.com/crytic/test-generator/issues"
)

def _decode_function_params(
self, function_params: list | dict, recursive: bool, entry_point: Any
) -> tuple[str, list]:
) -> tuple[str, list] | NoReturn:
params = []
variable_definitions = ""
index = 0
Expand Down Expand Up @@ -212,7 +223,9 @@ def _decode_function_params(
params.append(func_params)
case _:
# TODO should handle all cases, but keeping this just in case
print("UNHANDLED INPUT TYPE -> DEFAULT CASE")
CryticPrint().print_information(
f"\n* Attempted to decode an unidentified type {input_parameter}, this call will be skipped. Please open an issue at https://github.com/crytic/test-generator/issues"
)
continue
else:
for param_idx, param in enumerate(function_params):
Expand Down Expand Up @@ -245,7 +258,9 @@ def _decode_function_params(
params.append(func_params)
case _:
# TODO should handle all cases, but keeping this just in case
print("UNHANDLED INPUT TYPE -> DEFAULT CASE")
CryticPrint().print_information(
f"\n* Attempted to decode an unidentified type {input_parameter}, this call will be skipped. Please open an issue at https://github.com/crytic/test-generator/issues"
)
continue

# 3. Return a list of function parameters
Expand Down
Empty file.
14 changes: 11 additions & 3 deletions test_generator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from test_generator.templates.foundry_templates import templates
from test_generator.fuzzers.Medusa import Medusa
from test_generator.fuzzers.Echidna import Echidna
from test_generator.utils.error_handler import handle_exit


class FoundryTest:
Expand Down Expand Up @@ -100,7 +101,7 @@ def main() -> None: # type: ignore[func-returns-value]
)
parser.add_argument("file_path", help="Path to the Echidna test harness.")
parser.add_argument(
"-cd", "--corpus-dir", dest="corpus_dir", help="Path to the corpus directory"
"-cd", "--corpus-dir", dest="corpus_dir", help="Path to the corpus directory", required=True
)
parser.add_argument("-c", "--contract", dest="target_contract", help="Define the contract name")
parser.add_argument(
Expand Down Expand Up @@ -129,6 +130,12 @@ def main() -> None: # type: ignore[func-returns-value]
)

args = parser.parse_args()

missing_args = [arg for arg, value in vars(args).items() if value is None]
if missing_args:
parser.print_help()
handle_exit(f"\n* Missing required arguments: {', '.join(missing_args)}")

file_path = args.file_path
corpus_dir = args.corpus_dir
test_directory = args.test_directory
Expand All @@ -143,8 +150,9 @@ def main() -> None: # type: ignore[func-returns-value]
case "medusa":
fuzzer = Medusa(target_contract, corpus_dir, slither)
case _:
# TODO create a descriptive error
sys.exit(-1)
handle_exit(
f"\n* The requested fuzzer {args.selected_fuzzer} is not supported. Supported fuzzers: echidna, medusa."
)

CryticPrint().print_information(
f"Generating Foundry unit tests based on the {fuzzer.name} reproducers..."
Expand Down
Empty file.
10 changes: 10 additions & 0 deletions test_generator/utils/error_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
""" Utility function for error handling"""
import sys
from typing import NoReturn
from test_generator.utils.crytic_print import CryticPrint


def handle_exit(reason: str) -> NoReturn:
"""Print an error message to the console and exit"""
CryticPrint().print_error(reason)
sys.exit()

0 comments on commit fcde9ba

Please sign in to comment.