Skip to content

Commit

Permalink
fix: manually solve dynamic array overflow conditions (a16z#366)
Browse files Browse the repository at this point in the history
  • Loading branch information
daejunpark authored Sep 24, 2024
1 parent d9eaad9 commit 20cd93a
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 8 deletions.
48 changes: 40 additions & 8 deletions src/halmos/sevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
BitVec,
BitVecRef,
BoolVal,
CheckSatResult,
Concat,
Extract,
Function,
Expand Down Expand Up @@ -94,14 +95,19 @@
debug,
extract_bytes,
f_ecrecover,
f_sha3_256_name,
f_sha3_512_name,
f_sha3_name,
hexify,
int_of,
is_bool,
is_bv,
is_bv_value,
is_concrete,
is_f_sha3_name,
is_non_zero,
is_zero,
match_dynamic_array_overflow_condition,
restore_precomputed_hashes,
sha3_inv,
str_opcode,
Expand Down Expand Up @@ -994,15 +1000,35 @@ def dump(self, print_mem=False) -> str:
def advance_pc(self) -> None:
self.pc = self.pgm.next_pc(self.pc)

def check(self, cond: Any) -> Any:
cond = simplify(cond)
def quick_custom_check(self, cond: BitVecRef) -> CheckSatResult | None:
"""
Quick custom checker for specific known patterns.
This method checks for certain common conditions that can be evaluated
quickly without invoking the full SMT solver.
Returns:
sat if the condition is satisfiable
unsat if the condition is unsatisfiable
None if the condition requires full SMT solving
"""
if is_true(cond):
return sat

if is_false(cond):
return unsat

# Not(ULE(f_sha3_N(slot), offset + f_sha3_N(slot))), where offset < 2**64
if match_dynamic_array_overflow_condition(cond):
return unsat

def check(self, cond: Any) -> Any:
cond = simplify(cond)

# use quick custom checker for common patterns before falling back to SMT solver
if result := self.quick_custom_check(cond):
return result

return self.path.check(cond)

def select(
Expand Down Expand Up @@ -1063,7 +1089,7 @@ def sha3_data(self, data: Bytes) -> Word:
data = bytes_to_bv_value(data)

f_sha3 = Function(
f"f_sha3_{size * 8}", BitVecSorts[size * 8], BitVecSort256
f_sha3_name(size * 8), BitVecSorts[size * 8], BitVecSort256
)
sha3_expr = f_sha3(data)
else:
Expand Down Expand Up @@ -1288,17 +1314,17 @@ def get_key_structure(cls, loc) -> tuple:
def decode(cls, loc: Any) -> Any:
loc = normalize(loc)
# m[k] : hash(k.m)
if loc.decl().name() == "f_sha3_512":
if loc.decl().name() == f_sha3_512_name:
args = loc.arg(0)
offset = simplify(Extract(511, 256, args))
base = simplify(Extract(255, 0, args))
return cls.decode(base) + (offset, ZERO)
# a[i] : hash(a) + i
elif loc.decl().name() == "f_sha3_256":
elif loc.decl().name() == f_sha3_256_name:
base = loc.arg(0)
return cls.decode(base) + (ZERO,)
# m[k] : hash(k.m) where |k| != 256-bit
elif loc.decl().name().startswith("f_sha3_"):
elif is_f_sha3_name(loc.decl().name()):
sha3_input = normalize(loc.arg(0))
if sha3_input.decl().name() == "concat" and sha3_input.num_args() == 2:
offset = simplify(sha3_input.arg(0))
Expand Down Expand Up @@ -1417,12 +1443,12 @@ def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None:
@classmethod
def decode(cls, loc: Any) -> Any:
loc = normalize(loc)
if loc.decl().name() == "f_sha3_512": # hash(hi,lo), recursively
if loc.decl().name() == f_sha3_512_name: # hash(hi,lo), recursively
args = loc.arg(0)
hi = cls.decode(simplify(Extract(511, 256, args)))
lo = cls.decode(simplify(Extract(255, 0, args)))
return cls.simple_hash(Concat(hi, lo))
elif loc.decl().name().startswith("f_sha3_"):
elif is_f_sha3_name(loc.decl().name()):
sha3_input = normalize(loc.arg(0))
if sha3_input.decl().name() == "concat":
decoded_sha3_input_args = [
Expand Down Expand Up @@ -2359,6 +2385,12 @@ def jumpi(
follow_false = visited[False] < self.options.loop
if not (follow_true and follow_false):
self.logs.bounded_loops.append(jid)
if self.options.debug:
debug(f"\nloop id: {jid}")
debug(f"loop condition: {cond}")
debug(f"calldata: {ex.calldata()}")
debug("path condition:")
debug(ex.path)
else:
# for constant-bounded loops
follow_true = potential_true
Expand Down
50 changes: 50 additions & 0 deletions src/halmos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from typing import Any

from z3 import (
Z3_OP_BADD,
Z3_OP_CONCAT,
Z3_OP_ULEQ,
BitVecNumRef,
BitVecRef,
BitVecSort,
Expand All @@ -21,11 +23,13 @@
SignExt,
SolverFor,
ZeroExt,
eq,
is_app,
is_app_of,
is_bool,
is_bv,
is_bv_value,
is_not,
simplify,
)

Expand Down Expand Up @@ -94,6 +98,18 @@ def __getitem__(self, size: int) -> BitVecSort:
)


def is_f_sha3_name(name: str) -> bool:
return name.startswith("f_sha3_")


def f_sha3_name(bitsize: int) -> str:
return f"f_sha3_{bitsize}"


f_sha3_256_name = f_sha3_name(256)
f_sha3_512_name = f_sha3_name(512)


def wrap(x: Any) -> Word:
if is_bv(x):
return x
Expand Down Expand Up @@ -349,6 +365,40 @@ def byte_length(x: Any, strict=True) -> int:
raise TypeError(f"byte_length({x}) of type {type(x)}")


def match_dynamic_array_overflow_condition(cond: BitVecRef) -> bool:
"""
Check if `cond` matches the following pattern:
Not(ULE(f_sha3_N(slot), offset + f_sha3_N(slot))), where offset < 2**64
This condition is satisfied when a dynamic array at `slot` exceeds the storage limit.
Since such an overflow is highly unlikely in practice, we assume that this condition is unsat.
Note: we already assume that any sha3 hash output is smaller than 2**256 - 2**64 (see SEVM.sha3_data()).
However, the smt solver may not be able to solve this condition within the branching timeout.
In such cases, this explicit pattern serves as a fallback to avoid exploring practically infeasible paths.
We don't need to handle the negation of this condition, because unknown conditions are conservatively assumed to be sat.
"""

# Not(ule)
if not is_not(cond):
return False
ule = cond.arg(0)

# Not(ULE(left, right)
if not is_app_of(ule, Z3_OP_ULEQ):
return False
left, right = ule.arg(0), ule.arg(1)

# Not(ULE(f_sha3_N(slot), offset + base))
if not (is_f_sha3_name(left.decl().name()) and is_app_of(right, Z3_OP_BADD)):
return False
offset, base = right.arg(0), right.arg(1)

# Not(ULE(f_sha3_N(slot), offset + f_sha3_N(slot))) and offset < 2**64
return eq(left, base) and is_bv_value(offset) and offset.as_long() < 2**64


def stripped(hexstring: str) -> str:
"""Remove 0x prefix from hexstring"""
return hexstring[2:] if hexstring.startswith("0x") else hexstring
Expand Down
11 changes: 11 additions & 0 deletions tests/expected/all.json
Original file line number Diff line number Diff line change
Expand Up @@ -2365,6 +2365,17 @@
"num_bounded_loops": null
}
],
"test/Solver.t.sol:SolverTest": [
{
"name": "check_dynamic_array_overflow()",
"exitcode": 0,
"num_models": 0,
"models": null,
"num_paths": null,
"time": null,
"num_bounded_loops": null
}
],
"test/StaticContexts.t.sol:StaticContextsTest": [
{
"name": "check_create2_fails()",
Expand Down
13 changes: 13 additions & 0 deletions tests/regression/test/Solver.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// SPDX-License-Identifier: AGPL-3.0
pragma solidity >=0.8.0 <0.9.0;

import "forge-std/Test.sol";
import {SymTest} from "halmos-cheatcodes/SymTest.sol";

contract SolverTest is SymTest, Test {
uint[] numbers;

function check_dynamic_array_overflow() public {
numbers = new uint[](5); // shouldn't generate loop bounds warning
}
}
56 changes: 56 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from z3 import (
ULE,
BitVec,
BitVecSort,
BitVecVal,
Function,
Not,
simplify,
)

from halmos.utils import f_sha3_256_name, match_dynamic_array_overflow_condition


def test_match_dynamic_array_overflow_condition():
# Create Z3 objects
f_sha3_256 = Function(f_sha3_256_name, BitVecSort(256), BitVecSort(256))
slot = BitVec("slot", 256)
offset = BitVecVal(1000, 256) # Less than 2**64

# Test the function
cond = Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot)))
assert match_dynamic_array_overflow_condition(cond)

