Skip to content

Commit

Permalink
fix int.gt and compiler optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 11, 2023
1 parent 66d8b74 commit 380d8f5
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 25 deletions.
9 changes: 2 additions & 7 deletions qlasskit/ast2logic/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,15 @@ def not_eq(tleft: TExp, tcomp: TExp) -> TExp:
@staticmethod
def gt(tleft: TExp, tcomp: TExp) -> TExp:
"""Compare two Qint for greater than"""
ex = false
prev: List[Symbol] = []

for a, b in list(zip(tleft[1], tcomp[1]))[::-1]:
if len(prev) == 0:
ex = And(a, Not(b))
else:
ex = Or(
ex,
And(*([e for e in prev] + [Not(b), a])),
And(*([Not(e) for e in prev] + [Not(b), a])),
)
ex = Or(ex, And(*(prev + [a, Not(b)])))

prev.extend([a, b])
prev.append(bool_eq(a, b))

if len(tleft[1]) > len(tcomp[1]):
for x in tleft[1][len(tcomp[1]) :]:
Expand Down
10 changes: 6 additions & 4 deletions qlasskit/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


from sympy import Symbol
from sympy.logic import ITE, And, Implies, Not, Or, Xor, simplify_logic
from sympy.logic import ITE, And, Implies, Not, Or, Xor # , simplify_logic
from sympy.logic.boolalg import Boolean, BooleanFalse, BooleanTrue

from .. import QCircuit
Expand All @@ -31,10 +31,12 @@ def optimizer(expr: Boolean) -> Boolean:

elif isinstance(expr, ITE):
c = optimizer(expr.args[0])
return Or(And(c, optimizer(expr.args[1])), And(Not(c), optimizer(expr.args[2])))
return optimizer(
Or(And(c, optimizer(expr.args[1])), And(Not(c), optimizer(expr.args[2])))
)

elif isinstance(expr, Implies):
return Or(Not(optimizer(expr.args[0])), optimizer(expr.args[1]))
return optimizer(Or(Not(optimizer(expr.args[0])), optimizer(expr.args[1])))

elif isinstance(expr, Not):
return Not(optimizer(expr.args[0]))
Expand Down Expand Up @@ -82,7 +84,7 @@ def __init__(self):
self.qmap = {}

def _symplify_exp(self, exp):
exp = simplify_logic(exp) # TODO: remove this
# exp = simplify_logic(exp) # TODO: remove this
exp = optimizer(exp)
# print("exp3", exp)
return exp
Expand Down
28 changes: 17 additions & 11 deletions test/test_qlassf_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import unittest

import pytest
from sympy import Symbol, symbols
from sympy.logic import ITE, And, Not, Or, Xor, false, simplify_logic, true

Expand All @@ -26,7 +25,6 @@
_ret = Symbol("_ret")


# @pytest.mark.parametrize("qint", [Qint2])
class TestQlassfInt(unittest.TestCase):
def test_int_arg(self):
f = "def test(a: Qint2) -> bool:\n\treturn a[0]"
Expand All @@ -37,12 +35,13 @@ def test_int_arg(self):
# compare_circuit_truth_table(self, qf)

def test_int_arg2(self):
f = "def test(a: Qint2, b: bool) -> bool:\n\treturn True if (a[0] and b) else a[1]"
f = "def test(a: Qint2, b: bool) -> bool:\n\treturn a[1] if (a[0] and b) else a[0]"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
self.assertEqual(len(qf.expressions), 1)
self.assertEqual(qf.expressions[0][0], _ret)
self.assertEqual(
qf.expressions[0][1], ITE(And(Symbol("a.0"), b), True, Symbol("a.1"))
qf.expressions[0][1],
ITE(And(Symbol("a.0"), b), Symbol("a.1"), Symbol("a.0")),
)
compare_circuit_truth_table(self, qf)

Expand Down Expand Up @@ -179,25 +178,32 @@ def test_const_int_compare_gt(self):
compare_circuit_truth_table(self, qf)

def test_const_int4_compare_gt(self):
f = "def test(a: Qint4) -> bool:\n\treturn a > 3"
f = "def test(a: Qint4) -> bool:\n\treturn a > 6"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
self.assertEqual(len(qf.expressions), 1)
self.assertEqual(qf.expressions[0][0], _ret)
compare_circuit_truth_table(self, qf)

# def test_int4_int4_compare_gt(self):
# f = "def test(a: Qint4, b: Qint4) -> bool:\n\treturn a > b"
# qf = qlassf(f, to_compile=COMPILATION_ENABLED)
# self.assertEqual(len(qf.expressions), 1)
# self.assertEqual(qf.expressions[0][0], _ret)
# compare_circuit_truth_table(self, qf)

def test_const_int_compare_lt(self):
f = "def test(a: Qint2) -> bool:\n\treturn a < 2"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
self.assertEqual(len(qf.expressions), 1)
self.assertEqual(qf.expressions[0][0], _ret)
compare_circuit_truth_table(self, qf)

# def test_const_int4_compare_lt(self):
# f = "def test(a: Qint4) -> bool:\n\treturn a < 6"
# qf = qlassf(f, to_compile=COMPILATION_ENABLED)
# self.assertEqual(len(qf.expressions), 1)
# self.assertEqual(qf.expressions[0][0], _ret)
# compare_circuit_truth_table(self, qf)
def test_const_int4_compare_lt(self):
f = "def test(a: Qint4) -> bool:\n\treturn a < 6"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
self.assertEqual(len(qf.expressions), 1)
self.assertEqual(qf.expressions[0][0], _ret)
compare_circuit_truth_table(self, qf)

def test_int_int_compare_gt(self):
f = "def test(a: Qint2, b: Qint2) -> bool:\n\treturn a > b"
Expand Down
6 changes: 3 additions & 3 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def truth_to_arg(truth, i, argtt):

res_original = qf.original_f(*args)

# print("Classical evalution", args, res_original)
print("\nClassical evalution", args, res_original)

def res_to_str(res):
if type(res) == bool:
Expand All @@ -121,8 +121,8 @@ def res_to_str(res):
return res.to_bool_str()

res_original_str = res_to_str(res_original)
# print("Res (th, or)", res_str, res_original_str, truth_line)
# print(qf.expressions)
print("Res (th, or)", res_str, res_original_str, truth_line)
print(qf.expressions)

cls.assertEqual(len(res_original_str), qf.ret_size)
cls.assertEqual(res_str, res_original_str)
Expand Down

0 comments on commit 380d8f5

Please sign in to comment.