From 0dc1658a1d5dd046a17d2a8ea74699c168450993 Mon Sep 17 00:00:00 2001 From: Audrey Dutcher Date: Wed, 18 Sep 2024 13:48:14 -0700 Subject: [PATCH] Add type annotations for If --- claripy/ast/bool.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/claripy/ast/bool.py b/claripy/ast/bool.py index 1298a3d9f..c78fe0fa9 100644 --- a/claripy/ast/bool.py +++ b/claripy/ast/bool.py @@ -3,14 +3,18 @@ import atexit import logging from contextlib import suppress +from typing import TYPE_CHECKING, overload from claripy import operations from claripy.ast.base import ASTCacheKey, Base, _make_name from claripy.backend_manager import backends -from claripy.errors import BackendError, ClaripyOperationError, ClaripyTypeError +from claripy.errors import BackendError, ClaripyTypeError from .bits import Bits +if TYPE_CHECKING: + from .fp import FP + l = logging.getLogger("claripy.ast.bool") _boolv_cache = {} @@ -97,12 +101,17 @@ def BoolV(val) -> Bool: # -def If(*args): - # the coercion here is strange enough that we'll just implement it manually - if len(args) != 3: - raise ClaripyOperationError("invalid number of args passed to If") +@overload +def If(cond: bool | Bool, true_value: bool | Bool, false_value: bool | Bool) -> Bool: ... +@overload +def If(cond: bool | Bool, true_value: int | BV, false_value: int | BV) -> BV: ... +@overload +def If(cond: bool | Bool, true_value: float | FP, false_value: float | FP) -> FP: ... - args = list(args) + +def If(cond, true_value, false_value): + # the coercion here is strange enough that we'll just implement it manually + args = [cond, true_value, false_value] if isinstance(args[0], bool): args[0] = BoolV(args[0]) @@ -274,4 +283,4 @@ def constraint_to_si(expr): # pylint: disable=wrong-import-position -from .bv import BVS # noqa: E402 +from .bv import BV, BVS # noqa: E402