# Test with opposite order of addition
opposite_order_cond = Not(ULE(f_sha3_256(slot), f_sha3_256(slot) + offset))
assert not match_dynamic_array_overflow_condition(opposite_order_cond)

# Test with opposite order after simplification
simplified_opposite_order_cond = simplify(
Not(ULE(f_sha3_256(slot), f_sha3_256(slot) + offset))
)
assert match_dynamic_array_overflow_condition(simplified_opposite_order_cond)

# Test with offset = 2**64 - 1 (should match)
max_valid_offset = BitVecVal(2**64 - 1, 256)
max_valid_cond = Not(ULE(f_sha3_256(slot), max_valid_offset + f_sha3_256(slot)))
assert match_dynamic_array_overflow_condition(max_valid_cond)

# Test with offset >= 2**64
large_offset = BitVecVal(2**64, 256)
large_offset_cond = Not(ULE(f_sha3_256(slot), large_offset + f_sha3_256(slot)))
assert not match_dynamic_array_overflow_condition(large_offset_cond)

# Test with a different function
different_func = Function("different_func", BitVecSort(256), BitVecSort(256))
non_matching_cond = Not(ULE(different_func(slot), offset + different_func(slot)))
assert not match_dynamic_array_overflow_condition(non_matching_cond)

# Test with just ULE, not Not(ULE(...))
ule_only = ULE(f_sha3_256(slot), offset + f_sha3_256(slot))
assert not match_dynamic_array_overflow_condition(ule_only)

# Test with mismatched slots
slot2 = BitVec("slot2", 256)
mismatched_slots = Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot2)))
assert not match_dynamic_array_overflow_condition(mismatched_slots)

0 comments on commit 20cd93a

Please sign in to comment.