Skip to content

Commit

Permalink
tuple-tuple comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 27, 2023
1 parent 39b0093 commit 74b068f
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 20 deletions.
1 change: 1 addition & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@

### Week 2: (30 Oct 23)

- [x] Tuple-tuple comparison
- [ ] Groover algorithm tests

### Week 3: (6 Nov 23)
Expand Down
15 changes: 15 additions & 0 deletions examples/groover_hash_collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,21 @@ def qiskit_simulate(qc):
# return (k<<1) + 2 if inner(k) else 4


# from typing import Tuple

# @qlassf
# def hash(k: Qint8) -> Tuple[bool, bool]:
# return k[0] and k[1] and not k[2] and not k[3], k[4] and not k[5] and k[6] and not k[7]
# algo = Groover(hash, (True,True))


# @qlassf
# def hash(k: Qint8) -> bool:
# return k[0] and k[1] and not k[2] and not k[3] and k[4] and not k[5] and k[6] and not k[7]

# algo = Groover(hash, True)


@qlassf
def hash(k: Qint4) -> Qint4:
return (k << 1) + 2
Expand Down
5 changes: 1 addition & 4 deletions qlasskit/algorithms/groover.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,11 @@ def oracle_outer(v: {argt_name}) -> bool:

for i in range(n_iterations):
self.qc.barrier(label=f"g{i}")
# self.qc.barrier(label=f"orac_{i}")
self.qc += oracle_qc.copy()

# self.qc.barrier(label=f"diff_{i}")
self.qc.barrier()
self.qc += diffuser_qc.copy()

# self.qc.barrier(label="end")

def circuit(self) -> QCircuit:
return self.qc

Expand Down
28 changes: 28 additions & 0 deletions qlasskit/ast2logic/t_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,34 @@ def unfold(v_exps, op):
# Check comparability
if tleft[0] == bool and tcomp[0] == bool:
op_type = Qbool

# Compare tuples for equality / inequality
elif len(get_args(tleft[0])) > 0 and len(get_args(tcomp[0])) > 0:
arg_l = get_args(tleft[0])
arg_r = get_args(tcomp[0])
if arg_l != arg_r:
raise exceptions.TypeErrorException(tleft[0], tcomp[0])

if isinstance(expr.ops[0], ast.Eq):
op = Qbool.eq
elif isinstance(expr.ops[0], ast.NotEq):
op = Qbool.neq
else:
raise exceptions.OperationNotSupportedException(bool, expr.ops[0])

c = True
idx = 0
for left, right in zip(arg_l, arg_r):
if left == bool:
c = And(c, op((bool, tleft[1][idx]), (bool, tcomp[1][idx]))[1])
idx += 1
else:
for si in range(left.BIT_SIZE):
c = And(c, op((bool, tleft[1][i]), (bool, tcomp[1][idx]))[1])
idx += 1

return (bool, c)

elif issubclass(tleft[0], Qtype) and issubclass(tcomp[0], Qtype): # type: ignore
if not tleft[0].comparable(tcomp[0]): # type: ignore
raise exceptions.TypeErrorException(tcomp[0], tleft[0])
Expand Down
5 changes: 0 additions & 5 deletions qlasskit/compiler/internalcompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def compile_expr( # noqa: C901
elif isinstance(expr, Not):
eret = self.compile_expr(qc, expr.args[0])

# qc.barrier("not")

if eret in qc.ancilla_lst:
qc.x(eret)
self.expqmap[expr] = eret
Expand All @@ -92,7 +90,6 @@ def compile_expr( # noqa: C901
if dest is None:
dest = qc.get_free_ancilla()

# qc.barrier("and")
qc.mcx(erets, dest)

[qc.mark_ancilla(eret) for eret in erets]
Expand All @@ -104,8 +101,6 @@ def compile_expr( # noqa: C901
erets = list(map(lambda e: self.compile_expr(qc, e), expr.args))
last = erets.pop()

# qc.barrier("xor")

if last in qc.ancilla_lst:
dest = last
self.expqmap[expr] = last
Expand Down
12 changes: 1 addition & 11 deletions qlasskit/qcircuit/qcircuitenhanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def uncompute_all(self, keep: List[Union[Symbol, int]] = []):
# TODO: replace with + invert(keep)
scopy = copy.deepcopy(self.gates)
uncomputed = set()
# self.barrier(label="un_all")

for g, qbs, p in reversed(scopy):
if (
isinstance(g, gates.NopGate)
Expand All @@ -115,10 +115,6 @@ def uncompute_all(self, keep: List[Union[Symbol, int]] = []):

self.append(g, qbs, p)

# Remove barrier if no uncomputed
# if len(uncomputed) == 0:
# self.gates.pop()

return uncomputed

def uncompute(self, to_mark=[]):
Expand All @@ -128,8 +124,6 @@ def uncompute(self, to_mark=[]):
if len(self.marked_ancillas) == 0:
return []

# self.barrier(label="un")

uncomputed = set()
new_gates_comp = []

Expand All @@ -145,8 +139,4 @@ def uncompute(self, to_mark=[]):
self.marked_ancillas = self.marked_ancillas - uncomputed
self.gates_computed = new_gates_comp[::-1]

# Remove barrier if no uncomputed
# if len(uncomputed) == 0:
# self.gates.pop()

return uncomputed
10 changes: 10 additions & 0 deletions test/test_qlassf_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,13 @@ def test_tuple_result(self):
self.assertEqual(qf.expressions[1][0], Symbol("_ret.1"))
self.assertEqual(qf.expressions[1][1], b)
# compute_and_compare_results(self, qf)

def test_tuple_compare(self):
f = "def test(a: Tuple[bool, bool], b: Tuple[bool, bool]) -> bool:\n\treturn a == b"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
compute_and_compare_results(self, qf)

def test_tuple_int_compare(self):
f = "def test(a: Tuple[Qint2, Qint2], b: Tuple[Qint2, Qint2]) -> bool:\n\treturn a == b"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
compute_and_compare_results(self, qf)

0 comments on commit 74b068f

Please sign in to comment.