Skip to content

Commit

Permalink
fix int compare gt
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 11, 2023
1 parent da9b2b4 commit 66d8b74
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 31 deletions.
8 changes: 5 additions & 3 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@
- [x] Typecheck all the expressions

### Week 3: (9 Oct 23)
- [ ] Test circuit and boolexp using the python code as reference
- [ ] Fix structure and typing location
- [x] Test circuit and boolexp using the python code as reference
- [x] Qubit garbage uncomputing and recycling
- [ ] Test: add qubit usage check
- [ ] Compiler: remove consecutive X gates
- [ ] Properly render documentation
- [ ] Doc: emphatize the compiler flow
- [ ] Doc: properly render documentation
- [ ] Fix structure and typing location
- [ ] Parametrize qint tests over bit_size

### Week 4: (16 Oct 23)
- [ ] Int arithmetic expressions (+, -, *, /)
Expand Down
41 changes: 37 additions & 4 deletions qlasskit/ast2logic/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,33 @@ def __init__(self, value):
super().__init__()
self.value = value

def __getitem__(self, i):
if i > self.BIT_SIZE:
raise Exception("Unbound")

return self.to_bool_str()[i] == "1"

@classmethod
def from_bool(cls, v: List[bool]):
return cls(int("".join(map(lambda x: "1" if x else "0", v[::-1])), 2))

def to_bool_str(self) -> str:
s = bin(self.value)[2:][0 : self.BIT_SIZE]
return ("0" * (self.BIT_SIZE - len(s)) + s)[::-1]

@staticmethod
def const(v: int) -> List[bool]:
"""Return the list of bool representing an int"""
return list(map(lambda c: True if c == "1" else False, bin(v)[2:]))
return list(map(lambda c: True if c == "1" else False, bin(v)[2:]))[::-1]

@staticmethod
def fill(v: Tuple[TType, List[bool]]) -> Tuple[TType, List[bool]]:
"""Fill a Qint to reach its bit_size"""
if len(v[1]) < v[0].BIT_SIZE: # type: ignore
print("fillused!")
v = (
v[0],
[False] * (v[0].BIT_SIZE - len(v[1])) + v[1], # type: ignore
(v[0].BIT_SIZE - len(v[1])) * v[1] + [False], # type: ignore
)
return v

Expand Down Expand Up @@ -122,9 +137,27 @@ def not_eq(tleft: TExp, tcomp: TExp) -> TExp:
def gt(tleft: TExp, tcomp: TExp) -> TExp:
"""Compare two Qint for greater than"""
ex = false
prev: List[Symbol] = []

for x in list(zip(tleft[1], tcomp[1]))[::-1]:
ex = Or(ex, And(Not(ex), And(Not(x[1]), x[0])))
for a, b in list(zip(tleft[1], tcomp[1]))[::-1]:
if len(prev) == 0:
ex = And(a, Not(b))
else:
ex = Or(
ex,
And(*([e for e in prev] + [Not(b), a])),
And(*([Not(e) for e in prev] + [Not(b), a])),
)

prev.extend([a, b])

if len(tleft[1]) > len(tcomp[1]):
for x in tleft[1][len(tcomp[1]) :]:
ex = Or(ex, x)

if len(tleft[1]) < len(tcomp[1]):
for x in tcomp[1][len(tleft[1]) :]:
ex = Or(ex, x)

return (bool, ex)

Expand Down
2 changes: 1 addition & 1 deletion qlasskit/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self):
def _symplify_exp(self, exp):
exp = simplify_logic(exp) # TODO: remove this
exp = optimizer(exp)
print("exp3", exp)
# print("exp3", exp)
return exp

