diff --git a/claripy/ast/base.py b/claripy/ast/base.py index 21514d0a4..c7a5aaa74 100644 --- a/claripy/ast/base.py +++ b/claripy/ast/base.py @@ -491,6 +491,14 @@ def _arg_serialize(arg: ArgType | Annotation) -> bytes: def __hash__(self) -> int: return self._hash + def hash(self) -> int: + """Python's built in hash function is not collision resistant, so we use our own. + When you call `hash(ast)`, the value is derived from the claripy hash, but it gets + passed through python's non-resistent hash function first. This skips that step, + allowing the claripy hash to be used directly, eg as a cache key. + """ + return self._hash + @property def cache_key(self: Self) -> ASTCacheKey[Self]: """ @@ -834,7 +842,7 @@ def children_asts(self) -> Iterator[Base]: if isinstance(ast, Base): ast_queue.append(iter(ast.args)) - l.debug("Yielding AST %s with hash %s with %d children", ast, hash(ast), len(ast.args)) + l.debug("Yielding AST %s with hash %s with %d children", ast, ast.hash(), len(ast.args)) yield ast def leaf_asts(self) -> Iterator[Base]: @@ -862,14 +870,14 @@ def is_leaf(self) -> bool: """ return self.depth == 1 - def dbg_is_looped(self) -> bool: - l.debug("Checking AST with hash %s for looping", hash(self)) + def dbg_is_looped(self) -> Base | bool: # TODO: this return type is bad + l.debug("Checking AST with hash %s for looping", self.hash()) seen = set() for child_ast in self.children_asts(): - if hash(child_ast) in seen: + if child_ast.hash() in seen: return child_ast - seen.add(hash(child_ast)) + seen.add(child_ast.hash()) return False diff --git a/claripy/frontend_mixins/constraint_deduplicator_mixin.py b/claripy/frontend_mixins/constraint_deduplicator_mixin.py index 53c9b5651..a084cb4af 100644 --- a/claripy/frontend_mixins/constraint_deduplicator_mixin.py +++ b/claripy/frontend_mixins/constraint_deduplicator_mixin.py @@ -26,14 +26,14 @@ def simplify(self): # we only add to the constraint hashes because we want to # prevent previous (now simplified) constraints from # being re-added - self._constraint_hashes.update(map(hash, added)) + self._constraint_hashes.update(c.hash() for c in added) return added def _add(self, constraints, invalidate_cache=True): - filtered = tuple(c for c in constraints if hash(c) not in self._constraint_hashes) + filtered = tuple(c for c in constraints if c.hash() not in self._constraint_hashes) if len(filtered) == 0: return filtered added = super()._add(filtered, invalidate_cache=invalidate_cache) - self._constraint_hashes.update(map(hash, added)) + self._constraint_hashes.update(c.hash() for c in added) return added diff --git a/tests/test_expression.py b/tests/test_expression.py index 42d9d0427..bd0594443 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -119,9 +119,9 @@ def test_expression(self): new_formula = old_formula.replace(old, new) ooo_formula = new_formula.replace(new, ooo) - self.assertNotEqual(hash(old_formula), hash(new_formula)) - self.assertNotEqual(hash(old_formula), hash(ooo_formula)) - self.assertNotEqual(hash(new_formula), hash(ooo_formula)) + self.assertNotEqual(old_formula.hash(), new_formula.hash()) + self.assertNotEqual(old_formula.hash(), ooo_formula.hash()) + self.assertNotEqual(new_formula.hash(), ooo_formula.hash()) self.assertEqual(old_formula.variables, frozenset(("old",))) self.assertEqual(new_formula.variables, frozenset(("new",))) diff --git a/tests/test_serial.py b/tests/test_serial.py index f825185f7..b3174fc03 100644 --- a/tests/test_serial.py +++ b/tests/test_serial.py @@ -67,8 +67,8 @@ def test_identity(self): s.add(x == 3) s.finalize() ss = pickle.loads(pickle.dumps(s)) - old_constraint_sets = [[hash(j) for j in k.constraints] for k in s._solver_list] - new_constraint_sets = [[hash(j) for j in k.constraints] for k in ss._solver_list] + old_constraint_sets = [[j.hash() for j in k.constraints] for k in s._solver_list] + new_constraint_sets = [[j.hash() for j in k.constraints] for k in ss._solver_list] assert old_constraint_sets == new_constraint_sets assert str(s.variables) == str(ss.variables)