Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shift and add multiplication #41

Merged
merged 4 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions qlasskit/ast2ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def _replace_types_annotations(ann, arg=None):
isinstance(ann, ast.Subscript)
and isinstance(ann.value, ast.Name)
and ann.value.id == "Tuple"
and hasattr(ann.slice, 'elts')
and hasattr(ann.slice, "elts")
):
_elts = ann.slice.elts
_elts = ann.slice.elts
_ituple = ast.Tuple(elts=[_replace_types_annotations(el) for el in _elts])

ann = ast.Subscript(
Expand All @@ -101,7 +101,7 @@ def _replace_types_annotations(ann, arg=None):
isinstance(ann, ast.Subscript)
and isinstance(ann.value, ast.Name)
and ann.value.id == "Qlist"
and hasattr(ann.slice, 'elts')
and hasattr(ann.slice, "elts")
):
_elts = ann.slice.elts
_ituple = ast.Tuple(elts=[copy.deepcopy(_elts[0])] * _elts[1].value)
Expand All @@ -116,7 +116,7 @@ def _replace_types_annotations(ann, arg=None):
isinstance(ann, ast.Subscript)
and isinstance(ann.value, ast.Name)
and ann.value.id == "Qmatrix"
and hasattr(ann.slice, 'elts')
and hasattr(ann.slice, "elts")
):
_elts = ann.slice.elts
_ituple_row = ast.Tuple(elts=[copy.deepcopy(_elts[0])] * _elts[2].value)
Expand Down Expand Up @@ -425,31 +425,31 @@ def visit_For(self, node): # noqa: C901
isinstance(iter, ast.Subscript)
and isinstance(iter.value, ast.Name)
and iter.value.id in self.env
and hasattr(iter.slice, 'value')
and hasattr(iter.slice, "value")
):
if isinstance(self.env[iter.value.id], ast.Tuple):
new_iter = self.env[iter.value.id].elts[iter.slice.value]

elif isinstance(self.env[iter.value.id], ast.Subscript):
_elts = self.env[iter.value.id].slice.elts[iter.slice.value]
elif isinstance(self.env[iter.value.id], ast.Subscript):
_elts = self.env[iter.value.id].slice.elts[iter.slice.value]

if isinstance(_elts, ast.Tuple):
_elts = _elts.elts

new_iter = [
ast.Subscript(
value=ast.Subscript(
value=ast.Name(id=iter.value.id, ctx=ast.Load()),
slice=ast.Constant(value=iter.slice.value),
value=ast.Name(id=iter.value.id, ctx=ast.Load()),
slice=ast.Constant(value=iter.slice.value),
ctx=ast.Load(),
),
slice=ast.Constant(value=e),
)
for e in range(len(_elts))
]
else:
new_iter = iter
new_iter = iter

iter = new_iter

if isinstance(iter, ast.Constant) and isinstance(iter.value, ast.Tuple):
Expand Down
98 changes: 79 additions & 19 deletions qlasskit/types/qint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sympy.logic import And, Not, Or, Xor, false, true

from . import TypeErrorException, _eq, _full_adder, _neq
from .qtype import Qtype, TExp, bin_to_bool_list, bool_list_to_bin
from .qtype import Qtype, TExp, TType, bin_to_bool_list, bool_list_to_bin


class QintImp(int, Qtype):
Expand Down Expand Up @@ -164,12 +164,86 @@ def add(cls, tleft: TExp, tright: TExp) -> TExp:

return (cls if cls.BIT_SIZE > tleft_e[0].BIT_SIZE else tleft_e[0], sums)

@staticmethod
def mul_even_const(t_num: TExp, const: int, result_type: Qtype) -> TExp:
"""Multiply by an even const using shift and add
(x << 3) + (x << 1) # Here 10*x is computed as x*2^3 + x*2
"""

# Multiply t_num by the nearest n | 2**n < t_const
n = 1
while 2**n <= const:
n += 1
if 2**n > const:
n -= 1

result_ttype = cast(TType, result_type)

t_num_r = result_type.shift_left((result_ttype, t_num[1]), n)

# Shift t_const by t_const - 2**n
r = const - 2**n
if r > 0:
# Add the shift result to t_num
res = result_type.add(
(result_ttype, t_num_r[1]),
result_type.shift_left((result_ttype, t_num[1]), int(r / 2)),
)
else:
res = (result_ttype, t_num_r[1])

return res

@classmethod
def mul(cls, tleft: TExp, tright: TExp) -> TExp: # noqa: C901
# TODO: use RGQFT multiplier
def mul(cls, tleft_: TExp, tright_: TExp) -> TExp: # noqa: C901
if not issubclass(tleft_[0], Qtype):
raise TypeErrorException(tleft_[0], Qtype)
if not issubclass(tright_[0], Qtype):
raise TypeErrorException(tright_[0], Qtype)

