Skip to content

Commit

Permalink
shift and add multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Apr 16, 2024
1 parent 74c61ab commit 48cf912
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 19 deletions.
22 changes: 11 additions & 11 deletions qlasskit/ast2ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def _replace_types_annotations(ann, arg=None):
isinstance(ann, ast.Subscript)
and isinstance(ann.value, ast.Name)
and ann.value.id == "Tuple"
and hasattr(ann.slice, 'elts')
and hasattr(ann.slice, "elts")
):
_elts = ann.slice.elts
_elts = ann.slice.elts
_ituple = ast.Tuple(elts=[_replace_types_annotations(el) for el in _elts])

ann = ast.Subscript(
Expand All @@ -101,7 +101,7 @@ def _replace_types_annotations(ann, arg=None):
isinstance(ann, ast.Subscript)
and isinstance(ann.value, ast.Name)
and ann.value.id == "Qlist"
and hasattr(ann.slice, 'elts')
and hasattr(ann.slice, "elts")
):
_elts = ann.slice.elts
_ituple = ast.Tuple(elts=[copy.deepcopy(_elts[0])] * _elts[1].value)
Expand All @@ -116,7 +116,7 @@ def _replace_types_annotations(ann, arg=None):
isinstance(ann, ast.Subscript)
and isinstance(ann.value, ast.Name)
and ann.value.id == "Qmatrix"
and hasattr(ann.slice, 'elts')
and hasattr(ann.slice, "elts")
):
_elts = ann.slice.elts
_ituple_row = ast.Tuple(elts=[copy.deepcopy(_elts[0])] * _elts[2].value)
Expand Down Expand Up @@ -425,31 +425,31 @@ def visit_For(self, node): # noqa: C901
isinstance(iter, ast.Subscript)
and isinstance(iter.value, ast.Name)
and iter.value.id in self.env
and hasattr(iter.slice, 'value')
and hasattr(iter.slice, "value")
):
if isinstance(self.env[iter.value.id], ast.Tuple):
new_iter = self.env[iter.value.id].elts[iter.slice.value]

elif isinstance(self.env[iter.value.id], ast.Subscript):
_elts = self.env[iter.value.id].slice.elts[iter.slice.value]
elif isinstance(self.env[iter.value.id], ast.Subscript):
_elts = self.env[iter.value.id].slice.elts[iter.slice.value]

if isinstance(_elts, ast.Tuple):
_elts = _elts.elts

new_iter = [
ast.Subscript(
value=ast.Subscript(
value=ast.Name(id=iter.value.id, ctx=ast.Load()),
slice=ast.Constant(value=iter.slice.value),
value=ast.Name(id=iter.value.id, ctx=ast.Load()),
slice=ast.Constant(value=iter.slice.value),
ctx=ast.Load(),
),
slice=ast.Constant(value=e),
)
for e in range(len(_elts))
]
else:
new_iter = iter
new_iter = iter

iter = new_iter

if isinstance(iter, ast.Constant) and isinstance(iter.value, ast.Tuple):
Expand Down
45 changes: 44 additions & 1 deletion qlasskit/types/qint.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,52 @@ def add(cls, tleft: TExp, tright: TExp) -> TExp:

return (cls if cls.BIT_SIZE > tleft_e[0].BIT_SIZE else tleft_e[0], sums)

@staticmethod
def mul_even_const(t_num: TExp, t_const: TExp) -> TExp:
"""Multiply by an even const using shift and add
(x << 3) + (x << 1) # Here 10*x is computed as x*2^3 + x*2
"""
const = t_const[0].from_bool(t_const[1])

# Multiply t_num by the nearest n | 2**n < t_const
n = 1
while 2**n <= const:
n += 1
if 2**n > const:
n -= 1

t_num_r = t_num[0].crop(t_num[0].shift_left(t_num, n))

# Shift t_const by t_const - 2**n
r = const - 2**n
if r > 0:
# Add the shift result to t_num
res = t_num_r[0].add(
t_num_r, t_num[0].crop(t_num[0].shift_left(t_num, int(r / 2)))
)
else:
res = t_num_r

return res

@classmethod
def mul(cls, tleft: TExp, tright: TExp) -> TExp: # noqa: C901
# TODO: use RGQFT multiplier
# Fill constants so explicit typecast is not needed
if cls.is_const(tleft):
tleft = tright[0].fill(tleft)

if cls.is_const(tright):
tright = tleft[0].fill(tright)

# If one operand is an even constant, use mul_even_const
if cls.is_const(tleft) or cls.is_const(tright):
t_num = tleft if cls.is_const(tright) else tright
t_const = tleft if cls.is_const(tleft) else tright
const = t_const[0].from_bool(t_const[1])

if const % 2 == 0:
return cls.mul_even_const(t_num, t_const)

n = len(tleft[1])
m = len(tright[1])

Expand Down
4 changes: 2 additions & 2 deletions qlasskit/types/qtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,12 @@ def bitwise_not(v: TExp) -> TExp:
@staticmethod
def shift_right(v: TExp, i: int = 1) -> TExp:
"""Apply a shift right"""
return (v[0], v[1][i:])
return v[0].fill((v[0], v[1][i:]))

@staticmethod
def shift_left(v: TExp, i: int = 1) -> TExp:
"""Apply a shift left"""
return (v[0], [False] * i + v[1])
return v[0].crop((v[0], [False] * i + v[1]))

@staticmethod
def add(tleft: TExp, tright: TExp) -> TExp:
Expand Down
40 changes: 36 additions & 4 deletions test/qlassf/test_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,27 @@ def test_composed_comparators(self):
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

def test_shift_left(self):
f = "def test(n: Qint[2]) -> Qint[4]: return n << 1"
@parameterized.expand(
[
(1,),
(2,),
(3,),
]
)
def test_shift_left(self, v):
f = f"def test(n: Qint[4]) -> Qint[4]: return n << {v}"
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

def test_shift_right(self):
f = "def test(n: Qint[2]) -> Qint[4]: return n >> 1"
@parameterized.expand(
[
(1,),
(2,),
(3,),
]
)
def test_shift_right(self, v):
f = f"def test(n: Qint[2]) -> Qint[4]: return n >> {v}"
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

Expand Down Expand Up @@ -479,3 +493,21 @@ def test_mul5(self):
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)
self.assertEqual(qf.expressions[0][1], True)


@parameterized_class(
("ttype_i", "ttype_o", "const", "compiler"),
inject_parameterized_compilers(
[
(4, 4, 2),
(4, 4, 3),
(4, 4, 4),
(4, 4, 6),
]
),
)
class TestQlassfIntMulByConst(unittest.TestCase):
def test_mul(self):
f = f"def test(a: Qint[{self.ttype_i}]) -> Qint[{self.ttype_o}]: return a * {self.const}"
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)
2 changes: 1 addition & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def compute_and_compare_results(cls, qf, test_original_f=True, test_qcircuit=Tru

# circ_qi = qf.circuit().export("circuit", "qiskit")

# update_statistics(qf.circuit().num_qubits, qf.circuit().num_gates)
update_statistics(qf.circuit().num_qubits, qf.circuit().num_gates)

# print(qf.expressions)
# print(circ_qi.draw("text"))
Expand Down

0 comments on commit 48cf912

Please sign in to comment.