Skip to content

Commit

Permalink
fix xor compilation for multiple args, add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 19, 2023
1 parent 023896d commit 3e3a5a0
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 17 deletions.
4 changes: 3 additions & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@
- [x] Extensible type system
- [x] Builtin functions: max(), min(), len()
- [x] Function call (to builtin)
- [x] Int arithmetic: +
- [ ] Int: shift right / left
- [ ] Int: subtraction
- [ ] Publish doc
- [ ] Int arithmetic: +

## Month 2:

Expand Down
2 changes: 1 addition & 1 deletion qlasskit/ast2ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def visit_Call(self, node):
else:
args = node.args

op = ast.Gt() if node.func.id == "max" else ast.Lt()
op = ast.Gt() if node.func.id == "max" else ast.LtE()

def iterif(arg_l):
if len(arg_l) == 1:
Expand Down
2 changes: 1 addition & 1 deletion qlasskit/ast2logic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .env import Env, Binding # noqa: F401, E402
from .utils import flatten # noqa: F401, E402
from .typing import Args, BoolExpList # noqa: F401, E402
from .typing import Args, Arg, BoolExpList # noqa: F401, E402
from .t_arguments import translate_argument, translate_arguments # noqa: F401, E402
from .t_expression import translate_expression, decompose_to_symbols # noqa: F401, E402
from .t_statement import translate_statement # noqa: F401, E402
Expand Down
21 changes: 18 additions & 3 deletions qlasskit/compiler/poccompiler2.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,27 @@ def compile_expr(self, qc: QCircuit, expr: Boolean) -> int:

elif isinstance(expr, Xor):
erets = list(map(lambda e: self.compile_expr(qc, e), expr.args))
last = erets.pop()

qc.barrier("xor")
[qc.mark_ancilla(eret) for eret in erets[:-1]]

qc.mcx(erets[0:-1], erets[-1])
return erets[-1]
if last in qc.ancilla_lst:
fa = last
self.expqmap.update_exp_for_qubit(last, expr)
else:
fa = qc.get_free_ancilla()

qc.cx(last, fa)
qc.mark_ancilla(last)
self.expqmap[expr] = fa

for x in erets:
qc.cx(x, fa)

[qc.mark_ancilla(eret) for eret in erets]
self.garbage_collect(qc)

return fa

elif isinstance(expr, BooleanFalse):
return qc.get_free_ancilla()
Expand Down
2 changes: 1 addition & 1 deletion test/test_qlassf_bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def test_assign3(self):
+ "\th = (not a) and b and (not c)\n"
+ "\treturn g if d and e else h"
)
qf = qlassf(f, to_compile=False)
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
self.assertEqual(len(qf.expressions), 5)
self.assertEqual(qf.expressions[-1][1], ITE(d & e, g, h))
compute_and_compare_results(self, qf)
Expand Down
13 changes: 12 additions & 1 deletion test/test_qlassf_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,9 @@ def test_composed_comparators(self):
# return a + b


# TODO: parameterize
class TestQlassfIntAdd(unittest.TestCase):
def test_add2(self):
def test_add_tuple(self):
f = "def test(a: Tuple[Qint2, Qint2]) -> Qint2: return a[0] + a[1]"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
compute_and_compare_results(self, qf)
Expand All @@ -313,3 +314,13 @@ def test_add_const2(self):
f = "def test() -> Qint4: return Qint4(3) + 3"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
compute_and_compare_results(self, qf)

def test_add_const3(self):
f = "def test(a: Qint2, b: Qint2) -> Qint4: return Qint4(3) + a if a == 3 else Qint4(1) + b"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
compute_and_compare_results(self, qf)

def test_add_const4(self):
f = "def test(a: Qint2) -> Qint2: return a + 2"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
compute_and_compare_results(self, qf)
21 changes: 12 additions & 9 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,13 @@ def compute_result_of_qcircuit(cls, qf, truth_line):
gate = qf.gate()
qc = QuantumCircuit(gate.num_qubits)

# circ_qi = circ.export("circuit", "qiskit")
# print(circ_qi.draw("text"))
# print(qf.expressions)

# Prepare inputs
[qc.initialize(1 if truth_line[i] else 0, i) for i in range(qf.input_size)]

qc.append(gate, list(range(qf.num_qubits)))

# print(qc.decompose().draw("text"))

# Measure
counts = qiskit_measure_and_count(qc)

Expand Down Expand Up @@ -111,7 +109,7 @@ def res_to_str(res):
return "".join([res_to_str(x) for x in res])
elif type(res) == int:
qc = const_to_qtype(res)
qi = qc[0].from_bool(qc[1])
qi = qf.returns.ttype.from_bool(qc[1])
return qi.to_bin()
else:
return res.to_bin()
Expand Down Expand Up @@ -145,18 +143,23 @@ def compute_and_compare_results(cls, qf):

if len(truth_table) > MAX_Q_SIM and COMPILATION_ENABLED:
qc_truth = [random.choice(truth_table) for x in range(MAX_Q_SIM)]
elif COMPILATION_ENABLED:
qc_truth = truth_table

# circ_qi = qf.circuit().export("circuit", "qiskit")
# print(circ_qi.draw("text"))

for truth_line in truth_table:
# Extract str of truthtable and result
truth_str = "".join(
map(lambda x: "1" if x else "0", truth_line[-qf.ret_size :])
)

# Calculate and compare the originalf result
res_original = compute_result_of_originalf(cls, qf, truth_line)
cls.assertEqual(truth_str, res_original)

# Calculate and compare the gate result
if qc_truth and truth_line in qc_truth and COMPILATION_ENABLED:
res_qc = compute_result_of_qcircuit(cls, qf, truth_line)
cls.assertEqual(truth_str, res_qc)

# Calculate and compare the originalf result
res_original = compute_result_of_originalf(cls, qf, truth_line)
cls.assertEqual(truth_str, res_original)

0 comments on commit 3e3a5a0

Please sign in to comment.