From 28b2b0638a36b972c8f0dee268649bbd1e0115a4 Mon Sep 17 00:00:00 2001 From: Kevin Phoenix Date: Thu, 3 Oct 2024 15:52:33 -0700 Subject: [PATCH] Move excavate_ite and burrow_ite to algorithms subpackage --- claripy/__init__.py | 4 +- claripy/algorithm/__init__.py | 3 + claripy/algorithm/ite_relocation.py | 161 ++++++++++++++++++++ claripy/ast/base.py | 154 ------------------- claripy/ast/bool.py | 5 +- claripy/backends/backend_vsa/backend_vsa.py | 3 +- claripy/balancer.py | 2 +- tests/test_expression.py | 9 +- tests/test_vsa.py | 14 +- 9 files changed, 186 insertions(+), 169 deletions(-) create mode 100644 claripy/algorithm/ite_relocation.py diff --git a/claripy/__init__.py b/claripy/__init__.py index 12ab00aae..10b7238ba 100644 --- a/claripy/__init__.py +++ b/claripy/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from claripy import algorithm, ast -from claripy.algorithm import is_false, is_true, simplify +from claripy.algorithm import burrow_ite, excavate_ite, is_false, is_true, simplify from claripy.annotation import Annotation, RegionAnnotation, SimplificationAvoidanceAnnotation from claripy.ast.bool import ( And, @@ -210,4 +210,6 @@ "SolverReplacement", "SolverStrings", "SolverVSA", + "burrow_ite", + "excavate_ite", ) diff --git a/claripy/algorithm/__init__.py b/claripy/algorithm/__init__.py index f2f139ed0..de82a6003 100644 --- a/claripy/algorithm/__init__.py +++ b/claripy/algorithm/__init__.py @@ -1,9 +1,12 @@ from __future__ import annotations from .bool_check import is_false, is_true +from .ite_relocation import burrow_ite, excavate_ite from .simplify import simplify __all__ = ( + "burrow_ite", + "excavate_ite", "is_false", "is_true", "simplify", diff --git a/claripy/algorithm/ite_relocation.py b/claripy/algorithm/ite_relocation.py new file mode 100644 index 000000000..79bd215cd --- /dev/null +++ b/claripy/algorithm/ite_relocation.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from typing import TypeVar, cast +from weakref import WeakValueDictionary + +import claripy +from claripy.ast.base import Base + +T = TypeVar("T", bound=Base) + +# +# This code handles burrowing ITEs deeper into the ast and excavating +# them to shallower levels. +# + +burrowed_cache: WeakValueDictionary[int, Base] = WeakValueDictionary() +excavated_cache: WeakValueDictionary[int, Base] = WeakValueDictionary() + + +def _burrow_ite(expr: T) -> T: + if expr.op != "If": + return expr.swap_args([(burrow_ite(a) if isinstance(a, Base) else a) for a in expr.args]) + + if not all(isinstance(a, Base) for a in expr.args): + return expr + + old_true = expr.args[1] + old_false = expr.args[2] + + if old_true.op != old_false.op or len(old_true.args) != len(old_false.args): + return expr + + if old_true.op == "If": + # let's no go into this right now + return expr + + if any(a.is_leaf() for a in expr.args): + # burrowing through these is pretty funny + return expr + + matches = [old_true.args[i] is old_false.args[i] for i in range(len(old_true.args))] + if matches.count(True) != 1 or all(matches): + # TODO: handle multiple differences for multi-arg ast nodes + # print("wrong number of matches:",matches,old_true,old_false) + return expr + + different_idx = matches.index(False) + inner_if = claripy.If(expr.args[0], old_true.args[different_idx], old_false.args[different_idx]) + new_args = list(old_true.args) + new_args[different_idx] = burrow_ite(inner_if) + return old_true.__class__(old_true.op, new_args, length=expr.length) + + +def _excavate_ite(expr: T) -> T: + ast_queue = [iter([expr])] + arg_queue = [] + op_queue = [] + + while ast_queue: + try: + ast = next(ast_queue[-1]) + + if not isinstance(ast, Base): + arg_queue.append(ast) + continue + + if ast.is_leaf(): + arg_queue.append(ast) + continue + + if ast.annotations: + arg_queue.append(ast) + continue + + op_queue.append(ast) + ast_queue.append(iter(ast.args)) + + except StopIteration: + ast_queue.pop() + + if op_queue: + op = op_queue.pop() + + args = arg_queue[-len(op.args) :] + del arg_queue[-len(op.args) :] + + ite_args = [isinstance(a, Base) and a.op == "If" for a in args] + + if op.op == "If": + # if we are an If, call the If handler so that we can take advantage of its simplifiers + excavated = claripy.If(*args) + + elif ite_args.count(True) == 0: + # if there are no ifs that came to the surface, there's nothing more to do + excavated = op.swap_args(args, simplify=True) + + else: + # this gets called when we're *not* in an If, but there are Ifs in the args. + # it pulls those Ifs out to the surface. + cond = args[ite_args.index(True)].args[0] + new_true_args = [] + new_false_args = [] + + for a in args: + if not isinstance(a, Base) or a.op != "If": + new_true_args.append(a) + new_false_args.append(a) + elif a.args[0] is cond: + new_true_args.append(a.args[1]) + new_false_args.append(a.args[2]) + elif a.args[0] is ~cond: + new_true_args.append(a.args[2]) + new_false_args.append(a.args[1]) + else: + # weird conditions -- giving up! + excavated = op.swap_args(args, simplify=True) + break + + else: + excavated = claripy.If( + cond, + op.swap_args(new_true_args, simplify=True), + op.swap_args(new_false_args, simplify=True), + ) + + # continue + arg_queue.append(excavated) + + assert len(op_queue) == 0, "op_queue is not empty" + assert len(ast_queue) == 0, "ast_queue is not empty" + assert len(arg_queue) == 1, ("arg_queue has unexpected length", len(arg_queue)) + + return arg_queue.pop() + + +def burrow_ite(expr: T) -> T: + """ + Returns an equivalent AST that "burrows" the ITE expressions as deep as + possible into the ast, for simpler printing. + """ + if expr.hash() in burrowed_cache and burrowed_cache[expr.hash()] is not None: + return cast(T, burrowed_cache[expr.hash()]) + + burrowed = _burrow_ite(expr) + burrowed_cache[burrowed.hash()] = burrowed + burrowed_cache[expr.hash()] = burrowed + return burrowed + + +def excavate_ite(expr: T) -> T: + """ + Returns an equivalent AST that "excavates" the ITE expressions out as far as + possible toward the root of the AST, for processing in static analyses. + """ + if expr.hash() in excavated_cache and excavated_cache[expr.hash()] is not None: + return cast(T, excavated_cache[expr.hash()]) + + excavated = _excavate_ite(expr) + excavated_cache[excavated.hash()] = excavated + excavated_cache[expr.hash()] = excavated + return excavated diff --git a/claripy/ast/base.py b/claripy/ast/base.py index 5a62444e8..84547ed15 100644 --- a/claripy/ast/base.py +++ b/claripy/ast/base.py @@ -144,8 +144,6 @@ class Base: _cached_encoded_name: bytes | None # Extra information - _excavated: Base | None - _burrowed: Base | None _uninitialized: bool __slots__ = [ @@ -163,8 +161,6 @@ class Base: "_errored", "_cache_key", "_cached_encoded_name", - "_excavated", - "_burrowed", "_uninitialized", "__weakref__", ] @@ -392,8 +388,6 @@ def __a_init__( self._simplified = simplified self._cache_key = ASTCacheKey(self) - self._excavated = None - self._burrowed = None self._uninitialized = uninitialized @@ -1055,154 +1049,6 @@ def canonicalize(self, var_map=None, counter=None) -> Self: return var_map, counter, self.replace_dict(var_map) - # - # This code handles burrowing ITEs deeper into the ast and excavating - # them to shallower levels. - # - - def _burrow_ite(self): - if self.op != "If": - # print("i'm not an if") - return self.swap_args([(a.ite_burrowed if isinstance(a, Base) else a) for a in self.args]) - - if not all(isinstance(a, Base) for a in self.args): - # print("not all my args are bases") - return self - - old_true = self.args[1] - old_false = self.args[2] - - if old_true.op != old_false.op or len(old_true.args) != len(old_false.args): - return self - - if old_true.op == "If": - # let's no go into this right now - return self - - if any(a.is_leaf() for a in self.args): - # burrowing through these is pretty funny - return self - - matches = [old_true.args[i] is old_false.args[i] for i in range(len(old_true.args))] - if matches.count(True) != 1 or all(matches): - # TODO: handle multiple differences for multi-arg ast nodes - # print("wrong number of matches:",matches,old_true,old_false) - return self - - different_idx = matches.index(False) - inner_if = claripy.If(self.args[0], old_true.args[different_idx], old_false.args[different_idx]) - new_args = list(old_true.args) - new_args[different_idx] = inner_if.ite_burrowed - # print("replaced the",different_idx,"arg:",new_args) - return old_true.__class__(old_true.op, new_args, length=self.length) - - def _excavate_ite(self): - ast_queue = [iter([self])] - arg_queue = [] - op_queue = [] - - while ast_queue: - try: - ast = next(ast_queue[-1]) - - if not isinstance(ast, Base): - arg_queue.append(ast) - continue - - if ast.is_leaf(): - arg_queue.append(ast) - continue - - if ast.annotations: - arg_queue.append(ast) - continue - - op_queue.append(ast) - ast_queue.append(iter(ast.args)) - - except StopIteration: - ast_queue.pop() - - if op_queue: - op = op_queue.pop() - - args = arg_queue[-len(op.args) :] - del arg_queue[-len(op.args) :] - - ite_args = [isinstance(a, Base) and a.op == "If" for a in args] - - if op.op == "If": - # if we are an If, call the If handler so that we can take advantage of its simplifiers - excavated = claripy.If(*args) - - elif ite_args.count(True) == 0: - # if there are no ifs that came to the surface, there's nothing more to do - excavated = op.swap_args(args, simplify=True) - - else: - # this gets called when we're *not* in an If, but there are Ifs in the args. - # it pulls those Ifs out to the surface. - cond = args[ite_args.index(True)].args[0] - new_true_args = [] - new_false_args = [] - - for a in args: - if not isinstance(a, Base) or a.op != "If": - new_true_args.append(a) - new_false_args.append(a) - elif a.args[0] is cond: - new_true_args.append(a.args[1]) - new_false_args.append(a.args[2]) - elif a.args[0] is ~cond: - new_true_args.append(a.args[2]) - new_false_args.append(a.args[1]) - else: - # weird conditions -- giving up! - excavated = op.swap_args(args, simplify=True) - break - - else: - excavated = claripy.If( - cond, - op.swap_args(new_true_args, simplify=True), - op.swap_args(new_false_args, simplify=True), - ) - - # continue - arg_queue.append(excavated) - - assert len(op_queue) == 0, "op_queue is not empty" - assert len(ast_queue) == 0, "ast_queue is not empty" - assert len(arg_queue) == 1, ("arg_queue has unexpected length", len(arg_queue)) - - return arg_queue.pop() - - @property - def ite_burrowed(self: T) -> T: - """ - Returns an equivalent AST that "burrows" the ITE expressions as deep as possible into the ast, for simpler - printing. - """ - if self._burrowed is None: - self._burrowed = self._burrow_ite() # pylint:disable=attribute-defined-outside-init - self._burrowed._burrowed = self._burrowed # pylint:disable=attribute-defined-outside-init - return self._burrowed - - @property - def ite_excavated(self: T) -> T: - """ - Returns an equivalent AST that "excavates" the ITE expressions out as far as possible toward the root of the - AST, for processing in static analyses. - """ - if self._excavated is None: - self._excavated = self._excavate_ite() # pylint:disable=attribute-defined-outside-init - - # we set the flag for the children so that we avoid re-excavating during - # VSA backend evaluation (since the backend evaluation recursively works on - # the excavated ASTs) - self._excavated._excavated = self._excavated - return self._excavated - # # these are convenience operations # diff --git a/claripy/ast/bool.py b/claripy/ast/bool.py index 4417a936e..39e7b39ec 100644 --- a/claripy/ast/bool.py +++ b/claripy/ast/bool.py @@ -2,7 +2,7 @@ import logging from functools import lru_cache -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, TypeVar, overload import claripy from claripy import operations @@ -16,6 +16,7 @@ from .bv import BV from .fp import FP +T = TypeVar("T", bound=Base) log = logging.getLogger(__name__) @@ -95,6 +96,8 @@ def If(cond: bool | Bool, true_value: bool | Bool, false_value: bool | Bool) -> def If(cond: bool | Bool, true_value: int | BV, false_value: int | BV) -> BV: ... @overload def If(cond: bool | Bool, true_value: float | FP, false_value: float | FP) -> FP: ... +@overload +def If(cond: bool | Bool, true_value: T, false_value: T) -> T: ... def If(cond, true_value, false_value): diff --git a/claripy/backends/backend_vsa/backend_vsa.py b/claripy/backends/backend_vsa/backend_vsa.py index 7589f016b..35c83ba52 100644 --- a/claripy/backends/backend_vsa/backend_vsa.py +++ b/claripy/backends/backend_vsa/backend_vsa.py @@ -6,6 +6,7 @@ import operator from functools import reduce +import claripy from claripy.annotation import RegionAnnotation, StridedIntervalAnnotation from claripy.ast.base import Base from claripy.ast.bv import BV, BVV, ESI, SI, TSI, VS @@ -106,7 +107,7 @@ def _op_mod(*args): return reduce(operator.__mod__, args) def convert(self, expr): - return Backend.convert(self, expr.ite_excavated if isinstance(expr, Base) else expr) + return Backend.convert(self, claripy.excavate_ite(expr) if isinstance(expr, Base) else expr) def _convert(self, r): if isinstance(r, numbers.Number): diff --git a/claripy/balancer.py b/claripy/balancer.py index cddf04658..2539ec207 100644 --- a/claripy/balancer.py +++ b/claripy/balancer.py @@ -30,7 +30,7 @@ def __init__(self, helper, c, validation_frontend=None): self._lower_bounds = {} self._upper_bounds = {} - self._queue_truism(c.ite_excavated) + self._queue_truism(claripy.excavate_ite(c)) self.sat = True try: diff --git a/tests/test_expression.py b/tests/test_expression.py index 72cc1c8ff..f58711fe8 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -211,7 +211,6 @@ def test_cardinality(self): def test_if_stuff(self): x = claripy.BVS("x", 32) - # y = claripy.BVS('y', 32) c = claripy.If(x > 10, (claripy.If(x > 10, x * 3, x * 2)), x * 4) + 2 cc = claripy.If(x > 10, x * 3, x * 4) + 2 @@ -219,17 +218,17 @@ def test_if_stuff(self): cccc = x * claripy.If(x > 10, claripy.BVV(3, 32), claripy.BVV(4, 32)) + 2 self.assertIs(c, cc) - self.assertIs(c.ite_excavated, ccc) - self.assertIs(ccc.ite_burrowed, cccc) + self.assertIs(claripy.excavate_ite(c), ccc) + self.assertIs(claripy.burrow_ite(ccc), cccc) i = c + c ii = claripy.If(x > 10, (x * 3 + 2) + (x * 3 + 2), (x * 4 + 2) + (x * 4 + 2)) - self.assertIs(i.ite_excavated, ii) + self.assertIs(claripy.excavate_ite(i), ii) cn = claripy.If(x <= 10, claripy.BVV(0x10, 32), 0x20) iii = c + cn iiii = claripy.If(x > 10, (x * 3 + 2) + 0x20, (x * 4 + 2) + 0x10) - self.assertIs(iii.ite_excavated, iiii) + self.assertIs(claripy.excavate_ite(iii), iiii) def test_ite_Solver(self): self.raw_ite(claripy.Solver) diff --git a/tests/test_vsa.py b/tests/test_vsa.py index 46dc1fda6..1454eca5d 100644 --- a/tests/test_vsa.py +++ b/tests/test_vsa.py @@ -711,9 +711,10 @@ def test_if_proxy_and_operations(self): claripy.SI(bits=32, stride=0, lower_bound=0xFFFFFFFF, upper_bound=0xFFFFFFFF), ) assert claripy.backends.vsa.is_true( - vsa_model(if_1.ite_excavated.args[1]) == vsa_model(claripy.ValueSet(region="global", bits=32, value=0)) + vsa_model(claripy.excavate_ite(if_1).args[1]) + == vsa_model(claripy.ValueSet(region="global", bits=32, value=0)) ) - assert claripy.backends.vsa.is_true(vsa_model(if_1.ite_excavated.args[2]) == vsa_model(vs_2)) + assert claripy.backends.vsa.is_true(vsa_model(claripy.excavate_ite(if_1).args[2]) == vsa_model(vs_2)) def test_if_proxy_or_operations(self): # if_2 = And(VS_3, IfProxy(si != 0, 0, 1)) @@ -725,14 +726,15 @@ def test_if_proxy_or_operations(self): claripy.SI(bits=32, stride=0, lower_bound=0xFFFFFFFF, upper_bound=0xFFFFFFFF), ) assert claripy.backends.vsa.is_true( - vsa_model(if_2.ite_excavated.args[1]) == vsa_model(claripy.ValueSet(region="global", bits=32, value=0)) + vsa_model(claripy.excavate_ite(if_2).args[1]) + == vsa_model(claripy.ValueSet(region="global", bits=32, value=0)) ) - assert claripy.backends.vsa.is_true(vsa_model(if_2.ite_excavated.args[2]) == vsa_model(vs_3)) + assert claripy.backends.vsa.is_true(vsa_model(claripy.excavate_ite(if_2).args[2]) == vsa_model(vs_3)) # Something crazy is gonna happen... # if_3 = if_1 + if_2 - # assert claripy.backends.vsa.is_true(vsa_model(if_3.ite_excavated.args[1]) == vsa_model(vs_3))) - # assert claripy.backends.vsa.is_true(vsa_model(if_3.ite_excavated.args[1]) == vsa_model(vs_2))) + # assert claripy.backends.vsa.is_true(vsa_model(claripy.excavate_ite(if_3).args[1]) == vsa_model(vs_3))) + # assert claripy.backends.vsa.is_true(vsa_model(claripy.excavate_ite(if_3).args[1]) == vsa_model(vs_2))) class TestVSAConstraintToSI(unittest.TestCase): # pylint: disable=no-member,function-redefined