Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make true and false functions #505

Merged
merged 1 commit into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions claripy/ast/bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from contextlib import suppress
from functools import lru_cache
from typing import TYPE_CHECKING, overload

from claripy import operations
Expand All @@ -16,8 +17,6 @@

l = logging.getLogger("claripy.ast.bool")

_boolv_cache = {}


class Bool(Base):
__slots__ = ()
Expand Down Expand Up @@ -60,21 +59,18 @@ def BoolS(name, explicit_name=None) -> Bool:
return Bool("BoolS", (n,), variables=frozenset((n,)), symbolic=True)


@lru_cache(maxsize=2)
def BoolV(val) -> Bool:
try:
return _boolv_cache[(val)]
except KeyError:
result = Bool("BoolV", (val,))
_boolv_cache[val] = result
return result
return Bool("BoolV", (val,))


#
# some standard ASTs
#
def true():
return BoolV(True)


def false():
return BoolV(False)

true = BoolV(True)
false = BoolV(False)

#
# Bound operations
Expand Down Expand Up @@ -145,9 +141,9 @@ def If(cond, true_value, false_value):

if args[1] is args[2]:
return args[1]
if args[1] is true and args[2] is false:
if args[1] is true() and args[2] is false():
return args[0]
if args[1] is false and args[2] is true:
if args[1] is false() and args[2] is true():
return ~args[0]

