diff --git a/qlasskit/types/qfixed.py b/qlasskit/types/qfixed.py index 8666a9e6..6968ce61 100644 --- a/qlasskit/types/qfixed.py +++ b/qlasskit/types/qfixed.py @@ -49,7 +49,11 @@ class QfixedImp(float, Qtype): def __init__(self, value: float): super().__init__() + self.value = value + # v = str(value).split('.') + # self.value = int(v[0]) % self.BIT_SIZE_INTEGER + (float(f'0.{v[1]}') + # if len(v) == 2 else 0) @classmethod def from_bool(cls, v: List[bool]): @@ -200,6 +204,40 @@ def add(cls, tleft: TExp, tright: TExp) -> TExp: return (tleft[0], QfixedImp._from_qint_repr((tleft[0], res[1]))) + @classmethod + def sub(cls, tleft: TExp, tright: TExp) -> TExp: + """Subtract two Qfixed""" + an = cls.bitwise_not(cls.fill(tleft)) # type: ignore + su = cls.add(an, cls.fill(tright)) # type: ignore + return cls.bitwise_not(su) # type: ignore + + @classmethod + def mul(cls, tleft: TExp, tright: TExp) -> TExp: # noqa: C901 + a = len(list(filter(lambda b: b is bool, tleft[1]))) + b = len(list(filter(lambda b: b is bool, tright[1]))) + + if a == 0 and issubclass(tleft[0], Qint): # type: ignore + tconst = tleft + top = tright + elif b == 0 and issubclass(tright[0], Qint): # type: ignore + top = tleft + tconst = tright + else: + raise Exception( + "Qfixed mul works only between a Qfixed and an integer constant" + ) + + v_const = int(bool_list_to_bin(tconst[1])[::-1], 2) + + if v_const == 0: + return cls.const(0.0) + + v = top + for i in range(v_const - 1): + v = cls.add(v, top) + + return v + class Qfixed1_2(QfixedImp): BIT_SIZE = 3 diff --git a/test/qlassf/test_fixed.py b/test/qlassf/test_fixed.py index 01353cf5..f34a753f 100644 --- a/test/qlassf/test_fixed.py +++ b/test/qlassf/test_fixed.py @@ -16,7 +16,7 @@ from parameterized import parameterized, parameterized_class -from qlasskit import qlassf +from qlasskit import Qint2, qlassf from qlasskit.types.qfixed import Qfixed1_3, Qfixed2_3, Qfixed2_4 from qlasskit.types.qtype import bin_to_bool_list @@ -63,8 +63,18 @@ def test_fixed_gt(self, qft, a, b, r): [Qfixed2_3, 3.5, 0.5], ] ) - def test_fixed_add(self, qft, a, b): + def test_fixed_add_sub(self, qft, a, b): self.assertEqual(qft.add(qft.const(a), qft.const(b))[1], qft.const(a + b)[1]) + self.assertEqual(qft.sub(qft.const(a), qft.const(b))[1], qft.const(a - b)[1]) + + @parameterized.expand( + [ + [Qfixed2_3, 0.5, Qint2, 2], + [Qfixed2_3, 0.5, Qint2, 0], + ] + ) + def test_fixed_mul(self, qft, a, qit, b): + self.assertEqual(qft.mul(qft.const(a), qit.const(b))[1], qft.const(a * b)[1]) @parameterized_class(("compiler"), ENABLED_COMPILERS) @@ -119,16 +129,38 @@ def test_gte(self): qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) compute_and_compare_results(self, qf) - def test_sum_const(self): + def test_add_const(self): f = "def test(a: Qfixed[2,4]) -> Qfixed[2, 4]:\n\treturn Qfixed2_4(0.5) + a" qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) compute_and_compare_results(self, qf) - def test_sum(self): + def test_add(self): f = "def test(a: Qfixed[1,4], b: Qfixed[1,4]) -> Qfixed[1,4]:\n\treturn a + b" qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) compute_and_compare_results(self, qf) + def test_mul(self): + f = "def test(a: Qfixed[1,4]) -> Qfixed[1,4]:\n\treturn a * 3" + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) + + def test_sub(self): + f = ( + "def test(a: Qfixed[1,4], b: Qfixed[1,4]) -> Qfixed[1,4]:\n" + "\treturn a - b if a > b else b - a" + ) + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) + + def test_sub_const(self): + # TODO: allow negative overflow + f = ( + "def test(a: Qfixed[2,4]) -> Qfixed[2, 4]:\n" + "\treturn (a - Qfixed2_4(0.5)) if a > Qfixed2_4(0.5) else 0" + ) + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) + # def test_to_int(self): # f = "def test(a: Qfixed[2,4]) -> Qint2:\n\treturn int(a)" # qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) diff --git a/test/utils.py b/test/utils.py index 6eec672d..1284fa33 100644 --- a/test/utils.py +++ b/test/utils.py @@ -174,10 +174,10 @@ def res_to_str(res): qi = qf.returns.ttype.from_bool(qc[1]) except: qi = qc[0].from_bool(qc[1]) - return qi.to_bin() - elif type(res) is float: + return qi.to_bin() + elif type(res) is float: qi = qf.returns.ttype(res) - return qi.to_bin() + return qi.to_bin() else: return res.to_bin() @@ -191,7 +191,7 @@ def res_to_str(res): res_original = qf.original_f(*args) res_original_str = res_to_str(res_original) - + # truth_str = "".join( # map(lambda x: "1" if x else "0", truth_line[-qf.output_size :]) # )