Skip to content

Commit

Permalink
optimize for unrolling
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 23, 2023
1 parent 21d6cbd commit 7cce318
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 26 deletions.
3 changes: 2 additions & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@
- [x] Symbol reassign and augassign
- [x] Remove unneccessary expressions
- [x] Remove quantum circuit identities
- [ ] For unrolling
- [x] For unrolling
- [ ] Aggregate cascading expressions in for unrolling

### Week 2: (30 Oct 23)
### Week 3: (6 Nov 23)
Expand Down
12 changes: 6 additions & 6 deletions qlasskit/compiler/poccompiler2.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ def compile(self, name, args: Args, returns: Arg, exprs: BoolExpList) -> QCircui
self.expqmap.update_exp_for_qubit(iret, sym)
qc.map_qubit(sym, iret, promote=not is_temp)

# Mark temp symbols
if isinstance(symp_exp, Symbol) and symp_exp.name[0:2] == "__":
qc.ancilla_lst.add(iret)
qc.mark_ancilla(iret)
self.expqmap.remove_map_by_qubits([iret])

self.garbage_collect(qc)

# print(sym, exp)
# circ_qi = qc.export("circuit", "qiskit")
# print(circ_qi.draw("text"))
# print()
# print()

qc.remove_identities()
return qc

Expand Down
2 changes: 0 additions & 2 deletions qlasskit/compiler/poccompiler3.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def compile_expr(self, qc: QCircuit, expr: Boolean) -> int: # noqa: C901
and isinstance(expr.args[0], Symbol)
and self.symbol_count[expr.args[0].name] <= 1
):
print("called not simp")
self.symbol_count[expr.args[0].name] = 0
eret = self.compile_expr(qc, expr.args[0])
qc.x(eret)
Expand All @@ -132,7 +131,6 @@ def compile_expr(self, qc: QCircuit, expr: Boolean) -> int: # noqa: C901
for e in expr.args
]
):
print("called xor simp")
erets = []

for e in expr.args:
Expand Down
29 changes: 26 additions & 3 deletions qlasskit/qcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ def get_key_by_index(self, i: int):
return key
raise Exception(f"Qubit with index {i} not found")

def __contains__(self, key: Union[str, Symbol, int]):
if isinstance(key, str):
return key in self.qubit_map
elif isinstance(key, Symbol):
return key.name in self.qubit_map
return False

def __delitem__(self, key: Union[str, Symbol, int]):
if isinstance(key, str):
del self.qubit_map[key]
elif isinstance(key, Symbol):
del self.qubit_map[key.name]

def __setitem__(self, key: Union[str, Symbol, int], val):
if isinstance(key, str):
self.qubit_map[key] = val
elif isinstance(key, Symbol):
self.qubit_map[key.name] = val

def __getitem__(self, key: Union[str, Symbol, int]):
"""Return the qubit index given its name or index"""
if isinstance(key, str):
Expand All @@ -60,13 +79,17 @@ def remove_identities(self):
i = 0
len_g = len(self.gates)
while i < len_g:
if i < (len_g - 2) and self.gates[i] == self.gates[i + 1]:
if i < (len_g - 1) and self.gates[i] == self.gates[i + 1]:
if result[-1][0] == "bar":
result.pop()
i += 2
elif (
i < (len_g - 3)
i < (len_g - 2)
and self.gates[i] == self.gates[i + 2]
and self.gates[i + 1][0] == "bar"
):
if result[-1][0] == "bar":
result.pop()
i += 3
else:
result.append(self.gates[i])
Expand Down Expand Up @@ -144,7 +167,7 @@ def map_qubit(self, name, index, promote=False):
if promote and index in self.ancilla_lst:
self.ancilla_lst.remove(index)

self.qubit_map[name] = index
self[name] = index

