Skip to content

Commit

Permalink
improve uncomputing, test for max_qubits in circuit, aggregate not gates
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 12, 2023
1 parent 1c592ff commit 9b9531e
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 56 deletions.
11 changes: 11 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"python.testing.unittestArgs": [
"-v",
"-s",
"./test",
"-p",
"*test.py"
],
"python.testing.pytestEnabled": false,
"python.testing.unittestEnabled": true
}
3 changes: 1 addition & 2 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
### Week 3: (9 Oct 23)
- [x] Test circuit and boolexp using the python code as reference
- [x] Qubit garbage uncomputing and recycling
- [ ] Test: add qubit usage check
- [ ] Compiler: remove consecutive X gates
- [x] Test: add qubit usage check
- [ ] Doc: emphatize the compiler flow
- [ ] Doc: properly render documentation
- [ ] Fix structure and typing location
Expand Down
1 change: 0 additions & 1 deletion qlasskit/ast2logic/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def const(v: int) -> List[bool]:
def fill(v: Tuple[TType, List[bool]]) -> Tuple[TType, List[bool]]:
"""Fill a Qint to reach its bit_size"""
if len(v[1]) < v[0].BIT_SIZE: # type: ignore
print("fillused!")
v = (
v[0],
(v[0].BIT_SIZE - len(v[1])) * v[1] + [False], # type: ignore
Expand Down
5 changes: 1 addition & 4 deletions qlasskit/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,15 @@
# isort:skip_file


from .compiler import Compiler, CompilerException # noqa: F401
from .compiler import Compiler, CompilerException, optimizer # noqa: F401

from .multipass import MultipassCompiler
from .poccompiler import POCCompiler
from .poccompiler2 import POCCompiler2


def to_quantum(name, args, ret_size, exprs, compiler="poc2"):
if compiler == "multipass":
s = MultipassCompiler()
elif compiler == "poc":
s = POCCompiler()
elif compiler == "poc2":
s = POCCompiler2()

Expand Down
5 changes: 2 additions & 3 deletions qlasskit/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,11 @@ def optimizer(expr: Boolean) -> Boolean:

class Compiler:
def __init__(self):
self.qmap = {}
pass

def _symplify_exp(self, exp):
# exp = simplify_logic(exp) # TODO: remove this
# exp = simplify_logic(exp)
exp = optimizer(exp)
# print("exp3", exp)
return exp

def compile(
Expand Down
67 changes: 50 additions & 17 deletions qlasskit/compiler/poccompiler2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,41 @@
from . import Compiler, CompilerException


class ExpQMap:
"""Mapping between qubit and boolexp and vice-versa"""

def __init__(self):
self.exp_map: Dict[Boolean, int] = {}

def __contains__(self, k):
return k in self.exp_map

def __getitem__(self, k):
return self.exp_map[k]

def __setitem__(self, k, v):
self.exp_map[k] = v

def remove_map_by_qubits(self, qbs):
todel = []
for k in self.exp_map.keys():
if self.exp_map[k] in qbs:
todel.append(k)

for k in todel:
del self.exp_map[k]

def update_exp_for_qubit(self, qb, exp):
self.remove_map_by_qubits([qb])
self[exp] = qb


class POCCompiler2(Compiler):
"""POC2 compiler translating an expression list to quantum circuit"""

def garbage_collect(self, qc):
uncomputed = qc.uncompute()

for k in self.mapped.keys():
if self.mapped[k] in uncomputed:
del self.mapped[k]
self.expqmap.remove_map_by_qubits(uncomputed)

def compile(self, name, args: Args, ret_size: int, exprs: BoolExpList) -> QCircuit:
qc = QCircuit(name=name)
Expand All @@ -40,7 +66,7 @@ def compile(self, name, args: Args, ret_size: int, exprs: BoolExpList) -> QCircu
for arg_b in arg.bitvec:
qc.add_qubit(arg_b)

self.mapped: Dict[Boolean, int] = {}
self.expqmap = ExpQMap()

for sym, exp in exprs:
# print(sym, self._symplify_exp(exp))
Expand All @@ -56,25 +82,28 @@ def compile_expr(self, qc: QCircuit, expr: Boolean) -> int:
if isinstance(expr, Symbol):
return qc[expr.name]

elif expr in self.mapped:
# print("!!cachehit!!", expr)
return self.mapped[expr]
elif expr in self.expqmap:
return self.expqmap[expr]

elif isinstance(expr, Not):
fa = qc.get_free_ancilla()
eret = self.compile_expr(qc, expr.args[0])

qc.barrier("not")

qc.cx(eret, fa)
qc.x(fa)

qc.mark_ancilla(eret)
self.garbage_collect(qc)
if eret in qc.ancilla_lst:
qc.x(eret)
self.expqmap.update_exp_for_qubit(eret, expr)
return eret
else:
fa = qc.get_free_ancilla()
qc.cx(eret, fa)
qc.x(fa)
qc.mark_ancilla(eret)

self.mapped[expr] = fa
self.garbage_collect(qc)
self.expqmap[expr] = fa

return fa
return fa

elif isinstance(expr, And):
erets = list(map(lambda e: self.compile_expr(qc, e), expr.args))
Expand All @@ -88,12 +117,16 @@ def compile_expr(self, qc: QCircuit, expr: Boolean) -> int:

self.garbage_collect(qc)

self.mapped[expr] = fa
self.expqmap[expr] = fa

return fa

elif isinstance(expr, Xor):
erets = list(map(lambda e: self.compile_expr(qc, e), expr.args))

qc.barrier("xor")
[qc.mark_ancilla(eret) for eret in erets[:-1]]

qc.mcx(erets[0:-1], erets[-1])
return erets[-1]

Expand Down
46 changes: 27 additions & 19 deletions qlasskit/qcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,11 @@ def __init__(self, num_qubits=0, name="qc"):

self.ancilla_lst = set()
self.free_ancilla_lst = set()
self.marked_ancillas = []
self.marked_ancillas = set()

for x in range(num_qubits):
self.qubit_map[f"q{x}"] = x

def __add__(self, other: "QCircuit"):
"""Combine two quantum circuits.
Args:
other (QCircuit): The other quantum circuit to be combined with this one.
"""
self.num_qubits = max(self.num_qubits, other.num_qubits)
self.gates.extend(other.gates)

def get_key_by_index(self, i: int):
"""Return the qubit name given its index"""
for key in self.qubit_map:
Expand Down Expand Up @@ -83,25 +73,37 @@ def get_free_ancilla(self):
return anc

def mark_ancilla(self, w):
self.marked_ancillas.append(w)
if w in self.ancilla_lst:
self.marked_ancillas.add(w)

def uncompute(self, to_mark=[]):
"""Uncompute all the marked ancillas plus the to_mark list"""
[self.mark_ancilla(x) for x in to_mark]

self.barrier(label="un")

uncomputed = set()
new_gates_comp = []
not_to_uncompute = set()

for g, ws in self.gates_computed[::-1]:
if (
ws[-1] in self.marked_ancillas
and ws[-1] in self.ancilla_lst
and ws[-1] not in self.free_ancilla_lst
if ws[-1] in self.marked_ancillas and not all(
[ww in self.marked_ancillas for ww in ws[:-1]]
):
not_to_uncompute.add(ws[-1])

for g, ws in self.gates_computed[::-1]:
if ws[-1] in self.marked_ancillas and ws[-1] not in not_to_uncompute:
uncomputed.add(ws[-1])
self.append(g, ws)
self.free_ancilla_lst.add(ws[-1])
else:
new_gates_comp.append((g, ws))

for x in uncomputed:
self.free_ancilla_lst.add(x)
self.marked_ancillas = self.marked_ancillas - uncomputed
self.gates_computed = new_gates_comp[::-1]

self.marked_ancillas = [] # self.marked_ancillas - uncomputed
self.gates_computed = []
return uncomputed

def map_qubit(self, name, index, promote=False):
Expand Down Expand Up @@ -156,6 +158,12 @@ def append(self, gate_name: str, qubits: List[int]):
if self.num_qubits is None or x > self.num_qubits:
raise Exception(f"qubit {x} not present")

qs = set()
for q in qubits:
if q in qs:
raise Exception(f"duplicate qubit in gate append: {gate_name} {qubits}")
qs.add(q)

self.gates.append((gate_name, qubits))
self.gates_computed.append((gate_name, qubits))

Expand Down
4 changes: 1 addition & 3 deletions test/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@

import unittest

from qiskit import QuantumCircuit

from qlasskit import exceptions, qlassf
from qlasskit import compiler, qlassf

from .utils import compare_circuit_truth_table

Expand Down
21 changes: 21 additions & 0 deletions test/test_qcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,27 @@ def test_base_mapping(self):
self.assertEqual(qc.num_qubits, 3)
self.assertEqual(qc.gates, [("ccx", [0, 1, 2])])

def test_duplicate_qubit(self):
qc = QCircuit()
a, b = qc.add_qubit("a"), qc.add_qubit("b")
self.assertRaises(Exception, lambda qc: qc.toffoli(a, b, a), qc)

def test_mapping(self):
qc = QCircuit(4)
qc.ccx("q0", "q1", "q2")

def test_get_key_by_index(self):
qc = QCircuit()
a, b = qc.add_qubit("a"), qc.add_qubit("b")
self.assertRaises(Exception, lambda qc: qc.get_key_by_index(3), qc)
self.assertEqual(qc.get_key_by_index(0), "a")

def test_add_free_ancilla(self):
qc = QCircuit()
a = qc.add_ancilla(is_free=True)
b = qc.get_free_ancilla()
self.assertEqual(a, b)


class TestQCircuitUncomputing(unittest.TestCase):
def test1(self):
Expand Down
17 changes: 10 additions & 7 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

from qiskit import QuantumCircuit, transpile
from qiskit_aer import Aer
from sympy.logic.boolalg import gateinputcount

from qlasskit import QlassF, Qtype
from qlasskit import QlassF, Qtype, compiler

COMPILATION_ENABLED = True

Expand Down Expand Up @@ -58,6 +59,7 @@ def compare_circuit_truth_table(cls, qf):
circ = qf.circuit()
circ_qi = circ.export("circuit", "qiskit")
# print(circ_qi.draw("text"))
# print(qf.expressions)

for truth_line in truth_table:
qc = QuantumCircuit(gate.num_qubits)
Expand Down Expand Up @@ -85,7 +87,6 @@ def compare_circuit_truth_table(cls, qf):

# Calculate original result from python function
def truth_to_arg(truth, i, argtt):
# print(arg.ttype)
if argtt == bool:
return truth[i], i + 1
elif inspect.isclass(argtt) and issubclass(argtt, Qtype):
Expand All @@ -110,8 +111,6 @@ def truth_to_arg(truth, i, argtt):

res_original = qf.original_f(*args)

print("\nClassical evalution", args, res_original)

def res_to_str(res):
if type(res) == bool:
return "1" if res else "0"
Expand All @@ -121,10 +120,14 @@ def res_to_str(res):
return res.to_bool_str()

res_original_str = res_to_str(res_original)
print("Res (th, or)", res_str, res_original_str, truth_line)
print(qf.expressions)

cls.assertEqual(len(res_original_str), qf.ret_size)
cls.assertEqual(res_str, res_original_str)

# cls.assertLessEqual(gate.num_qubits, len(qf.truth_table_header()))
max_qubits = (
qf.input_size
+ len(qf.expressions)
+ sum([gateinputcount(compiler.optimizer(e[1])) for e in qf.expressions])
)

cls.assertLessEqual(gate.num_qubits, max_qubits)

0 comments on commit 9b9531e

Please sign in to comment.