Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make RM the final argument in FP ops #498

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
37 changes: 15 additions & 22 deletions claripy/ast/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,15 @@ class FP(Bits):

__slots__ = ()

def to_fp(self, sort, rm=None):
def to_fp(self, sort, rm=RM.RM_NearestTiesEven):
"""
Convert this float to a different sort

:param sort: The sort to convert to
:param rm: Optional: The rounding mode to use
:return: An FP AST
"""
if rm is None:
rm = RM.default()

return fpToFP(rm, self, sort)
return fpToFP(self, sort, rm)

def raw_to_fp(self):
"""
Expand Down Expand Up @@ -68,7 +65,7 @@ def val_to_bv(self, size, signed=True, rm=None):
rm = RM.default()

op = fpToSBV if signed else fpToUBV
return op(rm, self, size)
return op(self, size, rm)

@property
def sort(self):
Expand Down Expand Up @@ -125,20 +122,16 @@ def FPV(value, sort):
#


def _fp_length_calc(a1, a2, a3=None):
if isinstance(a1, RM) and a3 is None:
raise Exception
if a3 is None:
return a2.length
return a3.length
def _fp_length_calc(_a1, a2, _a3=None):
return a2.length


fpToFP = operations.op("fpToFP", object, FP, calc_length=_fp_length_calc)
fpToFPUnsigned = operations.op("fpToFPUnsigned", (RM, BV, FSort), FP, calc_length=_fp_length_calc)
fpToFPUnsigned = operations.op("fpToFPUnsigned", (BV, FSort, RM), FP, calc_length=_fp_length_calc)
fpFP = operations.op("fpFP", (BV, BV, BV), FP, calc_length=lambda a, b, c: a.length + b.length + c.length)
fpToIEEEBV = operations.op("fpToIEEEBV", (FP,), BV, calc_length=lambda fp: fp.length)
fpToSBV = operations.op("fpToSBV", (RM, FP, int), BV, calc_length=lambda _rm, _fp, len: len)
fpToUBV = operations.op("fpToUBV", (RM, FP, int), BV, calc_length=lambda _rm, _fp, len: len)
fpToSBV = operations.op("fpToSBV", (FP, int, RM), BV, calc_length=lambda _fp, len, _rm: len)
fpToUBV = operations.op("fpToUBV", (FP, int, RM), BV, calc_length=lambda _fp, len, _rm: len)

#
# unbound float point comparisons
Expand All @@ -162,21 +155,21 @@ def _fp_cmp_check(a, b):
#


def _fp_binop_check(rm, a, b): # pylint:disable=unused-argument
def _fp_binop_check(a, b, _):
return a.length == b.length, "Lengths must be equal"


def _fp_binop_length(rm, a, b): # pylint:disable=unused-argument
def _fp_binop_length(a, _b, _rm):
return a.length


fpAbs = operations.op("fpAbs", (FP,), FP, calc_length=lambda x: x.length)
fpNeg = operations.op("fpNeg", (FP,), FP, calc_length=lambda x: x.length)
fpSub = operations.op("fpSub", (RM, FP, FP), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length)
fpAdd = operations.op("fpAdd", (RM, FP, FP), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length)
fpMul = operations.op("fpMul", (RM, FP, FP), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length)
fpDiv = operations.op("fpDiv", (RM, FP, FP), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length)
fpSqrt = operations.op("fpSqrt", (RM, FP), FP, calc_length=lambda _, x: x.length)
fpSub = operations.op("fpSub", (FP, FP, RM), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length)
fpAdd = operations.op("fpAdd", (FP, FP, RM), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length)
fpMul = operations.op("fpMul", (FP, FP, RM), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length)
fpDiv = operations.op("fpDiv", (FP, FP, RM), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length)
fpSqrt = operations.op("fpSqrt", (FP, RM), FP, calc_length=lambda x, _: x.length)

#
# bound fp operations
Expand Down
2 changes: 1 addition & 1 deletion claripy/backends/backend_concrete/backend_concrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _op_boolnot(arg):
return not arg

@staticmethod
def _op_fpSqrt(rm, a): # pylint:disable=unused-argument
def _op_fpSqrt(a, _):
return a.fpSqrt()

def convert(self, expr):
Expand Down
49 changes: 8 additions & 41 deletions claripy/backends/backend_concrete/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,47 +177,14 @@ def __repr__(self):
return f"FPV({self.value:f}, {self.sort})"


def fpToFP(a1, a2, a3=None):
"""
Returns a FP AST and has three signatures:

