diff --git a/qlasskit/ast2ast.py b/qlasskit/ast2ast.py index 35b509ba..6bfaa473 100644 --- a/qlasskit/ast2ast.py +++ b/qlasskit/ast2ast.py @@ -72,9 +72,9 @@ def _replace_types_annotations(ann, arg=None): isinstance(ann, ast.Subscript) and isinstance(ann.value, ast.Name) and ann.value.id == "Tuple" - and hasattr(ann.slice, 'elts') + and hasattr(ann.slice, "elts") ): - _elts = ann.slice.elts + _elts = ann.slice.elts _ituple = ast.Tuple(elts=[_replace_types_annotations(el) for el in _elts]) ann = ast.Subscript( @@ -101,7 +101,7 @@ def _replace_types_annotations(ann, arg=None): isinstance(ann, ast.Subscript) and isinstance(ann.value, ast.Name) and ann.value.id == "Qlist" - and hasattr(ann.slice, 'elts') + and hasattr(ann.slice, "elts") ): _elts = ann.slice.elts _ituple = ast.Tuple(elts=[copy.deepcopy(_elts[0])] * _elts[1].value) @@ -116,7 +116,7 @@ def _replace_types_annotations(ann, arg=None): isinstance(ann, ast.Subscript) and isinstance(ann.value, ast.Name) and ann.value.id == "Qmatrix" - and hasattr(ann.slice, 'elts') + and hasattr(ann.slice, "elts") ): _elts = ann.slice.elts _ituple_row = ast.Tuple(elts=[copy.deepcopy(_elts[0])] * _elts[2].value) @@ -425,13 +425,13 @@ def visit_For(self, node): # noqa: C901 isinstance(iter, ast.Subscript) and isinstance(iter.value, ast.Name) and iter.value.id in self.env - and hasattr(iter.slice, 'value') + and hasattr(iter.slice, "value") ): if isinstance(self.env[iter.value.id], ast.Tuple): new_iter = self.env[iter.value.id].elts[iter.slice.value] - elif isinstance(self.env[iter.value.id], ast.Subscript): - _elts = self.env[iter.value.id].slice.elts[iter.slice.value] + elif isinstance(self.env[iter.value.id], ast.Subscript): + _elts = self.env[iter.value.id].slice.elts[iter.slice.value] if isinstance(_elts, ast.Tuple): _elts = _elts.elts @@ -439,8 +439,8 @@ def visit_For(self, node): # noqa: C901 new_iter = [ ast.Subscript( value=ast.Subscript( - value=ast.Name(id=iter.value.id, ctx=ast.Load()), - slice=ast.Constant(value=iter.slice.value), + value=ast.Name(id=iter.value.id, ctx=ast.Load()), + slice=ast.Constant(value=iter.slice.value), ctx=ast.Load(), ), slice=ast.Constant(value=e), @@ -448,8 +448,8 @@ def visit_For(self, node): # noqa: C901 for e in range(len(_elts)) ] else: - new_iter = iter - + new_iter = iter + iter = new_iter if isinstance(iter, ast.Constant) and isinstance(iter.value, ast.Tuple): diff --git a/qlasskit/types/qint.py b/qlasskit/types/qint.py index f818a83c..1d19edef 100644 --- a/qlasskit/types/qint.py +++ b/qlasskit/types/qint.py @@ -164,9 +164,52 @@ def add(cls, tleft: TExp, tright: TExp) -> TExp: return (cls if cls.BIT_SIZE > tleft_e[0].BIT_SIZE else tleft_e[0], sums) + @staticmethod + def mul_even_const(t_num: TExp, t_const: TExp) -> TExp: + """Multiply by an even const using shift and add + (x << 3) + (x << 1) # Here 10*x is computed as x*2^3 + x*2 + """ + const = t_const[0].from_bool(t_const[1]) + + # Multiply t_num by the nearest n | 2**n < t_const + n = 1 + while 2**n <= const: + n += 1 + if 2**n > const: + n -= 1 + + t_num_r = t_num[0].crop(t_num[0].shift_left(t_num, n)) + + # Shift t_const by t_const - 2**n + r = const - 2**n + if r > 0: + # Add the shift result to t_num + res = t_num_r[0].add( + t_num_r, t_num[0].crop(t_num[0].shift_left(t_num, int(r / 2))) + ) + else: + res = t_num_r + + return res + @classmethod def mul(cls, tleft: TExp, tright: TExp) -> TExp: # noqa: C901 - # TODO: use RGQFT multiplier + # Fill constants so explicit typecast is not needed + if cls.is_const(tleft): + tleft = tright[0].fill(tleft) + + if cls.is_const(tright): + tright = tleft[0].fill(tright) + + # If one operand is an even constant, use mul_even_const + if cls.is_const(tleft) or cls.is_const(tright): + t_num = tleft if cls.is_const(tright) else tright + t_const = tleft if cls.is_const(tleft) else tright + const = t_const[0].from_bool(t_const[1]) + + if const % 2 == 0: + return cls.mul_even_const(t_num, t_const) + n = len(tleft[1]) m = len(tright[1]) diff --git a/qlasskit/types/qtype.py b/qlasskit/types/qtype.py index 88904faf..982dd168 100644 --- a/qlasskit/types/qtype.py +++ b/qlasskit/types/qtype.py @@ -187,12 +187,12 @@ def bitwise_not(v: TExp) -> TExp: @staticmethod def shift_right(v: TExp, i: int = 1) -> TExp: """Apply a shift right""" - return (v[0], v[1][i:]) + return v[0].fill((v[0], v[1][i:])) @staticmethod def shift_left(v: TExp, i: int = 1) -> TExp: """Apply a shift left""" - return (v[0], [False] * i + v[1]) + return v[0].crop((v[0], [False] * i + v[1])) @staticmethod def add(tleft: TExp, tright: TExp) -> TExp: diff --git a/test/qlassf/test_int.py b/test/qlassf/test_int.py index 5251362c..b58574da 100644 --- a/test/qlassf/test_int.py +++ b/test/qlassf/test_int.py @@ -293,13 +293,27 @@ def test_composed_comparators(self): qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) compute_and_compare_results(self, qf) - def test_shift_left(self): - f = "def test(n: Qint[2]) -> Qint[4]: return n << 1" + @parameterized.expand( + [ + (1,), + (2,), + (3,), + ] + ) + def test_shift_left(self, v): + f = f"def test(n: Qint[4]) -> Qint[4]: return n << {v}" qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) compute_and_compare_results(self, qf) - def test_shift_right(self): - f = "def test(n: Qint[2]) -> Qint[4]: return n >> 1" + @parameterized.expand( + [ + (1,), + (2,), + (3,), + ] + ) + def test_shift_right(self, v): + f = f"def test(n: Qint[2]) -> Qint[4]: return n >> {v}" qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) compute_and_compare_results(self, qf) @@ -479,3 +493,21 @@ def test_mul5(self): qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) compute_and_compare_results(self, qf) self.assertEqual(qf.expressions[0][1], True) + + +@parameterized_class( + ("ttype_i", "ttype_o", "const", "compiler"), + inject_parameterized_compilers( + [ + (4, 4, 2), + (4, 4, 3), + (4, 4, 4), + (4, 4, 6), + ] + ), +) +class TestQlassfIntMulByConst(unittest.TestCase): + def test_mul(self): + f = f"def test(a: Qint[{self.ttype_i}]) -> Qint[{self.ttype_o}]: return a * {self.const}" + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) diff --git a/test/utils.py b/test/utils.py index 8a93a858..4d2831c3 100644 --- a/test/utils.py +++ b/test/utils.py @@ -223,7 +223,7 @@ def compute_and_compare_results(cls, qf, test_original_f=True, test_qcircuit=Tru # circ_qi = qf.circuit().export("circuit", "qiskit") - # update_statistics(qf.circuit().num_qubits, qf.circuit().num_gates) + update_statistics(qf.circuit().num_qubits, qf.circuit().num_gates) # print(qf.expressions) # print(circ_qi.draw("text"))