diff --git a/claripy/ast/fp.py b/claripy/ast/fp.py index 8d50957d1..7d9b6af2b 100644 --- a/claripy/ast/fp.py +++ b/claripy/ast/fp.py @@ -24,7 +24,7 @@ class FP(Bits): __slots__ = () - def to_fp(self, sort, rm=None): + def to_fp(self, sort, rm=RM.RM_NearestTiesEven): """ Convert this float to a different sort @@ -32,10 +32,7 @@ def to_fp(self, sort, rm=None): :param rm: Optional: The rounding mode to use :return: An FP AST """ - if rm is None: - rm = RM.default() - - return fpToFP(rm, self, sort) + return fpToFP(self, sort, rm) def raw_to_fp(self): """ @@ -68,7 +65,7 @@ def val_to_bv(self, size, signed=True, rm=None): rm = RM.default() op = fpToSBV if signed else fpToUBV - return op(rm, self, size) + return op(self, size, rm) @property def sort(self): @@ -125,20 +122,16 @@ def FPV(value, sort): # -def _fp_length_calc(a1, a2, a3=None): - if isinstance(a1, RM) and a3 is None: - raise Exception - if a3 is None: - return a2.length - return a3.length +def _fp_length_calc(_a1, a2, _a3=None): + return a2.length fpToFP = operations.op("fpToFP", object, FP, calc_length=_fp_length_calc) -fpToFPUnsigned = operations.op("fpToFPUnsigned", (RM, BV, FSort), FP, calc_length=_fp_length_calc) +fpToFPUnsigned = operations.op("fpToFPUnsigned", (BV, FSort, RM), FP, calc_length=_fp_length_calc) fpFP = operations.op("fpFP", (BV, BV, BV), FP, calc_length=lambda a, b, c: a.length + b.length + c.length) fpToIEEEBV = operations.op("fpToIEEEBV", (FP,), BV, calc_length=lambda fp: fp.length) -fpToSBV = operations.op("fpToSBV", (RM, FP, int), BV, calc_length=lambda _rm, _fp, len: len) -fpToUBV = operations.op("fpToUBV", (RM, FP, int), BV, calc_length=lambda _rm, _fp, len: len) +fpToSBV = operations.op("fpToSBV", (FP, int, RM), BV, calc_length=lambda _fp, len, _rm: len) +fpToUBV = operations.op("fpToUBV", (FP, int, RM), BV, calc_length=lambda _fp, len, _rm: len) # # unbound float point comparisons @@ -162,21 +155,21 @@ def _fp_cmp_check(a, b): # -def _fp_binop_check(rm, a, b): # pylint:disable=unused-argument +def _fp_binop_check(a, b, _): return a.length == b.length, "Lengths must be equal" -def _fp_binop_length(rm, a, b): # pylint:disable=unused-argument +def _fp_binop_length(a, _b, _rm): return a.length fpAbs = operations.op("fpAbs", (FP,), FP, calc_length=lambda x: x.length) fpNeg = operations.op("fpNeg", (FP,), FP, calc_length=lambda x: x.length) -fpSub = operations.op("fpSub", (RM, FP, FP), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length) -fpAdd = operations.op("fpAdd", (RM, FP, FP), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length) -fpMul = operations.op("fpMul", (RM, FP, FP), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length) -fpDiv = operations.op("fpDiv", (RM, FP, FP), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length) -fpSqrt = operations.op("fpSqrt", (RM, FP), FP, calc_length=lambda _, x: x.length) +fpSub = operations.op("fpSub", (FP, FP, RM), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length) +fpAdd = operations.op("fpAdd", (FP, FP, RM), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length) +fpMul = operations.op("fpMul", (FP, FP, RM), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length) +fpDiv = operations.op("fpDiv", (FP, FP, RM), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length) +fpSqrt = operations.op("fpSqrt", (FP, RM), FP, calc_length=lambda x, _: x.length) # # bound fp operations diff --git a/claripy/backends/backend_concrete/backend_concrete.py b/claripy/backends/backend_concrete/backend_concrete.py index 117b03cc4..13de04437 100644 --- a/claripy/backends/backend_concrete/backend_concrete.py +++ b/claripy/backends/backend_concrete/backend_concrete.py @@ -105,7 +105,7 @@ def _op_boolnot(arg): return not arg @staticmethod - def _op_fpSqrt(rm, a): # pylint:disable=unused-argument + def _op_fpSqrt(a, _): return a.fpSqrt() def convert(self, expr): diff --git a/claripy/backends/backend_concrete/fp.py b/claripy/backends/backend_concrete/fp.py index ad2a4b2aa..5a9fe8838 100644 --- a/claripy/backends/backend_concrete/fp.py +++ b/claripy/backends/backend_concrete/fp.py @@ -177,47 +177,14 @@ def __repr__(self): return f"FPV({self.value:f}, {self.sort})" -def fpToFP(a1, a2, a3=None): - """ - Returns a FP AST and has three signatures: - - fpToFP(ubvv, sort) - Returns a FP AST whose value is the same as the unsigned BVV `a1` - and whose sort is `a2`. - - fpToFP(rm, fpv, sort) - Returns a FP AST whose value is the same as the floating point `a2` - and whose sort is `a3`. - - fpToTP(rm, sbvv, sort) - Returns a FP AST whose value is the same as the signed BVV `a2` and - whose sort is `a3`. - """ - if isinstance(a1, BVV) and isinstance(a2, FSort): - sort = a2 - if sort == FSORT_FLOAT: - pack, unpack = "I", "f" - elif sort == FSORT_DOUBLE: - pack, unpack = "Q", "d" - else: - raise ClaripyOperationError("unrecognized float sort") - - try: - packed = struct.pack("<" + pack, a1.value) - (unpacked,) = struct.unpack("<" + unpack, packed) - except OverflowError as e: - # struct.pack sometimes overflows - raise ClaripyOperationError("OverflowError: " + str(e)) from e - - return FPV(unpacked, sort) - if isinstance(a1, RM) and isinstance(a2, FPV) and isinstance(a3, FSort): - return FPV(a2.value, a3) - if isinstance(a1, RM) and isinstance(a2, BVV) and isinstance(a3, FSort): - return FPV(float(a2.signed), a3) - raise ClaripyOperationError("unknown types passed to fpToFP") +def fpToFP(arg: BVV | FPV, sort: FSort, _rm: RM = RM.RM_NearestTiesEven): + """Returns a FP AST given a BVV or FPV.""" + if isinstance(arg, BVV): + arg = FPV(float(arg.value), sort) + return FPV(arg.value, sort) -def fpToFPUnsigned(_rm, thing, sort): +def fpToFPUnsigned(thing, sort, _rm): """ Returns a FP AST whose value is the same as the unsigned BVV `thing` and whose sort is `sort`. @@ -278,7 +245,7 @@ def fpFP(sgn, exp, mantissa): return FPV(unpacked, sort) -def fpToSBV(rm, fp, size): +def fpToSBV(fp, size, rm): try: rounding_mode = rm.pydecimal_equivalent_rounding_mode() val = int(Decimal(fp.value).to_integral_value(rounding_mode)) @@ -291,7 +258,7 @@ def fpToSBV(rm, fp, size): raise -def fpToUBV(rm, fp, size): +def fpToUBV(fp, size, rm): # todo: actually make unsigned try: rounding_mode = rm.pydecimal_equivalent_rounding_mode() diff --git a/claripy/backends/backend_z3.py b/claripy/backends/backend_z3.py index 9e9c772d3..9086c03f3 100644 --- a/claripy/backends/backend_z3.py +++ b/claripy/backends/backend_z3.py @@ -548,6 +548,9 @@ def _abstract_internal(self, ctx, ast, split_on=None): log.error(err) raise BackendError(err) + if ty is FP and isinstance(args[0], RM): + args = [*args[1:], args[0]] + if op_name == "If": # If is polymorphic and thus must be handled specially ty = type(args[1]) @@ -1019,23 +1022,23 @@ def _op_raw_fpNeg(self, a): return z3.FPRef(z3.Z3_mk_fpa_neg(self._context.ref(), a.as_ast()), self._context) @condom - def _op_raw_fpAdd(self, rm, a, b): + def _op_raw_fpAdd(self, a, b, rm): return z3.FPRef(z3.Z3_mk_fpa_add(self._context.ref(), rm.as_ast(), a.as_ast(), b.as_ast()), self._context) @condom - def _op_raw_fpSub(self, rm, a, b): + def _op_raw_fpSub(self, a, b, rm): return z3.FPRef(z3.Z3_mk_fpa_sub(self._context.ref(), rm.as_ast(), a.as_ast(), b.as_ast()), self._context) @condom - def _op_raw_fpMul(self, rm, a, b): + def _op_raw_fpMul(self, a, b, rm): return z3.FPRef(z3.Z3_mk_fpa_mul(self._context.ref(), rm.as_ast(), a.as_ast(), b.as_ast()), self._context) @condom - def _op_raw_fpDiv(self, rm, a, b): + def _op_raw_fpDiv(self, a, b, rm): return z3.FPRef(z3.Z3_mk_fpa_div(self._context.ref(), rm.as_ast(), a.as_ast(), b.as_ast()), self._context) @condom - def _op_raw_fpSqrt(self, rm, a): + def _op_raw_fpSqrt(self, a, rm): return z3.FPRef(z3.Z3_mk_fpa_sqrt(self._context.ref(), rm.as_ast(), a.as_ast()), self._context) @condom @@ -1071,11 +1074,11 @@ def _op_raw_fpFP(self, sgn, exp, sig): return z3.FPRef(z3.Z3_mk_fpa_fp(self._context.ref(), sgn.ast, exp.ast, sig.ast), self._context) @condom - def _op_raw_fpToSBV(self, rm, fp, bv_len): + def _op_raw_fpToSBV(self, fp, bv_len, rm): return z3.BitVecRef(z3.Z3_mk_fpa_to_sbv(self._context.ref(), rm.ast, fp.ast, bv_len), self._context) @condom - def _op_raw_fpToUBV(self, rm, fp, bv_len): + def _op_raw_fpToUBV(self, fp, bv_len, rm): return z3.BitVecRef(z3.Z3_mk_fpa_to_ubv(self._context.ref(), rm.ast, fp.ast, bv_len), self._context) @condom diff --git a/claripy/fp.py b/claripy/fp.py index 8ffaf2129..32733b46e 100644 --- a/claripy/fp.py +++ b/claripy/fp.py @@ -7,9 +7,10 @@ class RM(Enum): - """Rounding modes for floating point operations. + """Floating point rounding mode, as defined by IEEE754. - See https://en.wikipedia.org/wiki/IEEE_754#Rounding_rules for more information. + See this wikipedia entry for details: + https://en.wikipedia.org/wiki/IEEE_754#Rounding_rules """ RM_NearestTiesEven = "RM_RNE" @@ -33,7 +34,9 @@ def pydecimal_equivalent_rounding_mode(self): class FSort: - """A class representing a floating point sort.""" + """Floating point sort, desribing the size of the exponent and mantissa for + an IEEE754 floating point number. + """ def __init__(self, name, exp, mantissa): self.name = name diff --git a/claripy/operations.py b/claripy/operations.py index cfee13564..29d7b87d3 100644 --- a/claripy/operations.py +++ b/claripy/operations.py @@ -20,8 +20,8 @@ def op(name, arg_types, return_type, extra_check=None, calc_length=None): def _type_fixer(args): num_args = len(args) if expected_num_args is not None and num_args != expected_num_args: - if num_args + 1 == expected_num_args and arg_types[0] is claripy.fp.RM: - args = (claripy.fp.RM.default(), *args) + if num_args + 1 == expected_num_args and arg_types[-1] is claripy.fp.RM: + args = (*args, claripy.fp.RM.default()) else: raise ClaripyTypeError(f"Operation {name} takes exactly {len(arg_types)} arguments ({len(args)} given)") @@ -29,7 +29,7 @@ def _type_fixer(args): matches = list(itertools.starmap(isinstance, zip(args, actual_arg_types, strict=False))) # heuristically, this works! - thing = args[matches.index(True, 1 if actual_arg_types[0] is claripy.fp.RM else 0)] if True in matches else None + thing = args[matches.index(True, 0)] if True in matches else None for arg, argty, match in zip(args, actual_arg_types, matches, strict=False): if not match: diff --git a/tests/test_fp.py b/tests/test_fp.py index a3bfeddd4..774af2904 100644 --- a/tests/test_fp.py +++ b/tests/test_fp.py @@ -54,7 +54,7 @@ def test_negative_zero(self): def test_fp_ops(self): a = claripy.FPV(1.5, claripy.FSORT_DOUBLE) - b = claripy.fpToUBV(claripy.fp.RM.RM_NearestTiesEven, a, 32) + b = claripy.fpToUBV(a, 32) s = claripy.Solver() assert s.eval(b, 1)[0] == 2