diff --git a/qlasskit/ast2ast.py b/qlasskit/ast2ast.py index 35b509ba..6bfaa473 100644 --- a/qlasskit/ast2ast.py +++ b/qlasskit/ast2ast.py @@ -72,9 +72,9 @@ def _replace_types_annotations(ann, arg=None): isinstance(ann, ast.Subscript) and isinstance(ann.value, ast.Name) and ann.value.id == "Tuple" - and hasattr(ann.slice, 'elts') + and hasattr(ann.slice, "elts") ): - _elts = ann.slice.elts + _elts = ann.slice.elts _ituple = ast.Tuple(elts=[_replace_types_annotations(el) for el in _elts]) ann = ast.Subscript( @@ -101,7 +101,7 @@ def _replace_types_annotations(ann, arg=None): isinstance(ann, ast.Subscript) and isinstance(ann.value, ast.Name) and ann.value.id == "Qlist" - and hasattr(ann.slice, 'elts') + and hasattr(ann.slice, "elts") ): _elts = ann.slice.elts _ituple = ast.Tuple(elts=[copy.deepcopy(_elts[0])] * _elts[1].value) @@ -116,7 +116,7 @@ def _replace_types_annotations(ann, arg=None): isinstance(ann, ast.Subscript) and isinstance(ann.value, ast.Name) and ann.value.id == "Qmatrix" - and hasattr(ann.slice, 'elts') + and hasattr(ann.slice, "elts") ): _elts = ann.slice.elts _ituple_row = ast.Tuple(elts=[copy.deepcopy(_elts[0])] * _elts[2].value) @@ -425,13 +425,13 @@ def visit_For(self, node): # noqa: C901 isinstance(iter, ast.Subscript) and isinstance(iter.value, ast.Name) and iter.value.id in self.env - and hasattr(iter.slice, 'value') + and hasattr(iter.slice, "value") ): if isinstance(self.env[iter.value.id], ast.Tuple): new_iter = self.env[iter.value.id].elts[iter.slice.value] - elif isinstance(self.env[iter.value.id], ast.Subscript): - _elts = self.env[iter.value.id].slice.elts[iter.slice.value] + elif isinstance(self.env[iter.value.id], ast.Subscript): + _elts = self.env[iter.value.id].slice.elts[iter.slice.value] if isinstance(_elts, ast.Tuple): _elts = _elts.elts @@ -439,8 +439,8 @@ def visit_For(self, node): # noqa: C901 new_iter = [ ast.Subscript( value=ast.Subscript( - value=ast.Name(id=iter.value.id, ctx=ast.Load()), - slice=ast.Constant(value=iter.slice.value), + value=ast.Name(id=iter.value.id, ctx=ast.Load()), + slice=ast.Constant(value=iter.slice.value), ctx=ast.Load(), ), slice=ast.Constant(value=e), @@ -448,8 +448,8 @@ def visit_For(self, node): # noqa: C901 for e in range(len(_elts)) ] else: - new_iter = iter - + new_iter = iter + iter = new_iter if isinstance(iter, ast.Constant) and isinstance(iter.value, ast.Tuple): diff --git a/qlasskit/types/qint.py b/qlasskit/types/qint.py index f818a83c..d7a48343 100644 --- a/qlasskit/types/qint.py +++ b/qlasskit/types/qint.py @@ -18,7 +18,7 @@ from sympy.logic import And, Not, Or, Xor, false, true from . import TypeErrorException, _eq, _full_adder, _neq -from .qtype import Qtype, TExp, bin_to_bool_list, bool_list_to_bin +from .qtype import Qtype, TExp, TType, bin_to_bool_list, bool_list_to_bin class QintImp(int, Qtype): @@ -164,12 +164,86 @@ def add(cls, tleft: TExp, tright: TExp) -> TExp: return (cls if cls.BIT_SIZE > tleft_e[0].BIT_SIZE else tleft_e[0], sums) + @staticmethod + def mul_even_const(t_num: TExp, const: int, result_type: Qtype) -> TExp: + """Multiply by an even const using shift and add + (x << 3) + (x << 1) # Here 10*x is computed as x*2^3 + x*2 + """ + + # Multiply t_num by the nearest n | 2**n < t_const + n = 1 + while 2**n <= const: + n += 1 + if 2**n > const: + n -= 1 + + result_ttype = cast(TType, result_type) + + t_num_r = result_type.shift_left((result_ttype, t_num[1]), n) + + # Shift t_const by t_const - 2**n + r = const - 2**n + if r > 0: + # Add the shift result to t_num + res = result_type.add( + (result_ttype, t_num_r[1]), + result_type.shift_left((result_ttype, t_num[1]), int(r / 2)), + ) + else: + res = (result_ttype, t_num_r[1]) + + return res + @classmethod - def mul(cls, tleft: TExp, tright: TExp) -> TExp: # noqa: C901 - # TODO: use RGQFT multiplier + def mul(cls, tleft_: TExp, tright_: TExp) -> TExp: # noqa: C901 + if not issubclass(tleft_[0], Qtype): + raise TypeErrorException(tleft_[0], Qtype) + if not issubclass(tright_[0], Qtype): + raise TypeErrorException(tright_[0], Qtype) + + def __mul_sizing(n, m): + if (n + m) <= 2: + return Qint2 + elif (n + m) > 2 and (n + m) <= 4: + return Qint4 + elif (n + m) > 4 and (n + m) <= 6: + return Qint6 + elif (n + m) > 6 and (n + m) <= 8: + return Qint8 + elif (n + m) > 8 and (n + m) <= 12: + return Qint12 + elif (n + m) > 12 and (n + m) <= 16: + return Qint16 + elif (n + m) > 16: + return Qint16 + else: + raise Exception(f"Mul result size is too big ({n+m})") + + # Fill constants so explicit typecast is not needed + if cls.is_const(tleft_): + tleft = tright_[0].fill(tleft_) + else: + tleft = tleft_ + + if cls.is_const(tright_): + tright = tleft_[0].fill(tright_) + else: + tright = tright_ + n = len(tleft[1]) m = len(tright[1]) + # If one operand is an even constant, use mul_even_const + if cls.is_const(tleft) or cls.is_const(tright): + t_num = tleft if cls.is_const(tright) else tright + t_const = tleft if cls.is_const(tleft) else tright + const = cast(int, cast(Qtype, t_const[0]).from_bool(t_const[1])) + + if const % 2 == 0: + t = __mul_sizing(n, m) + res = cls.mul_even_const(t_num, const, t) + return t.crop(t.fill(res)) + if n != m: raise Exception(f"Mul works only on same size Qint: {n} != {m}") @@ -190,22 +264,8 @@ def mul(cls, tleft: TExp, tright: TExp) -> TExp: # noqa: C901 if i + m < n + m: product[i + m] = carry - if (n + m) <= 2: - return Qint2, product - elif (n + m) > 2 and (n + m) <= 4: - return Qint4, product - elif (n + m) > 4 and (n + m) <= 6: - return Qint6, product - elif (n + m) > 6 and (n + m) <= 8: - return Qint8, product - elif (n + m) > 8 and (n + m) <= 12: - return Qint12, product - elif (n + m) > 12 and (n + m) <= 16: - return Qint16, product - elif (n + m) > 16: - return Qint16.crop((Qint16, product)) - - raise Exception(f"Mul result size is too big ({n+m})") + t = __mul_sizing(n, m) + return t.crop(t.fill((t, product))) @classmethod def sub(cls, tleft: TExp, tright: TExp) -> TExp: diff --git a/qlasskit/types/qtype.py b/qlasskit/types/qtype.py index 88904faf..fab782a6 100644 --- a/qlasskit/types/qtype.py +++ b/qlasskit/types/qtype.py @@ -187,12 +187,18 @@ def bitwise_not(v: TExp) -> TExp: @staticmethod def shift_right(v: TExp, i: int = 1) -> TExp: """Apply a shift right""" - return (v[0], v[1][i:]) + if not issubclass(v[0], Qtype): + raise TypeErrorException(v[0], Qtype) + + return v[0].fill((v[0], v[1][i:])) @staticmethod def shift_left(v: TExp, i: int = 1) -> TExp: """Apply a shift left""" - return (v[0], [False] * i + v[1]) + if not issubclass(v[0], Qtype): + raise TypeErrorException(v[0], Qtype) + + return v[0].crop((v[0], [False] * i + v[1])) @staticmethod def add(tleft: TExp, tright: TExp) -> TExp: diff --git a/test/qlassf/test_int.py b/test/qlassf/test_int.py index 5251362c..8d2e672b 100644 --- a/test/qlassf/test_int.py +++ b/test/qlassf/test_int.py @@ -293,13 +293,27 @@ def test_composed_comparators(self): qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) compute_and_compare_results(self, qf) - def test_shift_left(self): - f = "def test(n: Qint[2]) -> Qint[4]: return n << 1" + @parameterized.expand( + [ + (1,), + (2,), + (3,), + ] + ) + def test_shift_left(self, v): + f = f"def test(n: Qint[4]) -> Qint[4]: return n << {v}" qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) compute_and_compare_results(self, qf) - def test_shift_right(self): - f = "def test(n: Qint[2]) -> Qint[4]: return n >> 1" + @parameterized.expand( + [ + (1,), + (2,), + (3,), + ] + ) + def test_shift_right(self, v): + f = f"def test(n: Qint[2]) -> Qint[4]: return n >> {v}" qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) compute_and_compare_results(self, qf) @@ -479,3 +493,23 @@ def test_mul5(self): qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) compute_and_compare_results(self, qf) self.assertEqual(qf.expressions[0][1], True) + + +@parameterized_class( + ("ttype_i", "ttype_o", "const", "compiler"), + inject_parameterized_compilers( + [ + (4, 6, 2), + (4, 6, 4), + (4, 6, 6), + (6, 8, 6), + (6, 8, 8), + (6, 8, 10), + ] + ), +) +class TestQlassfIntMulByConst(unittest.TestCase): + def test_mul(self): + f = f"def test(a: Qint[{self.ttype_i}]) -> Qint[{self.ttype_o}]: return a * {self.const}" + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) diff --git a/test/utils.py b/test/utils.py index 8a93a858..4d2831c3 100644 --- a/test/utils.py +++ b/test/utils.py @@ -223,7 +223,7 @@ def compute_and_compare_results(cls, qf, test_original_f=True, test_qcircuit=Tru # circ_qi = qf.circuit().export("circuit", "qiskit") - # update_statistics(qf.circuit().num_qubits, qf.circuit().num_gates) + update_statistics(qf.circuit().num_qubits, qf.circuit().num_gates) # print(qf.expressions) # print(circ_qi.draw("text"))