Skip to content

Commit

Permalink
cache computations
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 10, 2023
1 parent d26cc43 commit d442a9d
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 92 deletions.
2 changes: 1 addition & 1 deletion qlasskit/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


from sympy import simplify, symbols
from sympy.logic import ITE, And, Implies, Not, Or, boolalg
from sympy.logic import ITE, And, Implies, Not, Or # , boolalg

from .. import QCircuit
from ..ast2logic.typing import Args, BoolExpList
Expand Down
38 changes: 19 additions & 19 deletions qlasskit/compiler/poccompiler2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License..

from typing import Dict

from sympy import Symbol
from sympy.logic import And, Not, Or
from sympy.logic.boolalg import Boolean, BooleanFalse, BooleanTrue
Expand All @@ -25,32 +27,34 @@ class POCCompiler2(Compiler):
"""POC2 compiler translating an expression list to quantum circuit"""

def compile(self, name, args: Args, ret_size: int, exprs: BoolExpList) -> QCircuit:
self.mapped_not = {}
qc = QCircuit(name=name)

for arg in args:
for arg_b in arg.bitvec:
qc.add_qubit(arg_b)

self.mapped: Dict[Boolean, int] = {}

for sym, exp in exprs:
print(sym, exp, self._symplify_exp(exp))
# print(sym, self._symplify_exp(exp))
iret = self.compile_expr(qc, self._symplify_exp(exp))
print("iret", iret)
# print("iret", iret)
qc.map_qubit(sym, iret, promote=True)
qc.uncompute2()
uncomputed = qc.uncompute()

for k in self.mapped.keys():
if self.mapped[k] in uncomputed:
del self.mapped[k]

return qc

def compile_expr(self, qc: QCircuit, expr: Boolean) -> int:
if isinstance(expr, Symbol):
return qc[expr.name]

elif (
isinstance(expr, Not)
and isinstance(expr.args[0], Symbol)
and expr.args[0] in self.mapped_not
):
return self.mapped_not[expr.args[0]]
elif expr in self.mapped:
print("!!cachehit!!", expr)
return self.mapped[expr]

elif isinstance(expr, Not):
fa = qc.get_free_ancilla()
Expand All @@ -61,13 +65,11 @@ def compile_expr(self, qc: QCircuit, expr: Boolean) -> int:
qc.cx(eret, fa)
qc.x(fa)

qc.uncompute2()
qc.uncompute()

qc.mark_ancilla(eret)
# qc.free_ancilla(eret)

if isinstance(expr.args[0], Symbol):
self.mapped_not[expr.args[0]] = fa
self.mapped[expr] = fa

return fa

Expand All @@ -79,17 +81,17 @@ def compile_expr(self, qc: QCircuit, expr: Boolean) -> int:

qc.mcx(erets, fa)

qc.uncompute2()
qc.uncompute()

[qc.mark_ancilla(eret) for eret in erets]
# qc.free_ancillas(erets)
self.mapped[expr] = fa

return fa

elif isinstance(expr, Or):
# Translate or to and
expr = Not(And(*[Not(e) for e in expr.args]))
print("trans", expr)
# print("trans", expr)
return self.compile_expr(qc, expr)

# OLD TRANSLATOR
Expand All @@ -107,8 +109,6 @@ def compile_expr(self, qc: QCircuit, expr: Boolean) -> int:
# for j in range(i + 1, nclau - i):
# qc.x(iclau[j])

# qc.free_ancillas(iclau)

# return fa

elif isinstance(expr, BooleanFalse) or isinstance(expr, BooleanTrue):
Expand Down
81 changes: 11 additions & 70 deletions qlasskit/qcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, num_qubits=0, name="qc"):

self.ancilla_lst = set()
self.free_ancilla_lst = set()
self.marked_ancillas = []

for x in range(num_qubits):
self.qubit_map[f"q{x}"] = x
Expand Down Expand Up @@ -72,24 +73,6 @@ def add_ancilla(self, name=None, is_free=True):
self.free_ancilla_lst.add(i)
return i

def free_ancilla(self, w):
"""Freeing of an ancilla qubit"""
w = self[w]
if w not in self.ancilla_lst:
return # we don't care
# raise Exception(f"Qubit {w} is not in the ancilla set")

if w in self.free_ancilla_lst:
raise Exception(f"Ancilla {w} is already free")

self.uncompute(w)
self.free_ancilla_lst.add(w)

def free_ancillas(self, wl):
"""Freeing of a list of ancilla qubits"""
for w in wl:
self.free_ancilla(w)

def get_free_ancilla(self):
"""Get the first free ancilla available"""
if len(self.free_ancilla_lst) == 0:
Expand All @@ -99,69 +82,27 @@ def get_free_ancilla(self):

return anc

marked = []

def mark_ancilla(self, w):
self.marked.append(w)
self.marked_ancillas.append(w)

def uncompute2(self):
print("uncomputing", self.ancilla_lst - self.free_ancilla_lst, self.marked)
def uncompute(self, to_mark=[]):
"""Uncompute all the marked ancillas plus the to_mark list"""
[self.mark_ancilla(x) for x in to_mark]

uncomputed = set()
for g, ws in self.gates_computed[::-1]:
if (
ws[-1] in self.marked
ws[-1] in self.marked_ancillas
and ws[-1] in self.ancilla_lst
and ws[-1] not in self.free_ancilla_lst
):
uncomputed.add(ws[-1])
self.append(g, ws)
self.free_ancilla_lst.add(ws[-1])

self.marked = []
self.marked_ancillas = [] # self.marked_ancillas - uncomputed
self.gates_computed = []

def uncompute(self, w):
"""Uncompute a specific ancilla qubit.
Args:
w (int): The index of the qubit to be uncomputed.
"""
w = self[w]
if w not in self.ancilla_lst:
raise Exception("qubit not in the ancilla list")

print("uncomputing ", w)

g_comp = []
self.barrier(label=f"U{w}")
for g, ws in self.gates_computed[::-1]:
# w is the target
if w == ws[-1]:
self.append(g, ws)
# w is a control
# elif w in ws[:-1]:
# self.append(g, ws)
else:
g_comp.append((g, ws))

self.barrier(label=f"EU{w}")
self.gates_computed = g_comp[::-1]

# def uncompute(self):
# """Uncompute released ancilla qubits"""

# g_comp = []
# self.barrier(label=f"U{''.join(map(str,self.uncomputable))}")
# for g, ws in self.gates_computed[::-1]:
# # w is the target
# if ws[-1] in self.uncomputable:
# self.append(g, ws)
# # w is a control
# # elif w in ws[:-1]:
# # self.append(g, ws)
# else:
# g_comp.append((g, ws))

# self.barrier(label=f"EU{''.join(map(str,self.uncomputable))}")
# self.gates_computed = g_comp[::-1]
return uncomputed

def map_qubit(self, name, index, promote=False):
"""Map a name to a qubit
Expand Down
4 changes: 2 additions & 2 deletions test/test_qcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test1(self):
qc.mcx([a, b], c)
qc.mcx([a, b, c], d)
qc.cx(d, f)
qc.uncompute2()
qc.uncompute([c, d])
qc.draw()

def test2(self):
Expand All @@ -65,7 +65,7 @@ def test2(self):
qc.mcx(q + a[:1], a[2])
qc.mcx(q + a[:2], a[3])
qc.cx(a[3], r)
qc.uncompute2()
qc.uncompute(a)
qc.draw()


Expand Down

0 comments on commit d442a9d

Please sign in to comment.