Skip to content

Commit

Permalink
Move excavate_ite and burrow_ite to algorithms subpackage (#537)
Browse files Browse the repository at this point in the history
  • Loading branch information
twizmwazin authored Oct 3, 2024
1 parent 0393bd7 commit 4cff9ef
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 169 deletions.
4 changes: 3 additions & 1 deletion claripy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from claripy import algorithm, ast
from claripy.algorithm import is_false, is_true, simplify
from claripy.algorithm import burrow_ite, excavate_ite, is_false, is_true, simplify
from claripy.annotation import Annotation, RegionAnnotation, SimplificationAvoidanceAnnotation
from claripy.ast.bool import (
And,
Expand Down Expand Up @@ -210,4 +210,6 @@
"SolverReplacement",
"SolverStrings",
"SolverVSA",
"burrow_ite",
"excavate_ite",
)
3 changes: 3 additions & 0 deletions claripy/algorithm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from .bool_check import is_false, is_true
from .ite_relocation import burrow_ite, excavate_ite
from .simplify import simplify

__all__ = (
"burrow_ite",
"excavate_ite",
"is_false",
"is_true",
"simplify",
Expand Down
161 changes: 161 additions & 0 deletions claripy/algorithm/ite_relocation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from __future__ import annotations

from typing import TypeVar, cast
from weakref import WeakValueDictionary

import claripy
from claripy.ast.base import Base

T = TypeVar("T", bound=Base)

#
# This code handles burrowing ITEs deeper into the ast and excavating
# them to shallower levels.
#

burrowed_cache: WeakValueDictionary[int, Base] = WeakValueDictionary()
excavated_cache: WeakValueDictionary[int, Base] = WeakValueDictionary()


def _burrow_ite(expr: T) -> T:
if expr.op != "If":
return expr.swap_args([(burrow_ite(a) if isinstance(a, Base) else a) for a in expr.args])

if not all(isinstance(a, Base) for a in expr.args):
return expr

old_true = expr.args[1]
old_false = expr.args[2]

if old_true.op != old_false.op or len(old_true.args) != len(old_false.args):
return expr

if old_true.op == "If":
# let's no go into this right now
return expr

if any(a.is_leaf() for a in expr.args):
# burrowing through these is pretty funny
return expr

matches = [old_true.args[i] is old_false.args[i] for i in range(len(old_true.args))]
if matches.count(True) != 1 or all(matches):
# TODO: handle multiple differences for multi-arg ast nodes
# print("wrong number of matches:",matches,old_true,old_false)
return expr

different_idx = matches.index(False)
inner_if = claripy.If(expr.args[0], old_true.args[different_idx], old_false.args[different_idx])
new_args = list(old_true.args)
new_args[different_idx] = burrow_ite(inner_if)
return old_true.__class__(old_true.op, new_args, length=expr.length)


def _excavate_ite(expr: T) -> T:
ast_queue = [iter([expr])]
arg_queue = []
op_queue = []

while ast_queue:
try:
ast = next(ast_queue[-1])

if not isinstance(ast, Base):
arg_queue.append(ast)
continue

if ast.is_leaf():
arg_queue.append(ast)
continue

if ast.annotations:
arg_queue.append(ast)
continue

op_queue.append(ast)
ast_queue.append(iter(ast.args))

except StopIteration:
ast_queue.pop()

if op_queue:
op = op_queue.pop()

args = arg_queue[-len(op.args) :]
del arg_queue[-len(op.args) :]

ite_args = [isinstance(a, Base) and a.op == "If" for a in args]

if op.op == "If":
# if we are an If, call the If handler so that we can take advantage of its simplifiers
excavated = claripy.If(*args)

elif ite_args.count(True) == 0:
# if there are no ifs that came to the surface, there's nothing more to do
excavated = op.swap_args(args, simplify=True)

else:
# this gets called when we're *not* in an If, but there are Ifs in the args.
# it pulls those Ifs out to the surface.
cond = args[ite_args.index(True)].args[0]
new_true_args = []
new_false_args = []

for a in args:
if not isinstance(a, Base) or a.op != "If":
new_true_args.append(a)
new_false_args.append(a)
elif a.args[0] is cond:
new_true_args.append(a.args[1])
new_false_args.append(a.args[2])
elif a.args[0] is ~cond:
new_true_args.append(a.args[2])
new_false_args.append(a.args[1])
else:
# weird conditions -- giving up!
excavated = op.swap_args(args, simplify=True)
break

else:
excavated = claripy.If(
cond,
op.swap_args(new_true_args, simplify=True),
op.swap_args(new_false_args, simplify=True),
)

# continue
arg_queue.append(excavated)

assert len(op_queue) == 0, "op_queue is not empty"
assert len(ast_queue) == 0, "ast_queue is not empty"
assert len(arg_queue) == 1, ("arg_queue has unexpected length", len(arg_queue))

return arg_queue.pop()


def burrow_ite(expr: T) -> T:
"""
Returns an equivalent AST that "burrows" the ITE expressions as deep as
possible into the ast, for simpler printing.
"""
if expr.hash() in burrowed_cache and burrowed_cache[expr.hash()] is not None:
return cast(T, burrowed_cache[expr.hash()])

burrowed = _burrow_ite(expr)
burrowed_cache[burrowed.hash()] = burrowed
burrowed_cache[expr.hash()] = burrowed
return burrowed


def excavate_ite(expr: T) -> T:
"""
Returns an equivalent AST that "excavates" the ITE expressions out as far as
possible toward the root of the AST, for processing in static analyses.
"""
if expr.hash() in excavated_cache and excavated_cache[expr.hash()] is not None:
return cast(T, excavated_cache[expr.hash()])

excavated = _excavate_ite(expr)
excavated_cache[excavated.hash()] = excavated
excavated_cache[expr.hash()] = excavated
return excavated
154 changes: 0 additions & 154 deletions claripy/ast/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,6 @@ class Base:
_cached_encoded_name: bytes | None

# Extra information
_excavated: Base | None
_burrowed: Base | None
_uninitialized: bool

__slots__ = [
Expand All @@ -163,8 +161,6 @@ class Base:
"_errored",
"_cache_key",
"_cached_encoded_name",
"_excavated",
"_burrowed",
"_uninitialized",
"__weakref__",
]
Expand Down Expand Up @@ -392,8 +388,6 @@ def __a_init__(

self._simplified = simplified
self._cache_key = ASTCacheKey(self)
self._excavated = None
self._burrowed = None

self._uninitialized = uninitialized

Expand Down Expand Up @@ -1055,154 +1049,6 @@ def canonicalize(self, var_map=None, counter=None) -> Self:

return var_map, counter, self.replace_dict(var_map)

#
# This code handles burrowing ITEs deeper into the ast and excavating
# them to shallower levels.
#

def _burrow_ite(self):
if self.op != "If":
# print("i'm not an if")
return self.swap_args([(a.ite_burrowed if isinstance(a, Base) else a) for a in self.args])

if not all(isinstance(a, Base) for a in self.args):
# print("not all my args are bases")
return self

old_true = self.args[1]
old_false = self.args[2]

if old_true.op != old_false.op or len(old_true.args) != len(old_false.args):
return self

if old_true.op == "If":
# let's no go into this right now
return self

if any(a.is_leaf() for a in self.args):
# burrowing through these is pretty funny
return self

matches = [old_true.args[i] is old_false.args[i] for i in range(len(old_true.args))]
if matches.count(True) != 1 or all(matches):
# TODO: handle multiple differences for multi-arg ast nodes
# print("wrong number of matches:",matches,old_true,old_false)
return self

different_idx = matches.index(False)
inner_if = claripy.If(self.args[0], old_true.args[different_idx], old_false.args[different_idx])
new_args = list(old_true.args)
new_args[different_idx] = inner_if.ite_burrowed
# print("replaced the",different_idx,"arg:",new_args)
return old_true.__class__(old_true.op, new_args, length=self.length)

def _excavate_ite(self):
ast_queue = [iter([self])]
arg_queue = []
op_queue = []

while ast_queue:
try:
ast = next(ast_queue[-1])

if not isinstance(ast, Base):
arg_queue.append(ast)
continue

if ast.is_leaf():
arg_queue.append(ast)
continue

if ast.annotations:
arg_queue.append(ast)
continue

op_queue.append(ast)
ast_queue.append(iter(ast.args))

except StopIteration:
ast_queue.pop()

if op_queue:
op = op_queue.pop()

args = arg_queue[-len(op.args) :]
del arg_queue[-len(op.args) :]

ite_args = [isinstance(a, Base) and a.op == "If" for a in args]

if op.op == "If":
# if we are an If, call the If handler so that we can take advantage of its simplifiers
excavated = claripy.If(*args)

elif ite_args.count(True) == 0:
# if there are no ifs that came to the surface, there's nothing more to do
excavated = op.swap_args(args, simplify=True)

else:
# this gets called when we're *not* in an If, but there are Ifs in the args.
# it pulls those Ifs out to the surface.
cond = args[ite_args.index(True)].args[0]
new_true_args = []
new_false_args = []

for a in args:
if not isinstance(a, Base) or a.op != "If":
new_true_args.append(a)
new_false_args.append(a)
elif a.args[0] is cond:
new_true_args.append(a.args[1])
new_false_args.append(a.args[2])
elif a.args[0] is ~cond:
new_true_args.append(a.args[2])
new_false_args.append(a.args[1])
else:
# weird conditions -- giving up!
excavated = op.swap_args(args, simplify=True)
break

else:
excavated = claripy.If(
cond,
op.swap_args(new_true_args, simplify=True),
op.swap_args(new_false_args, simplify=True),
)

# continue
arg_queue.append(excavated)

assert len(op_queue) == 0, "op_queue is not empty"
assert len(ast_queue) == 0, "ast_queue is not empty"
assert len(arg_queue) == 1, ("arg_queue has unexpected length", len(arg_queue))

return arg_queue.pop()

@property
def ite_burrowed(self: T) -> T:
"""
Returns an equivalent AST that "burrows" the ITE expressions as deep as possible into the ast, for simpler
printing.
"""
if self._burrowed is None:
self._burrowed = self._burrow_ite() # pylint:disable=attribute-defined-outside-init
self._burrowed._burrowed = self._burrowed # pylint:disable=attribute-defined-outside-init
return self._burrowed

@property
def ite_excavated(self: T) -> T:
"""
Returns an equivalent AST that "excavates" the ITE expressions out as far as possible toward the root of the
AST, for processing in static analyses.
"""
if self._excavated is None:
self._excavated = self._excavate_ite() # pylint:disable=attribute-defined-outside-init

# we set the flag for the children so that we avoid re-excavating during
# VSA backend evaluation (since the backend evaluation recursively works on
# the excavated ASTs)
self._excavated._excavated = self._excavated
return self._excavated

#
# these are convenience operations
#
Expand Down
5 changes: 4 additions & 1 deletion claripy/ast/bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from functools import lru_cache
from typing import TYPE_CHECKING, overload
from typing import TYPE_CHECKING, TypeVar, overload

import claripy
from claripy import operations
Expand All @@ -16,6 +16,7 @@
from .bv import BV
from .fp import FP

T = TypeVar("T", bound=Base)
log = logging.getLogger(__name__)


Expand Down Expand Up @@ -95,6 +96,8 @@ def If(cond: bool | Bool, true_value: bool | Bool, false_value: bool | Bool) ->
def If(cond: bool | Bool, true_value: int | BV, false_value: int | BV) -> BV: ...
@overload
def If(cond: bool | Bool, true_value: float | FP, false_value: float | FP) -> FP: ...
@overload
def If(cond: bool | Bool, true_value: T, false_value: T) -> T: ...


def If(cond, true_value, false_value):
Expand Down
Loading

0 comments on commit 4cff9ef

Please sign in to comment.