From 66d8b74e4c25a501903329fdb820cfbc99fcf3c1 Mon Sep 17 00:00:00 2001 From: "Davide Gessa (dakk)" Date: Wed, 11 Oct 2023 14:28:44 +0200 Subject: [PATCH] fix int compare gt --- TODO.md | 8 ++-- qlasskit/ast2logic/typing.py | 41 +++++++++++++++++-- qlasskit/compiler/compiler.py | 2 +- qlasskit/compiler/poccompiler2.py | 2 +- test/test_qlassf_int.py | 48 +++++++++++++++++----- test/utils.py | 66 +++++++++++++++++++++++++------ 6 files changed, 136 insertions(+), 31 deletions(-) diff --git a/TODO.md b/TODO.md index 241c2c9d..ad0b7433 100644 --- a/TODO.md +++ b/TODO.md @@ -36,12 +36,14 @@ - [x] Typecheck all the expressions ### Week 3: (9 Oct 23) -- [ ] Test circuit and boolexp using the python code as reference -- [ ] Fix structure and typing location +- [x] Test circuit and boolexp using the python code as reference - [x] Qubit garbage uncomputing and recycling - [ ] Test: add qubit usage check - [ ] Compiler: remove consecutive X gates -- [ ] Properly render documentation +- [ ] Doc: emphatize the compiler flow +- [ ] Doc: properly render documentation +- [ ] Fix structure and typing location +- [ ] Parametrize qint tests over bit_size ### Week 4: (16 Oct 23) - [ ] Int arithmetic expressions (+, -, *, /) diff --git a/qlasskit/ast2logic/typing.py b/qlasskit/ast2logic/typing.py index fdf4d5c7..ab066717 100644 --- a/qlasskit/ast2logic/typing.py +++ b/qlasskit/ast2logic/typing.py @@ -69,18 +69,33 @@ def __init__(self, value): super().__init__() self.value = value + def __getitem__(self, i): + if i > self.BIT_SIZE: + raise Exception("Unbound") + + return self.to_bool_str()[i] == "1" + + @classmethod + def from_bool(cls, v: List[bool]): + return cls(int("".join(map(lambda x: "1" if x else "0", v[::-1])), 2)) + + def to_bool_str(self) -> str: + s = bin(self.value)[2:][0 : self.BIT_SIZE] + return ("0" * (self.BIT_SIZE - len(s)) + s)[::-1] + @staticmethod def const(v: int) -> List[bool]: """Return the list of bool representing an int""" - return list(map(lambda c: True if c == "1" else False, bin(v)[2:])) + return list(map(lambda c: True if c == "1" else False, bin(v)[2:]))[::-1] @staticmethod def fill(v: Tuple[TType, List[bool]]) -> Tuple[TType, List[bool]]: """Fill a Qint to reach its bit_size""" if len(v[1]) < v[0].BIT_SIZE: # type: ignore + print("fillused!") v = ( v[0], - [False] * (v[0].BIT_SIZE - len(v[1])) + v[1], # type: ignore + (v[0].BIT_SIZE - len(v[1])) * v[1] + [False], # type: ignore ) return v @@ -122,9 +137,27 @@ def not_eq(tleft: TExp, tcomp: TExp) -> TExp: def gt(tleft: TExp, tcomp: TExp) -> TExp: """Compare two Qint for greater than""" ex = false + prev: List[Symbol] = [] - for x in list(zip(tleft[1], tcomp[1]))[::-1]: - ex = Or(ex, And(Not(ex), And(Not(x[1]), x[0]))) + for a, b in list(zip(tleft[1], tcomp[1]))[::-1]: + if len(prev) == 0: + ex = And(a, Not(b)) + else: + ex = Or( + ex, + And(*([e for e in prev] + [Not(b), a])), + And(*([Not(e) for e in prev] + [Not(b), a])), + ) + + prev.extend([a, b]) + + if len(tleft[1]) > len(tcomp[1]): + for x in tleft[1][len(tcomp[1]) :]: + ex = Or(ex, x) + + if len(tleft[1]) < len(tcomp[1]): + for x in tcomp[1][len(tleft[1]) :]: + ex = Or(ex, x) return (bool, ex) diff --git a/qlasskit/compiler/compiler.py b/qlasskit/compiler/compiler.py index 01675308..eabdc0af 100644 --- a/qlasskit/compiler/compiler.py +++ b/qlasskit/compiler/compiler.py @@ -84,7 +84,7 @@ def __init__(self): def _symplify_exp(self, exp): exp = simplify_logic(exp) # TODO: remove this exp = optimizer(exp) - print("exp3", exp) + # print("exp3", exp) return exp def compile( diff --git a/qlasskit/compiler/poccompiler2.py b/qlasskit/compiler/poccompiler2.py index 036767ca..6f384a47 100644 --- a/qlasskit/compiler/poccompiler2.py +++ b/qlasskit/compiler/poccompiler2.py @@ -43,7 +43,7 @@ def compile(self, name, args: Args, ret_size: int, exprs: BoolExpList) -> QCircu self.mapped: Dict[Boolean, int] = {} for sym, exp in exprs: - print(sym, self._symplify_exp(exp)) + # print(sym, self._symplify_exp(exp)) iret = self.compile_expr(qc, self._symplify_exp(exp)) # print("iret", iret) qc.map_qubit(sym, iret, promote=True) diff --git a/test/test_qlassf_int.py b/test/test_qlassf_int.py index f082dab3..6843feb0 100644 --- a/test/test_qlassf_int.py +++ b/test/test_qlassf_int.py @@ -14,10 +14,11 @@ import unittest +import pytest from sympy import Symbol, symbols from sympy.logic import ITE, And, Not, Or, Xor, false, simplify_logic, true -from qlasskit import QlassF, exceptions, qlassf +from qlasskit import QlassF, exceptions, qlassf # Qint2 from .utils import COMPILATION_ENABLED, compare_circuit_truth_table @@ -25,6 +26,7 @@ _ret = Symbol("_ret") +# @pytest.mark.parametrize("qint", [Qint2]) class TestQlassfInt(unittest.TestCase): def test_int_arg(self): f = "def test(a: Qint2) -> bool:\n\treturn a[0]" @@ -90,7 +92,7 @@ def test_int_const_compare_eq(self): qf = qlassf(f, to_compile=COMPILATION_ENABLED) self.assertEqual(len(qf.expressions), 1) self.assertEqual(qf.expressions[0][0], _ret) - self.assertEqual(qf.expressions[0][1], And(Symbol("a.0"), Not(Symbol("a.1")))) + self.assertEqual(qf.expressions[0][1], And(Symbol("a.1"), Not(Symbol("a.0")))) compare_circuit_truth_table(self, qf) def test_int_const_compare_eq_different_type(self): @@ -101,8 +103,8 @@ def test_int_const_compare_eq_different_type(self): self.assertEqual( qf.expressions[0][1], And( - Symbol("a.0"), - Not(Symbol("a.1")), + Symbol("a.1"), + Not(Symbol("a.0")), Not(Symbol("a.2")), Not(Symbol("a.3")), ), @@ -117,8 +119,8 @@ def test_const_int_compare_eq_different_type(self): self.assertEqual( qf.expressions[0][1], And( - Symbol("a.0"), - Not(Symbol("a.1")), + Symbol("a.1"), + Not(Symbol("a.0")), Not(Symbol("a.2")), Not(Symbol("a.3")), ), @@ -133,8 +135,8 @@ def test_const_int_compare_neq_different_type(self): self.assertEqual( qf.expressions[0][1], Or( - Not(Symbol("a.0")), - Symbol("a.1"), + Not(Symbol("a.1")), + Symbol("a.0"), Symbol("a.2"), Symbol("a.3"), ), @@ -170,13 +172,41 @@ def test_int_int_compare_neq(self): compare_circuit_truth_table(self, qf) def test_const_int_compare_gt(self): - f = "def test(a: Qint2, b: Qint2) -> bool:\n\treturn a > b" + f = "def test(a: Qint2) -> bool:\n\treturn a > 1" + qf = qlassf(f, to_compile=COMPILATION_ENABLED) + self.assertEqual(len(qf.expressions), 1) + self.assertEqual(qf.expressions[0][0], _ret) + compare_circuit_truth_table(self, qf) + + def test_const_int4_compare_gt(self): + f = "def test(a: Qint4) -> bool:\n\treturn a > 3" qf = qlassf(f, to_compile=COMPILATION_ENABLED) self.assertEqual(len(qf.expressions), 1) self.assertEqual(qf.expressions[0][0], _ret) compare_circuit_truth_table(self, qf) def test_const_int_compare_lt(self): + f = "def test(a: Qint2) -> bool:\n\treturn a < 2" + qf = qlassf(f, to_compile=COMPILATION_ENABLED) + self.assertEqual(len(qf.expressions), 1) + self.assertEqual(qf.expressions[0][0], _ret) + compare_circuit_truth_table(self, qf) + + # def test_const_int4_compare_lt(self): + # f = "def test(a: Qint4) -> bool:\n\treturn a < 6" + # qf = qlassf(f, to_compile=COMPILATION_ENABLED) + # self.assertEqual(len(qf.expressions), 1) + # self.assertEqual(qf.expressions[0][0], _ret) + # compare_circuit_truth_table(self, qf) + + def test_int_int_compare_gt(self): + f = "def test(a: Qint2, b: Qint2) -> bool:\n\treturn a > b" + qf = qlassf(f, to_compile=COMPILATION_ENABLED) + self.assertEqual(len(qf.expressions), 1) + self.assertEqual(qf.expressions[0][0], _ret) + compare_circuit_truth_table(self, qf) + + def test_int_int_compare_lt(self): f = "def test(a: Qint2, b: Qint2) -> bool:\n\treturn a < b" qf = qlassf(f, to_compile=COMPILATION_ENABLED) self.assertEqual(len(qf.expressions), 1) diff --git a/test/utils.py b/test/utils.py index 986a67c0..37d07ecf 100644 --- a/test/utils.py +++ b/test/utils.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect +from typing import Tuple, get_args + from qiskit import QuantumCircuit, transpile from qiskit_aer import Aer -from qlasskit import QlassF +from qlasskit import QlassF, Qtype COMPILATION_ENABLED = True @@ -49,42 +52,79 @@ def qiskit_measure_and_count(circ, shots=1): def compare_circuit_truth_table(cls, qf): if not COMPILATION_ENABLED: return + truth_table = qf.truth_table() gate = qf.gate() circ = qf.circuit() circ_qi = circ.export("circuit", "qiskit") - print(circ_qi.draw("text")) + # print(circ_qi.draw("text")) for truth_line in truth_table: qc = QuantumCircuit(gate.num_qubits) # Prepare inputs - for i in range(qf.input_size): - qc.initialize(1 if truth_line[i] else 0, i) - - # (truth_line) + [qc.initialize(1 if truth_line[i] else 0, i) for i in range(qf.input_size)] qc.append(gate, list(range(qf.num_qubits))) - # print(qc.decompose().draw("text")) + # Measure counts = qiskit_measure_and_count(qc) - # print(counts, circ.qubit_map) + # Extract str of truthtable and result truth_str = "".join( map(lambda x: "1" if x else "0", truth_line[-qf.ret_size :]) ) - # print(truth_str) - res = list(counts.keys())[0][::-1] res_str = "" for qname in qf.truth_table_header()[-qf.ret_size :]: res_str += res[circ.qubit_map[qname]] - # res = res[0 : len(truth_str)][::-1] - # print(res_str) - cls.assertEqual(len(counts), 1) cls.assertEqual(truth_str, res_str) + # Calculate original result from python function + def truth_to_arg(truth, i, argtt): + # print(arg.ttype) + if argtt == bool: + return truth[i], i + 1 + elif inspect.isclass(argtt) and issubclass(argtt, Qtype): + return ( + argtt.from_bool(truth[i : i + argtt.BIT_SIZE]), + i + argtt.BIT_SIZE, + ) + else: # A tuple + al = [] + for x in get_args(argtt): + a, i = truth_to_arg(truth, i, x) + al.append(a) + return tuple(al), i + + args = [] + i = 0 + for x in qf.args: + arg, i = truth_to_arg(truth_line, i, x.ttype) + args.append(arg) + + cls.assertEqual(i, qf.input_size) + + res_original = qf.original_f(*args) + + # print("Classical evalution", args, res_original) + + def res_to_str(res): + if type(res) == bool: + return "1" if res else "0" + elif type(res) == tuple: + return "".join([res_to_str(x) for x in res]) + else: + return res.to_bool_str() + + res_original_str = res_to_str(res_original) + # print("Res (th, or)", res_str, res_original_str, truth_line) + # print(qf.expressions) + + cls.assertEqual(len(res_original_str), qf.ret_size) + cls.assertEqual(res_str, res_original_str) + # cls.assertLessEqual(gate.num_qubits, len(qf.truth_table_header()))