def __mul_sizing(n, m):
if (n + m) <= 2:
return Qint2
elif (n + m) > 2 and (n + m) <= 4:
return Qint4
elif (n + m) > 4 and (n + m) <= 6:
return Qint6
elif (n + m) > 6 and (n + m) <= 8:
return Qint8
elif (n + m) > 8 and (n + m) <= 12:
return Qint12
elif (n + m) > 12 and (n + m) <= 16:
return Qint16
elif (n + m) > 16:
return Qint16
else:
raise Exception(f"Mul result size is too big ({n+m})")

# Fill constants so explicit typecast is not needed
if cls.is_const(tleft_):
tleft = tright_[0].fill(tleft_)
else:
tleft = tleft_

if cls.is_const(tright_):
tright = tleft_[0].fill(tright_)
else:
tright = tright_

n = len(tleft[1])
m = len(tright[1])

# If one operand is an even constant, use mul_even_const
if cls.is_const(tleft) or cls.is_const(tright):
t_num = tleft if cls.is_const(tright) else tright
t_const = tleft if cls.is_const(tleft) else tright
const = cast(int, cast(Qtype, t_const[0]).from_bool(t_const[1]))

if const % 2 == 0:
t = __mul_sizing(n, m)
res = cls.mul_even_const(t_num, const, t)
return t.crop(t.fill(res))

if n != m:
raise Exception(f"Mul works only on same size Qint: {n} != {m}")

Expand All @@ -190,22 +264,8 @@ def mul(cls, tleft: TExp, tright: TExp) -> TExp: # noqa: C901
if i + m < n + m:
product[i + m] = carry

if (n + m) <= 2:
return Qint2, product
elif (n + m) > 2 and (n + m) <= 4:
return Qint4, product
elif (n + m) > 4 and (n + m) <= 6:
return Qint6, product
elif (n + m) > 6 and (n + m) <= 8:
return Qint8, product
elif (n + m) > 8 and (n + m) <= 12:
return Qint12, product
elif (n + m) > 12 and (n + m) <= 16:
return Qint16, product
elif (n + m) > 16:
return Qint16.crop((Qint16, product))

raise Exception(f"Mul result size is too big ({n+m})")
t = __mul_sizing(n, m)
return t.crop(t.fill((t, product)))

@classmethod
def sub(cls, tleft: TExp, tright: TExp) -> TExp:
Expand Down
10 changes: 8 additions & 2 deletions qlasskit/types/qtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,18 @@ def bitwise_not(v: TExp) -> TExp:
@staticmethod
def shift_right(v: TExp, i: int = 1) -> TExp:
"""Apply a shift right"""
return (v[0], v[1][i:])
if not issubclass(v[0], Qtype):
raise TypeErrorException(v[0], Qtype)

return v[0].fill((v[0], v[1][i:]))

@staticmethod
def shift_left(v: TExp, i: int = 1) -> TExp:
"""Apply a shift left"""
return (v[0], [False] * i + v[1])
if not issubclass(v[0], Qtype):
raise TypeErrorException(v[0], Qtype)

return v[0].crop((v[0], [False] * i + v[1]))

@staticmethod
def add(tleft: TExp, tright: TExp) -> TExp:
Expand Down
42 changes: 38 additions & 4 deletions test/qlassf/test_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,27 @@ def test_composed_comparators(self):
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

def test_shift_left(self):
f = "def test(n: Qint[2]) -> Qint[4]: return n << 1"
@parameterized.expand(
[
(1,),
(2,),
(3,),
]
)
def test_shift_left(self, v):
f = f"def test(n: Qint[4]) -> Qint[4]: return n << {v}"
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

def test_shift_right(self):
f = "def test(n: Qint[2]) -> Qint[4]: return n >> 1"
@parameterized.expand(
[
(1,),
(2,),
(3,),
]
)
def test_shift_right(self, v):
f = f"def test(n: Qint[2]) -> Qint[4]: return n >> {v}"
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

Expand Down Expand Up @@ -479,3 +493,23 @@ def test_mul5(self):
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)
self.assertEqual(qf.expressions[0][1], True)


@parameterized_class(
("ttype_i", "ttype_o", "const", "compiler"),
inject_parameterized_compilers(
[
(4, 6, 2),
(4, 6, 4),
(4, 6, 6),
(6, 8, 6),
(6, 8, 8),
(6, 8, 10),
]
),
)
class TestQlassfIntMulByConst(unittest.TestCase):
def test_mul(self):
f = f"def test(a: Qint[{self.ttype_i}]) -> Qint[{self.ttype_o}]: return a * {self.const}"
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)
2 changes: 1 addition & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def compute_and_compare_results(cls, qf, test_original_f=True, test_qcircuit=Tru

# circ_qi = qf.circuit().export("circuit", "qiskit")

# update_statistics(qf.circuit().num_qubits, qf.circuit().num_gates)
update_statistics(qf.circuit().num_qubits, qf.circuit().num_gates)

# print(qf.expressions)
# print(circ_qi.draw("text"))
Expand Down
Loading