From 5254cb4e3889d14bb65b12958ba8df6c68e4ef1a Mon Sep 17 00:00:00 2001 From: Kevin Phoenix Date: Fri, 20 Sep 2024 19:20:51 -0700 Subject: [PATCH] Make true and false functions --- claripy/ast/bool.py | 28 +++++------ .../constraint_filter_mixin.py | 2 +- claripy/frontend_mixins/model_cache_mixin.py | 4 +- claripy/frontend_mixins/sat_cache_mixin.py | 4 +- claripy/frontends/replacement_frontend.py | 4 +- claripy/simplifications.py | 50 +++++++------------ tests/test_simplify.py | 6 +-- tests/test_solver.py | 6 +-- 8 files changed, 42 insertions(+), 62 deletions(-) diff --git a/claripy/ast/bool.py b/claripy/ast/bool.py index 641dfd696..5e0c7ebca 100644 --- a/claripy/ast/bool.py +++ b/claripy/ast/bool.py @@ -2,6 +2,7 @@ import logging from contextlib import suppress +from functools import lru_cache from typing import TYPE_CHECKING, overload from claripy import operations @@ -16,8 +17,6 @@ l = logging.getLogger("claripy.ast.bool") -_boolv_cache = {} - class Bool(Base): __slots__ = () @@ -60,21 +59,18 @@ def BoolS(name, explicit_name=None) -> Bool: return Bool("BoolS", (n,), variables=frozenset((n,)), symbolic=True) +@lru_cache(maxsize=2) def BoolV(val) -> Bool: - try: - return _boolv_cache[(val)] - except KeyError: - result = Bool("BoolV", (val,)) - _boolv_cache[val] = result - return result + return Bool("BoolV", (val,)) -# -# some standard ASTs -# +def true(): + return BoolV(True) + + +def false(): + return BoolV(False) -true = BoolV(True) -false = BoolV(False) # # Bound operations @@ -145,9 +141,9 @@ def If(cond, true_value, false_value): if args[1] is args[2]: return args[1] - if args[1] is true and args[2] is false: + if args[1] is true() and args[2] is false(): return args[0] - if args[1] is false and args[2] is true: + if args[1] is false() and args[2] is true(): return ~args[0] if issubclass(ty, Bits): @@ -236,7 +232,7 @@ def reverse_ite_cases(ast): :param ast: :return: """ - queue = [(true, ast)] + queue = [(true(), ast)] while queue: condition, ast = queue.pop(0) if ast.op == "If": diff --git a/claripy/frontend_mixins/constraint_filter_mixin.py b/claripy/frontend_mixins/constraint_filter_mixin.py index 93226c237..df62ebc6b 100644 --- a/claripy/frontend_mixins/constraint_filter_mixin.py +++ b/claripy/frontend_mixins/constraint_filter_mixin.py @@ -22,7 +22,7 @@ def _add(self, constraints, invalidate_cache=True): ec = self._constraint_filter(constraints) except UnsatError: # filter out concrete False - ec = [c for c in constraints if c not in {False, false}] + [false] + ec = [c for c in constraints if c not in {False, false()}] + [false()] if len(constraints) == 0: return [] diff --git a/claripy/frontend_mixins/model_cache_mixin.py b/claripy/frontend_mixins/model_cache_mixin.py index 736fe746c..e60ffd93a 100644 --- a/claripy/frontend_mixins/model_cache_mixin.py +++ b/claripy/frontend_mixins/model_cache_mixin.py @@ -157,7 +157,7 @@ def __setstate__(self, base_state): def simplify(self): results = super().simplify() - if len(results) > 0 and any(c is false for c in results): + if len(results) > 0 and any(c is false() for c in results): self._models.clear() return results @@ -195,7 +195,7 @@ def _add(self, constraints, invalidate_cache=True): new_vars = any(a.variables - old_vars for a in added) if new_vars or invalidate_cache: # shortcut for unsat - if any(c is false for c in constraints): + if any(c is false() for c in constraints): self._models.clear() still_valid = set(self._get_models(extra_constraints=added)) diff --git a/claripy/frontend_mixins/sat_cache_mixin.py b/claripy/frontend_mixins/sat_cache_mixin.py index 97cedf4e4..75c73985b 100644 --- a/claripy/frontend_mixins/sat_cache_mixin.py +++ b/claripy/frontend_mixins/sat_cache_mixin.py @@ -30,7 +30,7 @@ def __setstate__(self, s): def _add(self, constraints, invalidate_cache=True): added = super()._add(constraints, invalidate_cache=invalidate_cache) - if len(added) > 0 and any(c is false for c in added): + if len(added) > 0 and any(c is false() for c in added): self._cached_satness = False elif self._cached_satness is True: self._cached_satness = None @@ -38,7 +38,7 @@ def _add(self, constraints, invalidate_cache=True): def simplify(self): new_constraints = super().simplify() - if len(new_constraints) > 0 and any(c is false for c in new_constraints): + if len(new_constraints) > 0 and any(c is false() for c in new_constraints): self._cached_satness = False return new_constraints diff --git a/claripy/frontends/replacement_frontend.py b/claripy/frontends/replacement_frontend.py index 57d14d78f..1a6e666b6 100644 --- a/claripy/frontends/replacement_frontend.py +++ b/claripy/frontends/replacement_frontend.py @@ -272,7 +272,7 @@ def _add(self, constraints, invalidate_cache=True): if not self._complex_auto_replace: if rc.op == "Not": - self.add_replacement(c.args[0], false, replace=False, promote=True, invalidate_cache=True) + self.add_replacement(c.args[0], false(), replace=False, promote=True, invalidate_cache=True) elif rc.op == "__eq__" and rc.args[0].symbolic ^ rc.args[1].symbolic: old, new = rc.args if rc.args[0].symbolic else rc.args[::-1] self.add_replacement(old, new, replace=False, promote=True, invalidate_cache=True) @@ -281,7 +281,7 @@ def _add(self, constraints, invalidate_cache=True): backends.vsa, rc, validation_frontend=self._validation_frontend ).compat_ret if not satisfiable: - self.add_replacement(rc, false) + self.add_replacement(rc, false()) for old, new in replacements: if old.cardinality == 1: continue diff --git a/claripy/simplifications.py b/claripy/simplifications.py index df070685c..93e6bbd0a 100644 --- a/claripy/simplifications.py +++ b/claripy/simplifications.py @@ -174,15 +174,15 @@ def lshift_simplifier(val, shift): def eq_simplifier(a, b): if a is b: - return claripy.true + return claripy.true() - if isinstance(a, claripy.ast.Bool) and b is claripy.true: + if isinstance(a, claripy.ast.Bool) and b is claripy.true(): return a - if isinstance(b, claripy.ast.Bool) and a is claripy.true: + if isinstance(b, claripy.ast.Bool) and a is claripy.true(): return b - if isinstance(a, claripy.ast.Bool) and b is claripy.false: + if isinstance(a, claripy.ast.Bool) and b is claripy.false(): return claripy.Not(a) - if isinstance(b, claripy.ast.Bool) and a is claripy.false: + if isinstance(b, claripy.ast.Bool) and a is claripy.false(): return claripy.Not(b) if a.op == "Reverse" and b.op == "Reverse": @@ -211,10 +211,6 @@ def eq_simplifier(a, b): if a.args[2] is b and claripy.is_true(a.args[1] != b): # (If(c, x, y) == y, x != y) -> !c return claripy.Not(a.args[0]) - # elif a._claripy.is_true(a.args[1] == b) and a._claripy.is_true(a.args[2] == b): - # return a._claripy.true - # elif a._claripy.is_true(a.args[1] != b) and a._claripy.is_true(a.args[2] != b): - # return a._claripy.false if b.op == "If": if b.args[1] is a and claripy.is_true(b.args[2] != b): @@ -223,10 +219,6 @@ def eq_simplifier(a, b): if b.args[2] is a and claripy.is_true(b.args[1] != a): # (y == If(c, x, y)) -> !c return claripy.Not(b.args[0]) - # elif b._claripy.is_true(b.args[1] == a) and b._claripy.is_true(b.args[2] == a): - # return b._claripy.true - # elif b._claripy.is_true(b.args[1] != a) and b._claripy.is_true(b.args[2] != a): - # return b._claripy.false # Masking and comparing against a constant simp = and_mask_comparing_against_constant_simplifier(operator.__eq__, a, b) @@ -254,14 +246,14 @@ def eq_simplifier(a, b): break if claripy.is_false(a_bit == b_bit): - return claripy.false + return claripy.false() return None return None def ne_simplifier(a, b): if a is b: - return claripy.false + return claripy.false() if a.op == "Reverse" and b.op == "Reverse": return a.args[0] != b.args[0] @@ -273,10 +265,6 @@ def ne_simplifier(a, b): if a.args[1] is b and claripy.is_true(a.args[2] != b): # (If(c, x, y) == y, x != y) -> !c return claripy.Not(a.args[0]) - # elif a._claripy.is_true(a.args[1] == b) and a._claripy.is_true(a.args[2] == b): - # return a._claripy.false - # elif a._claripy.is_true(a.args[1] != b) and a._claripy.is_true(a.args[2] != b): - # return a._claripy.true if b.op == "If": if b.args[2] is a and claripy.is_true(b.args[1] != a): @@ -285,10 +273,6 @@ def ne_simplifier(a, b): if b.args[1] is a and claripy.is_true(b.args[2] != a): # (y == If(c, x, y)) -> !c return claripy.Not(b.args[0]) - # elif b._claripy.is_true(b.args[1] != a) and b._claripy.is_true(b.args[2] != a): - # return b._claripy.true - # elif b._claripy.is_true(b.args[1] == a) and b._claripy.is_true(b.args[2] == a): - # return b._claripy.false # 1 ^ expr != 0 -> expr == 0 if a.op == "__xor__" and b.op == "BVV" and b.args[0] == 0 and len(a.args) == 2: @@ -323,7 +307,7 @@ def ne_simplifier(a, b): break if claripy.is_true(a_bit != b_bit): - return claripy.true + return claripy.true() return None return None @@ -385,14 +369,14 @@ def boolean_and_simplifier(*args): for a in args: if a.op == "BoolV": if a.is_false(): - return claripy.false + return claripy.false() else: new_args[ctr] = a ctr += 1 new_args = new_args[:ctr] if not new_args: - return claripy.true + return claripy.true() if len(new_args) < len(args): return claripy.And(*new_args) @@ -459,11 +443,11 @@ def boolean_and_simplifier(*args): if not eq_list: return flattened if any(any(ne is eq for eq in eq_list) for ne in ne_list): - return claripy.false + return claripy.false() if all(v.op == "BVV" for v in eq_list) and all(v.op == "BVV" for v in ne_list): mustbe = eq_list[0] if any(eq.args[0] != mustbe.args[0] for eq in eq_list): - return claripy.false + return claripy.false() return target_var == eq_list[0] return flattened @@ -475,12 +459,12 @@ def boolean_or_simplifier(*args): new_args = [] for a in args: if a.is_true(): - return claripy.true + return claripy.true() if not a.is_false(): new_args.append(a) if not new_args: - return claripy.false + return claripy.false() if len(new_args) < len(args): return claripy.Or(*new_args) @@ -1039,7 +1023,7 @@ def and_mask_comparing_against_constant_simplifier(op, a, b): return None return op(a_arg0[b_highbit_idx:0] & a_arg1.args[0], b_lower) if b_higher_bits_are_0 is False: - return claripy.false if op is operator.__eq__ else claripy.true + return claripy.false() if op is operator.__eq__ else claripy.true() return None @@ -1109,7 +1093,7 @@ def zeroext_comparing_against_simplifier(op, a, b): return op(a.args[1], b[b.size() - a_zeroext_bits - 1 : 0]) if (b_highbits == 0).is_false(): # unsat - return claripy.false if op is operator.__eq__ else claripy.true + return claripy.false() if op is operator.__eq__ else claripy.true() if ( a.op == "Concat" and len(a.args) == 2 and a.args[0].op == "BVV" and a.args[0].args[0] == 0 @@ -1121,7 +1105,7 @@ def zeroext_comparing_against_simplifier(op, a, b): return op(a.args[1], b[b.size() - a_zero_bits - 1 : 0]) if (b_highbits == 0).is_false(): # unsat - return claripy.false if op is operator.__eq__ else claripy.true + return claripy.false() if op is operator.__eq__ else claripy.true() return None diff --git a/tests/test_simplify.py b/tests/test_simplify.py index 619b57df3..a480307ba 100644 --- a/tests/test_simplify.py +++ b/tests/test_simplify.py @@ -13,15 +13,15 @@ def assert_correct(a, b): a, b, c = (claripy.BoolS(name) for name in ("a", "b", "c")) - assert_correct(claripy.And(a, claripy.Not(a)), claripy.false) - assert_correct(claripy.Or(a, claripy.Not(a)), claripy.true) + assert_correct(claripy.And(a, claripy.Not(a)), claripy.false()) + assert_correct(claripy.Or(a, claripy.Not(a)), claripy.true()) complex_true_expression = claripy.Or( claripy.And(a, b), claripy.Or(claripy.And(a, claripy.Not(b)), claripy.And(claripy.Not(a), c)), claripy.Or(claripy.And(a, claripy.Not(b)), claripy.And(claripy.Not(a), claripy.Not(c))), ) - assert_correct(complex_true_expression, claripy.true) + assert_correct(complex_true_expression, claripy.true()) def test_simplification(self): def assert_correct(a, b): diff --git a/tests/test_solver.py b/tests/test_solver.py index e76c6ffb3..698d43868 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -80,7 +80,7 @@ def raw_replacement_solver(reuse_z3_solver): b = claripy.BoolS("b") assert sr._replacement(b) is b sr.add(claripy.Not(b)) - assert sr._replacement(b) is claripy.false + assert sr._replacement(b) is claripy.false() sr = claripy.SolverReplacement(claripy.SolverVSA(), complex_auto_replace=True) x = claripy.BVS("x", 64) @@ -409,7 +409,7 @@ def raw_ancestor_merge(solver, reuse_z3_solver): p.add(z == 1) q.add(z == 2) - r = p.merge([q], [claripy.true, claripy.true])[-1] + r = p.merge([q], [claripy.true(), claripy.true()])[-1] t = p.merge([q], [p.constraints[-1], q.constraints[-1]], common_ancestor=s)[-1] if not isinstance(r, claripy.frontends.CompositeFrontend): @@ -489,7 +489,7 @@ def test_unsatness(self): s = claripy.Solver() s.add(x == 10) assert s.satisfiable() - s.add(claripy.false) + s.add(claripy.false()) assert not s.satisfiable() def test_simplification_annotations(self):