def compile(
Expand Down
2 changes: 1 addition & 1 deletion qlasskit/compiler/poccompiler2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def compile(self, name, args: Args, ret_size: int, exprs: BoolExpList) -> QCircu
self.mapped: Dict[Boolean, int] = {}

for sym, exp in exprs:
print(sym, self._symplify_exp(exp))
# print(sym, self._symplify_exp(exp))
iret = self.compile_expr(qc, self._symplify_exp(exp))
# print("iret", iret)
qc.map_qubit(sym, iret, promote=True)
Expand Down
48 changes: 39 additions & 9 deletions test/test_qlassf_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@

import unittest

import pytest
from sympy import Symbol, symbols
from sympy.logic import ITE, And, Not, Or, Xor, false, simplify_logic, true

from qlasskit import QlassF, exceptions, qlassf
from qlasskit import QlassF, exceptions, qlassf # Qint2

from .utils import COMPILATION_ENABLED, compare_circuit_truth_table

a, b, c, d = symbols("a,b,c,d")
_ret = Symbol("_ret")


# @pytest.mark.parametrize("qint", [Qint2])
class TestQlassfInt(unittest.TestCase):
def test_int_arg(self):
f = "def test(a: Qint2) -> bool:\n\treturn a[0]"
Expand Down Expand Up @@ -90,7 +92,7 @@ def test_int_const_compare_eq(self):
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
self.assertEqual(len(qf.expressions), 1)
self.assertEqual(qf.expressions[0][0], _ret)
self.assertEqual(qf.expressions[0][1], And(Symbol("a.0"), Not(Symbol("a.1"))))
self.assertEqual(qf.expressions[0][1], And(Symbol("a.1"), Not(Symbol("a.0"))))
compare_circuit_truth_table(self, qf)

def test_int_const_compare_eq_different_type(self):
Expand All @@ -101,8 +103,8 @@ def test_int_const_compare_eq_different_type(self):
self.assertEqual(
qf.expressions[0][1],
And(
Symbol("a.0"),
Not(Symbol("a.1")),
Symbol("a.1"),
Not(Symbol("a.0")),
Not(Symbol("a.2")),
Not(Symbol("a.3")),
),
Expand All @@ -117,8 +119,8 @@ def test_const_int_compare_eq_different_type(self):
self.assertEqual(
qf.expressions[0][1],
And(
Symbol("a.0"),
Not(Symbol("a.1")),
Symbol("a.1"),
Not(Symbol("a.0")),
Not(Symbol("a.2")),
Not(Symbol("a.3")),
),
Expand All @@ -133,8 +135,8 @@ def test_const_int_compare_neq_different_type(self):
self.assertEqual(
qf.expressions[0][1],
Or(
Not(Symbol("a.0")),
Symbol("a.1"),
Not(Symbol("a.1")),
Symbol("a.0"),
Symbol("a.2"),
Symbol("a.3"),
),
Expand Down Expand Up @@ -170,13 +172,41 @@ def test_int_int_compare_neq(self):
compare_circuit_truth_table(self, qf)

def test_const_int_compare_gt(self):
f = "def test(a: Qint2, b: Qint2) -> bool:\n\treturn a > b"
f = "def test(a: Qint2) -> bool:\n\treturn a > 1"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
self.assertEqual(len(qf.expressions), 1)
self.assertEqual(qf.expressions[0][0], _ret)
compare_circuit_truth_table(self, qf)

def test_const_int4_compare_gt(self):
f = "def test(a: Qint4) -> bool:\n\treturn a > 3"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
self.assertEqual(len(qf.expressions), 1)
self.assertEqual(qf.expressions[0][0], _ret)
compare_circuit_truth_table(self, qf)

def test_const_int_compare_lt(self):
f = "def test(a: Qint2) -> bool:\n\treturn a < 2"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
self.assertEqual(len(qf.expressions), 1)
self.assertEqual(qf.expressions[0][0], _ret)
compare_circuit_truth_table(self, qf)

# def test_const_int4_compare_lt(self):
# f = "def test(a: Qint4) -> bool:\n\treturn a < 6"
# qf = qlassf(f, to_compile=COMPILATION_ENABLED)
# self.assertEqual(len(qf.expressions), 1)
# self.assertEqual(qf.expressions[0][0], _ret)
# compare_circuit_truth_table(self, qf)

def test_int_int_compare_gt(self):
f = "def test(a: Qint2, b: Qint2) -> bool:\n\treturn a > b"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
self.assertEqual(len(qf.expressions), 1)
self.assertEqual(qf.expressions[0][0], _ret)
compare_circuit_truth_table(self, qf)

def test_int_int_compare_lt(self):
f = "def test(a: Qint2, b: Qint2) -> bool:\n\treturn a < b"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
self.assertEqual(len(qf.expressions), 1)
Expand Down
66 changes: 53 additions & 13 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Tuple, get_args

from qiskit import QuantumCircuit, transpile
from qiskit_aer import Aer

from qlasskit import QlassF
from qlasskit import QlassF, Qtype

COMPILATION_ENABLED = True

Expand Down Expand Up @@ -49,42 +52,79 @@ def qiskit_measure_and_count(circ, shots=1):
def compare_circuit_truth_table(cls, qf):
if not COMPILATION_ENABLED:
return

truth_table = qf.truth_table()
gate = qf.gate()
circ = qf.circuit()
circ_qi = circ.export("circuit", "qiskit")
print(circ_qi.draw("text"))
# print(circ_qi.draw("text"))

for truth_line in truth_table:
qc = QuantumCircuit(gate.num_qubits)

# Prepare inputs
for i in range(qf.input_size):
qc.initialize(1 if truth_line[i] else 0, i)

# (truth_line)
[qc.initialize(1 if truth_line[i] else 0, i) for i in range(qf.input_size)]

qc.append(gate, list(range(qf.num_qubits)))
# print(qc.decompose().draw("text"))

# Measure
counts = qiskit_measure_and_count(qc)
# print(counts, circ.qubit_map)

# Extract str of truthtable and result
truth_str = "".join(
map(lambda x: "1" if x else "0", truth_line[-qf.ret_size :])
)

# print(truth_str)

res = list(counts.keys())[0][::-1]
res_str = ""
for qname in qf.truth_table_header()[-qf.ret_size :]:
res_str += res[circ.qubit_map[qname]]

# res = res[0 : len(truth_str)][::-1]
# print(res_str)

cls.assertEqual(len(counts), 1)
cls.assertEqual(truth_str, res_str)

# Calculate original result from python function
def truth_to_arg(truth, i, argtt):
# print(arg.ttype)
if argtt == bool:
return truth[i], i + 1
elif inspect.isclass(argtt) and issubclass(argtt, Qtype):
return (
argtt.from_bool(truth[i : i + argtt.BIT_SIZE]),
i + argtt.BIT_SIZE,
)
else: # A tuple
al = []
for x in get_args(argtt):
a, i = truth_to_arg(truth, i, x)
al.append(a)
return tuple(al), i

args = []
i = 0
for x in qf.args:
arg, i = truth_to_arg(truth_line, i, x.ttype)
args.append(arg)

cls.assertEqual(i, qf.input_size)

res_original = qf.original_f(*args)

# print("Classical evalution", args, res_original)

def res_to_str(res):
if type(res) == bool:
return "1" if res else "0"
elif type(res) == tuple:
return "".join([res_to_str(x) for x in res])
else:
return res.to_bool_str()

res_original_str = res_to_str(res_original)
# print("Res (th, or)", res_str, res_original_str, truth_line)
# print(qf.expressions)

cls.assertEqual(len(res_original_str), qf.ret_size)
cls.assertEqual(res_str, res_original_str)

# cls.assertLessEqual(gate.num_qubits, len(qf.truth_table_header()))

0 comments on commit 66d8b74

Please sign in to comment.