fpToFP(ubvv, sort)
Returns a FP AST whose value is the same as the unsigned BVV `a1`
and whose sort is `a2`.

fpToFP(rm, fpv, sort)
Returns a FP AST whose value is the same as the floating point `a2`
and whose sort is `a3`.

fpToTP(rm, sbvv, sort)
Returns a FP AST whose value is the same as the signed BVV `a2` and
whose sort is `a3`.
"""
if isinstance(a1, BVV) and isinstance(a2, FSort):
sort = a2
if sort == FSORT_FLOAT:
pack, unpack = "I", "f"
elif sort == FSORT_DOUBLE:
pack, unpack = "Q", "d"
else:
raise ClaripyOperationError("unrecognized float sort")

try:
packed = struct.pack("<" + pack, a1.value)
(unpacked,) = struct.unpack("<" + unpack, packed)
except OverflowError as e:
# struct.pack sometimes overflows
raise ClaripyOperationError("OverflowError: " + str(e)) from e

return FPV(unpacked, sort)
if isinstance(a1, RM) and isinstance(a2, FPV) and isinstance(a3, FSort):
return FPV(a2.value, a3)
if isinstance(a1, RM) and isinstance(a2, BVV) and isinstance(a3, FSort):
return FPV(float(a2.signed), a3)
raise ClaripyOperationError("unknown types passed to fpToFP")
def fpToFP(arg: BVV | FPV, sort: FSort, _rm: RM = RM.RM_NearestTiesEven):
"""Returns a FP AST given a BVV or FPV."""
if isinstance(arg, BVV):
arg = FPV(float(arg.value), sort)
return FPV(arg.value, sort)


def fpToFPUnsigned(_rm, thing, sort):
def fpToFPUnsigned(thing, sort, _rm):
"""
Returns a FP AST whose value is the same as the unsigned BVV `thing` and
whose sort is `sort`.
Expand Down Expand Up @@ -278,7 +245,7 @@ def fpFP(sgn, exp, mantissa):
return FPV(unpacked, sort)


def fpToSBV(rm, fp, size):
def fpToSBV(fp, size, rm):
try:
rounding_mode = rm.pydecimal_equivalent_rounding_mode()
val = int(Decimal(fp.value).to_integral_value(rounding_mode))
Expand All @@ -291,7 +258,7 @@ def fpToSBV(rm, fp, size):
raise


def fpToUBV(rm, fp, size):
def fpToUBV(fp, size, rm):
# todo: actually make unsigned
try:
rounding_mode = rm.pydecimal_equivalent_rounding_mode()
Expand Down
17 changes: 10 additions & 7 deletions claripy/backends/backend_z3.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,9 @@ def _abstract_internal(self, ctx, ast, split_on=None):
log.error(err)
raise BackendError(err)

if ty is FP and isinstance(args[0], RM):
args = [*args[1:], args[0]]

if op_name == "If":
# If is polymorphic and thus must be handled specially
ty = type(args[1])
Expand Down Expand Up @@ -1019,23 +1022,23 @@ def _op_raw_fpNeg(self, a):
return z3.FPRef(z3.Z3_mk_fpa_neg(self._context.ref(), a.as_ast()), self._context)

@condom
def _op_raw_fpAdd(self, rm, a, b):
def _op_raw_fpAdd(self, a, b, rm):
return z3.FPRef(z3.Z3_mk_fpa_add(self._context.ref(), rm.as_ast(), a.as_ast(), b.as_ast()), self._context)

@condom
def _op_raw_fpSub(self, rm, a, b):
def _op_raw_fpSub(self, a, b, rm):
return z3.FPRef(z3.Z3_mk_fpa_sub(self._context.ref(), rm.as_ast(), a.as_ast(), b.as_ast()), self._context)