if issubclass(ty, Bits):
Expand Down Expand Up @@ -236,7 +232,7 @@ def reverse_ite_cases(ast):
:param ast:
:return:
"""
queue = [(true, ast)]
queue = [(true(), ast)]
while queue:
condition, ast = queue.pop(0)
if ast.op == "If":
Expand Down
2 changes: 1 addition & 1 deletion claripy/frontend_mixins/constraint_filter_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _add(self, constraints, invalidate_cache=True):
ec = self._constraint_filter(constraints)
except UnsatError:
# filter out concrete False
ec = [c for c in constraints if c not in {False, false}] + [false]
ec = [c for c in constraints if c not in {False, false()}] + [false()]

if len(constraints) == 0:
return []
Expand Down
4 changes: 2 additions & 2 deletions claripy/frontend_mixins/model_cache_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __setstate__(self, base_state):

def simplify(self):
results = super().simplify()
if len(results) > 0 and any(c is false for c in results):
if len(results) > 0 and any(c is false() for c in results):
self._models.clear()
return results

Expand Down Expand Up @@ -195,7 +195,7 @@ def _add(self, constraints, invalidate_cache=True):
new_vars = any(a.variables - old_vars for a in added)
if new_vars or invalidate_cache:
# shortcut for unsat
if any(c is false for c in constraints):
if any(c is false() for c in constraints):
self._models.clear()

still_valid = set(self._get_models(extra_constraints=added))
Expand Down
4 changes: 2 additions & 2 deletions claripy/frontend_mixins/sat_cache_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ def __setstate__(self, s):

def _add(self, constraints, invalidate_cache=True):
added = super()._add(constraints, invalidate_cache=invalidate_cache)
if len(added) > 0 and any(c is false for c in added):
if len(added) > 0 and any(c is false() for c in added):
self._cached_satness = False
elif self._cached_satness is True:
self._cached_satness = None
return added

def simplify(self):
new_constraints = super().simplify()
if len(new_constraints) > 0 and any(c is false for c in new_constraints):
if len(new_constraints) > 0 and any(c is false() for c in new_constraints):
self._cached_satness = False
return new_constraints

Expand Down
4 changes: 2 additions & 2 deletions claripy/frontends/replacement_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def _add(self, constraints, invalidate_cache=True):

if not self._complex_auto_replace:
if rc.op == "Not":
self.add_replacement(c.args[0], false, replace=False, promote=True, invalidate_cache=True)
self.add_replacement(c.args[0], false(), replace=False, promote=True, invalidate_cache=True)
elif rc.op == "__eq__" and rc.args[0].symbolic ^ rc.args[1].symbolic:
old, new = rc.args if rc.args[0].symbolic else rc.args[::-1]
self.add_replacement(old, new, replace=False, promote=True, invalidate_cache=True)
Expand All @@ -281,7 +281,7 @@ def _add(self, constraints, invalidate_cache=True):
backends.vsa, rc, validation_frontend=self._validation_frontend
).compat_ret
if not satisfiable:
self.add_replacement(rc, false)
self.add_replacement(rc, false())
for old, new in replacements:
if old.cardinality == 1:
continue
Expand Down
50 changes: 17 additions & 33 deletions claripy/simplifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ def lshift_simplifier(val, shift):

def eq_simplifier(a, b):
if a is b:
return claripy.true
return claripy.true()

if isinstance(a, claripy.ast.Bool) and b is claripy.true:
if isinstance(a, claripy.ast.Bool) and b is claripy.true():
return a
if isinstance(b, claripy.ast.Bool) and a is claripy.true:
if isinstance(b, claripy.ast.Bool) and a is claripy.true():
return b
if isinstance(a, claripy.ast.Bool) and b is claripy.false:
if isinstance(a, claripy.ast.Bool) and b is claripy.false():
return claripy.Not(a)
if isinstance(b, claripy.ast.Bool) and a is claripy.false:
if isinstance(b, claripy.ast.Bool) and a is claripy.false():
return claripy.Not(b)

if a.op == "Reverse" and b.op == "Reverse":
Expand Down Expand Up @@ -211,10 +211,6 @@ def eq_simplifier(a, b):
if a.args[2] is b and claripy.is_true(a.args[1] != b):
# (If(c, x, y) == y, x != y) -> !c
return claripy.Not(a.args[0])
# elif a._claripy.is_true(a.args[1] == b) and a._claripy.is_true(a.args[2] == b):
# return a._claripy.true
# elif a._claripy.is_true(a.args[1] != b) and a._claripy.is_true(a.args[2] != b):
# return a._claripy.false

if b.op == "If":
if b.args[1] is a and claripy.is_true(b.args[2] != b):
Expand All @@ -223,10 +219,6 @@ def eq_simplifier(a, b):
if b.args[2] is a and claripy.is_true(b.args[1] != a):
# (y == If(c, x, y)) -> !c
return claripy.Not(b.args[0])
# elif b._claripy.is_true(b.args[1] == a) and b._claripy.is_true(b.args[2] == a):
# return b._claripy.true
# elif b._claripy.is_true(b.args[1] != a) and b._claripy.is_true(b.args[2] != a):
# return b._claripy.false

# Masking and comparing against a constant
simp = and_mask_comparing_against_constant_simplifier(operator.__eq__, a, b)
Expand Down Expand Up @@ -254,14 +246,14 @@ def eq_simplifier(a, b):
break

if claripy.is_false(a_bit == b_bit):
return claripy.false
return claripy.false()
return None
return None


def ne_simplifier(a, b):
if a is b:
return claripy.false
return claripy.false()

if a.op == "Reverse" and b.op == "Reverse":
return a.args[0] != b.args[0]
Expand All @@ -273,10 +265,6 @@ def ne_simplifier(a, b):
if a.args[1] is b and claripy.is_true(a.args[2] != b):
# (If(c, x, y) == y, x != y) -> !c
return claripy.Not(a.args[0])
# elif a._claripy.is_true(a.args[1] == b) and a._claripy.is_true(a.args[2] == b):
# return a._claripy.false
# elif a._claripy.is_true(a.args[1] != b) and a._claripy.is_true(a.args[2] != b):
# return a._claripy.true

if b.op == "If":
if b.args[2] is a and claripy.is_true(b.args[1] != a):
Expand All @@ -285,10 +273,6 @@ def ne_simplifier(a, b):
if b.args[1] is a and claripy.is_true(b.args[2] != a):
# (y == If(c, x, y)) -> !c
return claripy.Not(b.args[0])
# elif b._claripy.is_true(b.args[1] != a) and b._claripy.is_true(b.args[2] != a):
# return b._claripy.true
# elif b._claripy.is_true(b.args[1] == a) and b._claripy.is_true(b.args[2] == a):
# return b._claripy.false

# 1 ^ expr != 0 -> expr == 0
if a.op == "__xor__" and b.op == "BVV" and b.args[0] == 0 and len(a.args) == 2:
Expand Down Expand Up @@ -323,7 +307,7 @@ def ne_simplifier(a, b):
break

if claripy.is_true(a_bit != b_bit):
return claripy.true
return claripy.true()
return None
return None

Expand Down Expand Up @@ -385,14 +369,14 @@ def boolean_and_simplifier(*args):
for a in args:
if a.op == "BoolV":
if a.is_false():
return claripy.false
return claripy.false()
else:
new_args[ctr] = a
ctr += 1
new_args = new_args[:ctr]

if not new_args:
return claripy.true
return claripy.true()

if len(new_args) < len(args):
return claripy.And(*new_args)
Expand Down Expand Up @@ -459,11 +443,11 @@ def boolean_and_simplifier(*args):
if not eq_list:
return flattened
if any(any(ne is eq for eq in eq_list) for ne in ne_list):
return claripy.false
return claripy.false()
if all(v.op == "BVV" for v in eq_list) and all(v.op == "BVV" for v in ne_list):
mustbe = eq_list[0]
if any(eq.args[0] != mustbe.args[0] for eq in eq_list):
return claripy.false
return claripy.false()
return target_var == eq_list[0]
return flattened

Expand All @@ -475,12 +459,12 @@ def boolean_or_simplifier(*args):
new_args = []
for a in args:
if a.is_true():
return claripy.true
return claripy.true()
if not a.is_false():
new_args.append(a)

if not new_args:
return claripy.false
return claripy.false()
if len(new_args) < len(args):
return claripy.Or(*new_args)

Expand Down Expand Up @@ -1039,7 +1023,7 @@ def and_mask_comparing_against_constant_simplifier(op, a, b):
return None
return op(a_arg0[b_highbit_idx:0] & a_arg1.args[0], b_lower)
if b_higher_bits_are_0 is False:
return claripy.false if op is operator.__eq__ else claripy.true
return claripy.false() if op is operator.__eq__ else claripy.true()

return None

Expand Down Expand Up @@ -1109,7 +1093,7 @@ def zeroext_comparing_against_simplifier(op, a, b):
return op(a.args[1], b[b.size() - a_zeroext_bits - 1 : 0])
if (b_highbits == 0).is_false():
# unsat
return claripy.false if op is operator.__eq__ else claripy.true
return claripy.false() if op is operator.__eq__ else claripy.true()

if (
a.op == "Concat" and len(a.args) == 2 and a.args[0].op == "BVV" and a.args[0].args[0] == 0
Expand All @@ -1121,7 +1105,7 @@ def zeroext_comparing_against_simplifier(op, a, b):
return op(a.args[1], b[b.size() - a_zero_bits - 1 : 0])
if (b_highbits == 0).is_false():
# unsat
return claripy.false if op is operator.__eq__ else claripy.true
return claripy.false() if op is operator.__eq__ else claripy.true()

return None

Expand Down
6 changes: 3 additions & 3 deletions tests/test_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def assert_correct(a, b):

a, b, c = (claripy.BoolS(name) for name in ("a", "b", "c"))

assert_correct(claripy.And(a, claripy.Not(a)), claripy.false)
assert_correct(claripy.Or(a, claripy.Not(a)), claripy.true)
assert_correct(claripy.And(a, claripy.Not(a)), claripy.false())
assert_correct(claripy.Or(a, claripy.Not(a)), claripy.true())

complex_true_expression = claripy.Or(
claripy.And(a, b),
claripy.Or(claripy.And(a, claripy.Not(b)), claripy.And(claripy.Not(a), c)),
claripy.Or(claripy.And(a, claripy.Not(b)), claripy.And(claripy.Not(a), claripy.Not(c))),
)
assert_correct(complex_true_expression, claripy.true)
assert_correct(complex_true_expression, claripy.true())

def test_simplification(self):
def assert_correct(a, b):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def raw_replacement_solver(reuse_z3_solver):
b = claripy.BoolS("b")
assert sr._replacement(b) is b
sr.add(claripy.Not(b))
assert sr._replacement(b) is claripy.false
assert sr._replacement(b) is claripy.false()

sr = claripy.SolverReplacement(claripy.SolverVSA(), complex_auto_replace=True)
x = claripy.BVS("x", 64)
Expand Down Expand Up @@ -409,7 +409,7 @@ def raw_ancestor_merge(solver, reuse_z3_solver):
p.add(z == 1)
q.add(z == 2)

r = p.merge([q], [claripy.true, claripy.true])[-1]
r = p.merge([q], [claripy.true(), claripy.true()])[-1]
t = p.merge([q], [p.constraints[-1], q.constraints[-1]], common_ancestor=s)[-1]

if not isinstance(r, claripy.frontends.CompositeFrontend):
Expand Down Expand Up @@ -489,7 +489,7 @@ def test_unsatness(self):
s = claripy.Solver()
s.add(x == 10)
assert s.satisfiable()
s.add(claripy.false)
s.add(claripy.false())
assert not s.satisfiable()

def test_simplification_annotations(self):
Expand Down
Loading