Skip to content

Commit

Permalink
Move replace and replace_dict to algorithms sub-package (#545)
Browse files Browse the repository at this point in the history
* Move replace and replace_dict to algorithms sub-package

* Remove unused expr arg from _check_replaceability
  • Loading branch information
twizmwazin authored Oct 4, 2024
1 parent baeced9 commit ba40a8a
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 100 deletions.
4 changes: 3 additions & 1 deletion claripy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from claripy import algorithm, ast, backends
from claripy.algorithm import burrow_ite, excavate_ite, is_false, is_true, simplify
from claripy.algorithm import burrow_ite, excavate_ite, is_false, is_true, replace, replace_dict, simplify
from claripy.annotation import Annotation, RegionAnnotation, SimplificationAvoidanceAnnotation
from claripy.ast.bool import (
And,
Expand Down Expand Up @@ -213,4 +213,6 @@
"burrow_ite",
"excavate_ite",
"backends",
"replace",
"replace_dict",
)
3 changes: 3 additions & 0 deletions claripy/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

from .bool_check import is_false, is_true
from .ite_relocation import burrow_ite, excavate_ite
from .replace import replace, replace_dict
from .simplify import simplify

__all__ = (
"burrow_ite",
"excavate_ite",
"is_false",
"is_true",
"replace",
"replace_dict",
"simplify",
)
99 changes: 99 additions & 0 deletions claripy/algorithm/replace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import annotations

from typing import TYPE_CHECKING, TypeVar

from claripy.ast.base import Base
from claripy.errors import ClaripyReplacementError

if TYPE_CHECKING:
from collections.abc import Callable

T = TypeVar("T", bound="Base")


def replace_dict(
expr: Base,
replacements: dict[int, Base],
variable_set: set[str] | None = None,
leaf_operation: Callable[[Base], Base] = lambda x: x,
) -> Base:
"""
Returns this AST with subexpressions replaced by those that can be found in `replacements`
dict.
:param variable_set: For optimization, ast's without these variables are not checked
for replacing.
:param replacements: A dictionary of hashes to their replacements.
:param leaf_operation: An operation that should be applied to the leaf nodes.
:returns: An AST with all instances of ast's in replacements.
"""
if variable_set is None:
variable_set = set()

arg_queue = [iter([expr])]
rep_queue = []
ast_queue = []

while arg_queue:
try:
ast = next(arg_queue[-1])
repl = ast

if not isinstance(ast, Base):
rep_queue.append(repl)
continue

if ast.hash() in replacements:
repl = replacements[ast.hash()]

elif ast.variables >= variable_set:
if ast.is_leaf():
repl = leaf_operation(ast)
if repl is not ast:
replacements[ast.hash()] = repl

elif ast.depth > 1:
arg_queue.append(iter(ast.args))
ast_queue.append(ast)
continue

rep_queue.append(repl)

except StopIteration:
arg_queue.pop()

if ast_queue:
ast = ast_queue.pop()
repl = ast

args = rep_queue[-len(ast.args) :]
del rep_queue[-len(ast.args) :]

# Check if replacement occurred.
if any((a is not b for a, b in zip(ast.args, args, strict=False))):
repl = ast.make_like(ast.op, tuple(args))
replacements[ast.hash()] = repl

rep_queue.append(repl)

assert len(arg_queue) == 0, "arg_queue is not empty"
assert len(ast_queue) == 0, "ast_queue is not empty"
assert len(rep_queue) == 1, ("rep_queue has unexpected length", len(rep_queue))

return rep_queue.pop()


def _check_replaceability(old: T, new: T) -> None:
if not isinstance(old, Base) or not isinstance(new, Base):
raise ClaripyReplacementError("replacements must be AST nodes")
if type(old) is not type(new):
raise ClaripyReplacementError(f"cannot replace type {type(old)} ast with type {type(new)} ast")


def replace(expr: Base, old: T, new: T) -> Base:
"""
Returns this AST but with the AST 'old' replaced with AST 'new' in its subexpressions.
"""
_check_replaceability(old, new)
replacements: dict[int, Base] = {old.hash(): new}
return replace_dict(expr, replacements, variable_set=old.variables)
89 changes: 3 additions & 86 deletions claripy/ast/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

import claripy
from claripy import operations
from claripy.errors import BackendError, ClaripyOperationError, ClaripyReplacementError
from claripy.errors import BackendError, ClaripyOperationError
from claripy.fp import FSort

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator
from collections.abc import Iterable, Iterator

from claripy.annotation import Annotation
from claripy.backends import Backend
Expand Down Expand Up @@ -912,89 +912,6 @@ def structurally_match(self, o: Base) -> bool:

return True

def replace_dict(
self, replacements, variable_set: set[str] | None = None, leaf_operation: Callable[[Base], Base] = lambda x: x
) -> Self:
"""
Returns this AST with subexpressions replaced by those that can be found in `replacements`
dict.
:param variable_set: For optimization, ast's without these variables are not checked
for replacing.
:param replacements: A dictionary of hashes to their replacements.
:param leaf_operation: An operation that should be applied to the leaf nodes.
:returns: An AST with all instances of ast's in replacements.
"""
if variable_set is None:
variable_set = set()

arg_queue = [iter([self])]
rep_queue = []
ast_queue = []

while arg_queue:
try:
ast = next(arg_queue[-1])
repl = ast

if not isinstance(ast, Base):
rep_queue.append(repl)
continue

if ast.hash() in replacements:
repl = replacements[ast.hash()]

elif ast.variables >= variable_set:
if ast.is_leaf():
repl = leaf_operation(ast)
if repl is not ast:
replacements[ast.hash()] = repl

elif ast.depth > 1:
arg_queue.append(iter(ast.args))
ast_queue.append(ast)
continue

rep_queue.append(repl)

except StopIteration:
arg_queue.pop()

if ast_queue:
ast = ast_queue.pop()
repl = ast

args = rep_queue[-len(ast.args) :]
del rep_queue[-len(ast.args) :]

# Check if replacement occurred.
if any((a is not b for a, b in zip(ast.args, args, strict=False))):
repl = ast.make_like(ast.op, tuple(args))
replacements[ast.hash()] = repl

rep_queue.append(repl)

assert len(arg_queue) == 0, "arg_queue is not empty"
assert len(ast_queue) == 0, "ast_queue is not empty"
assert len(rep_queue) == 1, ("rep_queue has unexpected length", len(rep_queue))

return rep_queue.pop()

def replace(self, old: T, new: T) -> Self:
"""
Returns this AST but with the AST 'old' replaced with AST 'new' in its subexpressions.
"""
self._check_replaceability(old, new)
replacements = {old.hash(): new}
return self.replace_dict(replacements, variable_set=old.variables)

@staticmethod
def _check_replaceability(old: T, new: T) -> None:
if not isinstance(old, Base) or not isinstance(new, Base):
raise ClaripyReplacementError("replacements must be AST nodes")
if type(old) is not type(new):
raise ClaripyReplacementError(f"cannot replace type {type(old)} ast with type {type(new)} ast")

def canonicalize(self, var_map=None, counter=None) -> Self:
counter = itertools.count() if counter is None else counter
var_map = {} if var_map is None else var_map
Expand All @@ -1004,7 +921,7 @@ def canonicalize(self, var_map=None, counter=None) -> Self:
new_name = "canonical_%d" % next(counter)
var_map[v.hash()] = v._rename(new_name)

return var_map, counter, self.replace_dict(var_map)
return var_map, counter, claripy.replace_dict(self, var_map)

#
# these are convenience operations
Expand Down
6 changes: 4 additions & 2 deletions claripy/frontend/mixin/model_cache_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ def eval_ast(self, ast, allow_unconstrained: bool = True):
"""

if allow_unconstrained:
new_ast = ast.replace_dict(self.replacements, leaf_operation=self._leaf_op)
new_ast = claripy.replace_dict(ast, self.replacements, leaf_operation=self._leaf_op)
else:
new_ast = ast.replace_dict(self.constraint_only_replacements, leaf_operation=self._leaf_op_existonly)
new_ast = claripy.replace_dict(
ast, self.constraint_only_replacements, leaf_operation=self._leaf_op_existonly
)
return backends.concrete.eval(new_ast, 1)[0]

def eval_constraints(self, constraints):
Expand Down
3 changes: 2 additions & 1 deletion claripy/frontend/replacement_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numbers
from contextlib import suppress

import claripy
from claripy import backends
from claripy.ast.base import Base
from claripy.ast.bool import BoolV, false
Expand Down Expand Up @@ -111,7 +112,7 @@ def _replacement(self, old):
return self._replacement_cache[old.hash()]

# not found in the cache
new = old.replace_dict(self._replacement_cache)
new = claripy.replace_dict(old, self._replacement_cache)
if new is not old:
self._replacement_cache[old.hash()] = new
return new
Expand Down
20 changes: 10 additions & 10 deletions tests/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def test_expression(self):
ooo = claripy.BVV(0, 32)

old_formula = claripy.If((old + 1) % 256 == 0, old + 10, old + 20)
new_formula = old_formula.replace(old, new)
ooo_formula = new_formula.replace(new, ooo)
new_formula = claripy.replace(old_formula, old, new)
ooo_formula = claripy.replace(new_formula, new, ooo)

self.assertNotEqual(old_formula.hash(), new_formula.hash())
self.assertNotEqual(old_formula.hash(), ooo_formula.hash())
Expand All @@ -139,7 +139,7 @@ def test_expression(self):
new = claripy.BVS("new", 32, explicit_name=True)
c = (old + 10) - (old + 20)
d = (old + 1) - (old + 2)
cr = c.replace_dict({(old + 10).hash(): (old + 1), (old + 20).hash(): (old + 2)})
cr = claripy.replace_dict(c, {(old + 10).hash(): (old + 1), (old + 20).hash(): (old + 2)})
self.assertIs(cr, d)

# test AST collapse
Expand Down Expand Up @@ -423,12 +423,12 @@ def test_multiarg(self):
assert len(x_xor.args) == 4
assert len(x_and.args) == 4

assert (x_add).replace(x, o).args[0] == 8
assert (x_mul).replace(x, o).args[0] == 16
assert (x_or).replace(x, o).args[0] == 7
assert (x_xor).replace(x, o).args[0] == 0
assert (x_and).replace(x, o).args[0] == 0
assert (100 + (x_sub).replace(x, o)).args[0] == 90
assert claripy.replace(x_add, x, o).args[0] == 8
assert claripy.replace(x_mul, x, o).args[0] == 16
assert claripy.replace(x_or, x, o).args[0] == 7
assert claripy.replace(x_xor, x, o).args[0] == 0
assert claripy.replace(x_and, x, o).args[0] == 0
assert (100 + claripy.replace(x_sub, x, o)).args[0] == 90

# make sure that z3 and vsa backends handle this properly
claripy.backends.z3.convert(x + x + x + x)
Expand Down Expand Up @@ -560,7 +560,7 @@ def test_bool_conversion(self):
def test_bool_replace_in_ite(self):
b = claripy.BoolS("b")
expr = claripy.If(b, claripy.BVV(2, 32), claripy.BVV(3, 32))
new_expr = expr.replace(b, claripy.BoolV(True))
new_expr = claripy.replace(expr, b, claripy.BoolV(True))

# Replace calls make_like which will simplify the expression. As a
# result, the new expression will be a BVV.
Expand Down

0 comments on commit ba40a8a

Please sign in to comment.