From c861bf30555e18c6431cfadcf880db18eec7d20a Mon Sep 17 00:00:00 2001 From: "Davide Gessa (dakk)" Date: Wed, 18 Oct 2023 16:55:07 +0200 Subject: [PATCH] add operation --- qlasskit/ast2logic/t_expression.py | 9 ++++++++- qlasskit/types/__init__.py | 10 +++++++++- qlasskit/types/qint.py | 25 ++++++++++++++++++++----- qlasskit/types/qtype.py | 6 ++++++ test/test_qlassf_int.py | 12 ++++++++++++ 5 files changed, 55 insertions(+), 7 deletions(-) diff --git a/qlasskit/ast2logic/t_expression.py b/qlasskit/ast2logic/t_expression.py index 71d27e2b..5a727008 100644 --- a/qlasskit/ast2logic/t_expression.py +++ b/qlasskit/ast2logic/t_expression.py @@ -199,7 +199,14 @@ def unfold(v_exps, op): elif isinstance(expr, ast.BinOp): # Add | Sub | Mult | MatMult | Div | Mod | Pow | LShift | RShift # | BitOr | BitXor | BitAnd | FloorDiv - raise exceptions.ExpressionNotHandledException(expr) + # print(ast.dump(expr)) + tleft = translate_expression(expr.left, env) + tright = translate_expression(expr.right, env) + + if isinstance(expr.op, ast.Add): + return tleft[0].add(tleft, tright) + else: + raise exceptions.ExpressionNotHandledException(expr) # Call elif isinstance(expr, ast.Call): diff --git a/qlasskit/types/__init__.py b/qlasskit/types/__init__.py index 052d7bff..5296c2d7 100644 --- a/qlasskit/types/__init__.py +++ b/qlasskit/types/__init__.py @@ -15,7 +15,7 @@ from typing import Any -from sympy.logic import Not, Xor +from sympy.logic import Not, Xor, And, Or def _neq(a, b): @@ -26,6 +26,14 @@ def _eq(a, b): return Not(_neq(a, b)) +def _half_adder(a, b): # Carry x Sum + return And(a, b), Xor(a, b) + + +def _full_adder(c, a, b): # Carry x Sum + return Or(And(a, b), And(b, c), And(a, c)), Xor(a, b, c) + + from .qtype import Qtype, TExp, TType # noqa: F401, E402 from .qbool import Qbool # noqa: F401, E402 from .qint import Qint, Qint2, Qint4, Qint8, Qint12, Qint16 # noqa: F401, E402 diff --git a/qlasskit/types/qint.py b/qlasskit/types/qint.py index 4abf8aa7..2808810e 100644 --- a/qlasskit/types/qint.py +++ b/qlasskit/types/qint.py @@ -17,7 +17,7 @@ from sympy import Symbol from sympy.logic import And, Not, Or, false, true -from . import _eq, _neq +from . import _eq, _full_adder, _neq from .qtype import Qtype, TExp @@ -26,7 +26,10 @@ class Qint(int, Qtype): def __init__(self, value): super().__init__() - self.value = value + self.value = value % 2**self.BIT_SIZE + + def __add__(self, b): + return (self.value + b) % 2**self.BIT_SIZE @classmethod def from_bool(cls, v: List[bool]): @@ -140,9 +143,21 @@ def gte(tleft: TExp, tcomp: TExp) -> TExp: """Compare two Qint for greater than - equal""" return (bool, Not(Qint.lt(tleft, tcomp)[1])) - # @staticmethod - # def add(tleft: TExp, tright: TExp) -> TExp: - # """Add two Qint""" + # Operations + + @classmethod + def add(cls, tleft: TExp, tright: TExp) -> TExp: + """Add two Qint""" + if len(tleft[1]) != len(tright[1]): # TODO: handle this + raise Exception("Ints have differnt sizes") + + carry = False + sums = [] + for x in zip(tleft[1], tright[1]): + carry, sum = _full_adder(carry, x[0], x[1]) + sums.append(sum) + + return (cls, sums) class Qint2(Qint): diff --git a/qlasskit/types/qtype.py b/qlasskit/types/qtype.py index 1c7f7e2c..df9a4c5c 100644 --- a/qlasskit/types/qtype.py +++ b/qlasskit/types/qtype.py @@ -98,3 +98,9 @@ def lt(tleft: TExp, tcomp: TExp) -> TExp: @staticmethod def lte(tleft: TExp, tcomp: TExp) -> TExp: raise Exception("abstract") + + # Operations + + @staticmethod + def add(tleft: TExp, tcomp: TExp) -> TExp: + raise Exception("abstract") diff --git a/test/test_qlassf_int.py b/test/test_qlassf_int.py index 988f57aa..fa93701b 100644 --- a/test/test_qlassf_int.py +++ b/test/test_qlassf_int.py @@ -291,3 +291,15 @@ def test_composed_comparators(self): # return a + 1 # def test(a: Qint2, b: Qint2) -> Qint2: # return a + b + + +class TestQlassfIntAdd(unittest.TestCase): + def test_add(self): + f = "def test(a: Qint2, b: Qint2) -> Qint2: return a + b" + qf = qlassf(f, to_compile=COMPILATION_ENABLED) + compute_and_compare_results(self, qf) + + def test_add_const(self): + f = "def test(a: Qint2) -> Qint2: return a + 1" + qf = qlassf(f, to_compile=COMPILATION_ENABLED) + compute_and_compare_results(self, qf)