@condom
def _op_raw_fpMul(self, rm, a, b):
def _op_raw_fpMul(self, a, b, rm):
return z3.FPRef(z3.Z3_mk_fpa_mul(self._context.ref(), rm.as_ast(), a.as_ast(), b.as_ast()), self._context)

@condom
def _op_raw_fpDiv(self, rm, a, b):
def _op_raw_fpDiv(self, a, b, rm):
return z3.FPRef(z3.Z3_mk_fpa_div(self._context.ref(), rm.as_ast(), a.as_ast(), b.as_ast()), self._context)

@condom
def _op_raw_fpSqrt(self, rm, a):
def _op_raw_fpSqrt(self, a, rm):
return z3.FPRef(z3.Z3_mk_fpa_sqrt(self._context.ref(), rm.as_ast(), a.as_ast()), self._context)

@condom
Expand Down Expand Up @@ -1071,11 +1074,11 @@ def _op_raw_fpFP(self, sgn, exp, sig):
return z3.FPRef(z3.Z3_mk_fpa_fp(self._context.ref(), sgn.ast, exp.ast, sig.ast), self._context)

@condom
def _op_raw_fpToSBV(self, rm, fp, bv_len):
def _op_raw_fpToSBV(self, fp, bv_len, rm):
return z3.BitVecRef(z3.Z3_mk_fpa_to_sbv(self._context.ref(), rm.ast, fp.ast, bv_len), self._context)

@condom
def _op_raw_fpToUBV(self, rm, fp, bv_len):
def _op_raw_fpToUBV(self, fp, bv_len, rm):
return z3.BitVecRef(z3.Z3_mk_fpa_to_ubv(self._context.ref(), rm.ast, fp.ast, bv_len), self._context)

@condom
Expand Down
9 changes: 6 additions & 3 deletions claripy/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@


class RM(Enum):
"""Rounding modes for floating point operations.
"""Floating point rounding mode, as defined by IEEE754.

See https://en.wikipedia.org/wiki/IEEE_754#Rounding_rules for more information.
See this wikipedia entry for details:
https://en.wikipedia.org/wiki/IEEE_754#Rounding_rules
"""

RM_NearestTiesEven = "RM_RNE"
Expand All @@ -33,7 +34,9 @@ def pydecimal_equivalent_rounding_mode(self):


class FSort:
"""A class representing a floating point sort."""
"""Floating point sort, desribing the size of the exponent and mantissa for
an IEEE754 floating point number.
"""

def __init__(self, name, exp, mantissa):
self.name = name
Expand Down
6 changes: 3 additions & 3 deletions claripy/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ def op(name, arg_types, return_type, extra_check=None, calc_length=None):
def _type_fixer(args):
num_args = len(args)
if expected_num_args is not None and num_args != expected_num_args:
if num_args + 1 == expected_num_args and arg_types[0] is claripy.fp.RM:
args = (claripy.fp.RM.default(), *args)
if num_args + 1 == expected_num_args and arg_types[-1] is claripy.fp.RM:
args = (*args, claripy.fp.RM.default())
else:
raise ClaripyTypeError(f"Operation {name} takes exactly {len(arg_types)} arguments ({len(args)} given)")

actual_arg_types = (arg_types,) * num_args if isinstance(arg_types, type) else arg_types
matches = list(itertools.starmap(isinstance, zip(args, actual_arg_types, strict=False)))

# heuristically, this works!
thing = args[matches.index(True, 1 if actual_arg_types[0] is claripy.fp.RM else 0)] if True in matches else None
thing = args[matches.index(True, 0)] if True in matches else None

for arg, argty, match in zip(args, actual_arg_types, matches, strict=False):
if not match:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_negative_zero(self):

def test_fp_ops(self):
a = claripy.FPV(1.5, claripy.FSORT_DOUBLE)
b = claripy.fpToUBV(claripy.fp.RM.RM_NearestTiesEven, a, 32)
b = claripy.fpToUBV(a, 32)

s = claripy.Solver()
assert s.eval(b, 1)[0] == 2
Expand Down
Loading