def add_qubit(self, name=None):
"""Add a qubit to the circuit.
Expand Down
46 changes: 45 additions & 1 deletion qlasskit/qlassf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,35 @@ def remove_const_exps(exps: BoolExpList, fun_ret: Arg) -> BoolExpList:
return n_exps


# Remove exp like: __a.0 = a.0, a.0 = __a.0
# Subsitute exps (replace a = ~a, a = ~a, a = ~a => a = ~a)
# def subsitute_exps(exps: BoolExpList, fun_ret: Arg) -> BoolExpList:
# const: Dict[Symbol, Boolean] = {}
# n_exps: BoolExpList = []
# print(exps)

# for i in range(len(exps)):
# (s, e) = exps[i]
# e = e.subs(const)
# const[s] = e

# for x in e.free_symbols:
# if x in const:
# n_exps.append((x, const[x]))
# del const[x]

# for (s,e) in const.items():
# if s == e:
# continue

# n_exps.append((s,e))

# print(n_exps)
# print()
# print()
# return n_exps


# Remove exp like: __a.0 = a.0, ..., a.0 = __a.0
def remove_unnecessary_assigns(exps: BoolExpList) -> BoolExpList:
n_exps: BoolExpList = []

Expand All @@ -69,6 +97,20 @@ def should_add(s, e, n_exps2):
return n_exps


# Translate exp like: __a.0 = !a, a = __a.0 ===> a = !a
def merge_unnecessary_assigns(exps: BoolExpList) -> BoolExpList:
n_exps: BoolExpList = []

for s, e in exps:
if len(n_exps) >= 1 and n_exps[-1][0] == e:
old = n_exps.pop()
n_exps.append((s, old[1]))
else:
n_exps.append((s, e))

return n_exps


class QlassF:
"""Class representing a quantum classical circuit"""

Expand Down Expand Up @@ -221,6 +263,8 @@ def from_function(
# Remove unnecessary expressions
exps = remove_const_exps(exps, fun_ret)
exps = remove_unnecessary_assigns(exps)
exps = merge_unnecessary_assigns(exps)
# exps = subsitute_exps(exps, fun_ret)

# Return the qlassf object
qf = QlassF(fun_name, original_f, args, fun_ret, exps)
Expand Down
12 changes: 4 additions & 8 deletions test/test_qlassf_bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,9 @@ def test_ifexp3(self):
def test_assign(self):
f = "def test(a: bool, b: bool) -> bool:\n\tc = a and b\n\treturn c"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
self.assertEqual(len(qf.expressions), 2)
self.assertEqual(qf.expressions[0][0], c)
self.assertEqual(len(qf.expressions), 1)
self.assertEqual(qf.expressions[0][0], _ret)
self.assertEqual(qf.expressions[0][1], And(a, b))
self.assertEqual(qf.expressions[1][0], _ret)
self.assertEqual(qf.expressions[1][1], c)
compute_and_compare_results(self, qf)

def test_assign2(self):
Expand All @@ -165,11 +163,9 @@ def test_assign2(self):
+ "\treturn True if d else False"
)
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
self.assertEqual(len(qf.expressions), 2)
self.assertEqual(qf.expressions[0][0], d)
self.assertEqual(len(qf.expressions), 1)
self.assertEqual(qf.expressions[0][1], And(a, And(Not(b), c)))
self.assertEqual(qf.expressions[1][0], _ret)
self.assertEqual(qf.expressions[1][1], d)
self.assertEqual(qf.expressions[0][0], _ret)
compute_and_compare_results(self, qf)

def test_assign3(self):
Expand Down
15 changes: 12 additions & 3 deletions test/test_qlassf_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,17 @@ def test_for_nit_bool(self):
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
compute_and_compare_results(self, qf)

# def test_for_cond(self):
# f = "def test(a: Qint2, b: bool) -> Qint2:\n\tfor i in range(3):\n\t\ta += (i if b else 1)\n\treturn a"
def test_for_nit_bool_many(self):
f = "def test(a: bool) -> bool:\n\tfor i in range(15):\n\t\ta = not a\n\treturn a"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
compute_and_compare_results(self, qf)

# def test_for_nit_tbool_many(self):
# f = "def test(a: Tuple[bool,bool]) -> Tuple[bool,bool]:\n\tfor i in range(32):\n\t\ta[0] = not a[0]\n\t\ta[1] = not a[1]\n\treturn a"
# qf = qlassf(f, to_compile=COMPILATION_ENABLED)
# print(qf.expressions)
# compute_and_compare_results(self, qf)

def test_for_cond(self):
f = "def test(a: Qint2, b: bool) -> Qint2:\n\tfor i in range(2):\n\t\ta += (i if b else 1)\n\treturn a"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
compute_and_compare_results(self, qf)
2 changes: 0 additions & 2 deletions test/test_qlassf_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,13 +286,11 @@ def test_composed_comparators(self):
def test_shift_left(self):
f = "def test(n: Qint2) -> Qint4: return n << 1"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
print(qf.expressions)
compute_and_compare_results(self, qf)

def test_shift_right(self):
f = "def test(n: Qint2) -> Qint4: return n >> 1"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
print(qf.expressions)
compute_and_compare_results(self, qf)

# Our Qint are unsigned
Expand Down

0 comments on commit 7cce318

Please sign in to comment.