Skip to content

Commit

Permalink
aggregate uncomputing and not caching
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 10, 2023
1 parent 245dc7c commit d26cc43
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 12 deletions.
2 changes: 1 addition & 1 deletion qlasskit/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _symplify_exp(self, exp):

# Simplify the expression
exp = simplify(exp)
exp = boolalg.to_cnf(exp)
# exp = boolalg.to_cnf(exp)
return exp

def compile(
Expand Down
24 changes: 22 additions & 2 deletions qlasskit/compiler/poccompiler2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,33 @@ class POCCompiler2(Compiler):
"""POC2 compiler translating an expression list to quantum circuit"""

def compile(self, name, args: Args, ret_size: int, exprs: BoolExpList) -> QCircuit:
self.mapped_not = {}
qc = QCircuit(name=name)

for arg in args:
for arg_b in arg.bitvec:
qc.add_qubit(arg_b)

for sym, exp in exprs:
print(sym, exp, self._symplify_exp(exp))
iret = self.compile_expr(qc, self._symplify_exp(exp))
print("iret", iret)
qc.map_qubit(sym, iret, promote=True)
qc.uncompute2()

return qc

def compile_expr(self, qc: QCircuit, expr: Boolean) -> int:
if isinstance(expr, Symbol):
return qc[expr.name]

elif (
isinstance(expr, Not)
and isinstance(expr.args[0], Symbol)
and expr.args[0] in self.mapped_not
):
return self.mapped_not[expr.args[0]]

elif isinstance(expr, Not):
fa = qc.get_free_ancilla()
eret = self.compile_expr(qc, expr.args[0])
Expand All @@ -51,7 +61,13 @@ def compile_expr(self, qc: QCircuit, expr: Boolean) -> int:
qc.cx(eret, fa)
qc.x(fa)

#qc.free_ancilla(eret)
qc.uncompute2()

qc.mark_ancilla(eret)
# qc.free_ancilla(eret)

if isinstance(expr.args[0], Symbol):
self.mapped_not[expr.args[0]] = fa

return fa

Expand All @@ -63,13 +79,17 @@ def compile_expr(self, qc: QCircuit, expr: Boolean) -> int:

qc.mcx(erets, fa)

qc.free_ancillas(erets)
qc.uncompute2()

[qc.mark_ancilla(eret) for eret in erets]
# qc.free_ancillas(erets)

return fa

elif isinstance(expr, Or):
# Translate or to and
expr = Not(And(*[Not(e) for e in expr.args]))
print("trans", expr)
return self.compile_expr(qc, expr)

# OLD TRANSLATOR
Expand Down
19 changes: 19 additions & 0 deletions qlasskit/qcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,25 @@ def get_free_ancilla(self):

return anc

marked = []

def mark_ancilla(self, w):
self.marked.append(w)

def uncompute2(self):
print("uncomputing", self.ancilla_lst - self.free_ancilla_lst, self.marked)
for g, ws in self.gates_computed[::-1]:
if (
ws[-1] in self.marked
and ws[-1] in self.ancilla_lst
and ws[-1] not in self.free_ancilla_lst
):
self.append(g, ws)
self.free_ancilla_lst.add(ws[-1])

self.marked = []
self.gates_computed = []

def uncompute(self, w):
"""Uncompute a specific ancilla qubit.
Expand Down
13 changes: 9 additions & 4 deletions test/test_qcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,31 @@ def test_base_mapping(self):
class TestQCircuitUncomputing(unittest.TestCase):
def test1(self):
qc = QCircuit()
a, b, c, d = qc.add_qubit(), qc.add_qubit(), qc.add_ancilla(), qc.add_ancilla()
a, b, c, d = (
qc.add_qubit(),
qc.add_qubit(),
qc.add_ancilla(is_free=False),
qc.add_ancilla(is_free=False),
)
f = qc.add_qubit("res")
qc.mcx([a, b], c)
qc.mcx([a, b, c], d)
qc.cx(d, f)
qc.uncompute(c)
qc.uncompute(d) # this is invalidated
qc.uncompute2()
qc.draw()

def test2(self):
qc = QCircuit()
q = [qc.add_qubit() for x in range(4)]
a = [qc.add_ancilla() for x in range(4)]
a = [qc.add_ancilla(is_free=False) for x in range(4)]
r = qc.add_qubit()

qc.mcx(q, a[0])
qc.mcx(q + [a[0]], a[1])
qc.mcx(q + a[:1], a[2])
qc.mcx(q + a[:2], a[3])
qc.cx(a[3], r)
qc.uncompute2()
qc.draw()


Expand Down
10 changes: 5 additions & 5 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,28 +63,28 @@ def compare_circuit_truth_table(cls, qf):
for i in range(qf.input_size):
qc.initialize(1 if truth_line[i] else 0, i)

#(truth_line)
# (truth_line)

qc.append(gate, list(range(qf.num_qubits)))
# print(qc.decompose().draw("text"))
print(circ_qi.draw("text"))
# print(circ_qi.draw("text"))

counts = qiskit_measure_and_count(qc)
#print(counts, circ.qubit_map)
# print(counts, circ.qubit_map)

truth_str = "".join(
map(lambda x: "1" if x else "0", truth_line[-qf.ret_size :])
)

#print(truth_str)
# print(truth_str)

res = list(counts.keys())[0][::-1]
res_str = ""
for qname in qf.truth_table_header()[-qf.ret_size :]:
res_str += res[circ.qubit_map[qname]]

# res = res[0 : len(truth_str)][::-1]
#print(res_str)
# print(res_str)

cls.assertEqual(len(counts), 1)
cls.assertEqual(truth_str, res_str)

0 comments on commit d26cc43

Please sign in to comment.