From 058d68dd4ed3da15a120a08082f70328e22e7f5f Mon Sep 17 00:00:00 2001 From: "Davide Gessa (dakk)" Date: Sat, 16 Mar 2024 17:41:14 +0100 Subject: [PATCH] qfixed.add --- qlasskit/qlassfun.py | 1 + qlasskit/types/__init__.py | 9 ++++----- qlasskit/types/qfixed.py | 6 +++--- test/qlassf/test_fixed.py | 36 ++++++++++++++++++++++-------------- test/utils.py | 13 ++++++++++--- 5 files changed, 40 insertions(+), 25 deletions(-) diff --git a/qlasskit/qlassfun.py b/qlasskit/qlassfun.py index f89c21c0..adf1eee5 100644 --- a/qlasskit/qlassfun.py +++ b/qlasskit/qlassfun.py @@ -30,6 +30,7 @@ from .qcircuit import QCircuitWrapper from .types import * # noqa: F403, F401 from .types import Qtype, format_outcome, interpret_as_qtype, type_repr +from .types.qfixed import * # noqa: F403, F401 MAX_TRUTH_TABLE_SIZE = 20 diff --git a/qlasskit/types/__init__.py b/qlasskit/types/__init__.py index 6fd7b336..411d7fca 100644 --- a/qlasskit/types/__init__.py +++ b/qlasskit/types/__init__.py @@ -94,11 +94,10 @@ def const_to_qtype(value: Any) -> TExp: elif isinstance(value, float): for det_type in QFIXED_TYPES: # type: ignore - v_s = str(value).split(".") - - # TODO: check also for the fractional part - if int(v_s[0]) < 2**det_type.BIT_SIZE_INTEGER: # type: ignore - return det_type.const(value) # type: ignore + v = det_type.const(value) # type: ignore + c_val = det_type.from_bool(v[1]) + if c_val > value - 0.05 and c_val < value + 0.05: + return v raise Exception(f"Unable to infer type of constant: {value}") diff --git a/qlasskit/types/qfixed.py b/qlasskit/types/qfixed.py index 816a4fbd..8666a9e6 100644 --- a/qlasskit/types/qfixed.py +++ b/qlasskit/types/qfixed.py @@ -71,9 +71,9 @@ def from_bool(cls, v: List[bool]): return cls(integer_value + fractional_value) def to_bool(self) -> List[bool]: - integer_part = bin_to_bool_list(bin(int(self.value)), self.BIT_SIZE_INTEGER)[ - ::-1 - ] + integer_part = bin_to_bool_list( + bin(int(self.value))[::-1], self.BIT_SIZE_INTEGER + ) fractional_part = [] c_val = self.value diff --git a/test/qlassf/test_fixed.py b/test/qlassf/test_fixed.py index 1b77bcd1..01353cf5 100644 --- a/test/qlassf/test_fixed.py +++ b/test/qlassf/test_fixed.py @@ -32,6 +32,7 @@ class TestQfixedEncoding(unittest.TestCase): (Qfixed1_3, "1000", 1.0), (Qfixed2_3, "01000", 2.0), (Qfixed2_3, "01100", 2.5), + # (Qfixed2_3, "00000", 4.0), ] ) def test_fixed_from_bool_and_to_bin(self, qft, bin_v, val): @@ -58,6 +59,8 @@ def test_fixed_gt(self, qft, a, b, r): [Qfixed2_3, 0.5, 0.5], [Qfixed2_3, 0.75, 0.75], [Qfixed2_3, 1.0, 0.75], + [Qfixed2_3, 1.0, 0.5], + [Qfixed2_3, 3.5, 0.5], ] ) def test_fixed_add(self, qft, a, b): @@ -76,10 +79,10 @@ def test_fixed_const(self): qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) compute_and_compare_results(self, qf) - # def test_equal_const(self): - # f = "def test(a: Qfixed[1,4]) -> bool:\n\treturn a == 0.1" - # qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) - # compute_and_compare_results(self, qf) + def test_equal_const(self): + f = "def test(a: Qfixed[2,4]) -> bool:\n\treturn a == Qfixed2_4(0.5)" + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) def test_equal(self): f = "def test(a: Qfixed[1,4], b: Qfixed[1,4]) -> bool:\n\treturn a == b" @@ -96,6 +99,11 @@ def test_gt(self): qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) compute_and_compare_results(self, qf) + def test_gt_const(self): + f = "def test(a: Qfixed[1,4], b: Qfixed[1,4]) -> bool:\n\treturn a > Qfixed1_4(0.5)" + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) + def test_lt(self): f = "def test(a: Qfixed[1,4], b: Qfixed[1,4]) -> bool:\n\treturn a < b" qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) @@ -111,6 +119,16 @@ def test_gte(self): qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) compute_and_compare_results(self, qf) + def test_sum_const(self): + f = "def test(a: Qfixed[2,4]) -> Qfixed[2, 4]:\n\treturn Qfixed2_4(0.5) + a" + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) + + def test_sum(self): + f = "def test(a: Qfixed[1,4], b: Qfixed[1,4]) -> Qfixed[1,4]:\n\treturn a + b" + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) + # def test_to_int(self): # f = "def test(a: Qfixed[2,4]) -> Qint2:\n\treturn int(a)" # qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) @@ -120,13 +138,3 @@ def test_gte(self): # f = "def test(a: Qint2) -> Qfixed[2,4]:\n\treturn float(a)" # qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) # compute_and_compare_results(self, qf) - - # def test_sum_const(self): - # f = "def test(a: Qfixed[1,4]) -> Qfixed[1, 3]:\n\treturn 0.1 + a" - # qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) - # compute_and_compare_results(self, qf) - - # def test_sum(self): - # f = "def test(a: Qfixed[1,4], b: Qfixed[1,4]) -> Qfixed[1,4]:\n\treturn a + b" - # qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) - # compute_and_compare_results(self, qf) diff --git a/test/utils.py b/test/utils.py index e230a7ee..6eec672d 100644 --- a/test/utils.py +++ b/test/utils.py @@ -168,13 +168,16 @@ def res_to_str(res): return "1" if res else "0" elif type(res) is tuple or type(res) is list: return "".join([res_to_str(x) for x in res]) - elif type(res) is int or type(res) is str or type(res) is float: + elif type(res) is int or type(res) is str: qc = const_to_qtype(res) try: qi = qf.returns.ttype.from_bool(qc[1]) except: qi = qc[0].from_bool(qc[1]) - return qi.to_bin() + return qi.to_bin() + elif type(res) is float: + qi = qf.returns.ttype(res) + return qi.to_bin() else: return res.to_bin() @@ -188,7 +191,11 @@ def res_to_str(res): res_original = qf.original_f(*args) res_original_str = res_to_str(res_original) - # print('\n\n', args, res_original, res_original_str, truth_line) + + # truth_str = "".join( + # map(lambda x: "1" if x else "0", truth_line[-qf.output_size :]) + # ) + # print('\n\n', args, res_original, res_original_str, truth_str) # print (qf.expressions) cls.assertEqual(len(res_original_str), qf.output_size)