Skip to content

Commit

Permalink
fix function call and groover search
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 28, 2023
1 parent 7e81cc7 commit 1a877fd
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 12 deletions.
78 changes: 78 additions & 0 deletions examples/groover_hash_collision2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Tuple

from matplotlib import pyplot as plt
from qiskit import Aer, QuantumCircuit, transpile, ClassicalRegister
from qiskit.visualization import plot_histogram

from qlasskit import Qint2, Qint4, Qint8, Qint16, qlassf
from qlasskit.algorithms import Groover


def qiskit_simulate(qc, alog):
c = ClassicalRegister(len(algo.out_qubits()))
qc.add_bits(c)
qc.measure(algo.out_qubits(), c)
print(qc.draw("text"))

simulator = Aer.get_backend("aer_simulator")
circ = transpile(qc, simulator)
result = simulator.run(circ).result()

return result.get_counts(circ)


# @qlassf
# def md5_simp(message: Tuple[Qint8, Qint8]) -> Qint8:
# A = 0x12
# # A, B, C, D = 0x12, 0x34, 0x56, 0x78

# for i in range(2): # MESSAGE_LEN
# char = message[i]

# A = (A + char) & 0xFF
# # B = (B ^ char) & 0xFF
# # C = (C + (char << 1)) & 0xFF
# # D = (D - char) & 0xFF

# # return (A<<8) + B
# return A


@qlassf
def md5_simp(message: Tuple[Qint4, Qint4]) -> Qint8:
A = 0x12
B = 0x34
# A, B, C, D = 0x12, 0x34, 0x56, 0x78

for i in range(2): # MESSAGE_LEN
char = message[i]

A = (A + char) & 0xF
B = (B ^ char) & 0xF
# C = (C + (char << 1)) & 0xFF
# D = (D - char) & 0xFF

return (A<<4) + B

# @qlassf
# def md5_simp(m: Tuple[Qint2, Qint2]) -> Qint4:
# A = 0x1
# B = 0x3
# A = (A + m[0]) & 0xF
# B = (B ^ m[1]) & 0xF
# return (A << 2) + B


algo = Groover(md5_simp, (Qint8(0xCA)))

# print(hex(md5_simp.original_f((2,3))))

qc = algo.circuit().export("circuit", "qiskit")
counts = qiskit_simulate(qc, algo)
counts_readable = algo.interpet_counts(counts)
plot_histogram(counts_readable)
plt.show()


# print(md5_simp.circuit().export("circuit", "qiskit").draw("text"))
# plt.show()
29 changes: 25 additions & 4 deletions qlasskit/algorithms/groover.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import math
from typing import List, Optional, Union
from typing import List, Optional, Union, get_args

from ..qcircuit import QCircuit, gates
from ..qlassf import QlassF
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(
self.search_space_size = len(self.oracle.args[0])

if n_iterations is None:
n_iterations = math.ceil(math.pi / 4.0 * math.sqrt(self.search_space_size))
n_iterations = math.ceil(math.pi / 4.0 * math.sqrt(2**self.search_space_size))

self.n_iterations = n_iterations

Expand All @@ -54,7 +54,16 @@ def __init__(

# Prepare and add the quantum oracle
if element_to_search is not None:
argt_name = self.oracle.args[0].ttype.__name__ # type: ignore
if hasattr(self.oracle.args[0].ttype, "__name__"):
argt_name = self.oracle.args[0].ttype.__name__ # type: ignore
elif self.oracle.args[0].ttype == bool:
argt_name = "bool"
else:
argt_name = "Tuple["
argt_name += ",".join(
[x.__name__ for x in get_args(self.oracle.args[0].ttype)]
)
argt_name += "]"

oracle_outer = QlassF.from_function(
f"""
Expand Down Expand Up @@ -118,4 +127,16 @@ def interpret_outcome(self, outcome: Union[str, int, List[bool]]) -> Qtype:
if len_a == 1:
return out[0] # type: ignore

return self.oracle.args[0].ttype.from_bool(out[::-1][0:len_a]) # type: ignore
if hasattr(self.oracle.args[0].ttype, "__name__"):
return self.oracle.args[0].ttype.from_bool(out[::-1][0:len_a]) # type: ignore
elif self.oracle.args[0].ttype == bool:
return out[::-1][0]
else: # Tuple
idx_s = 0
values = []
for x in get_args(self.oracle.args[0].ttype):
len_a = x.BIT_SIZE
values.append(x.from_bool(out[::-1][idx_s:idx_s+len_a]))
idx_s += len_a

return tuple(values)
9 changes: 4 additions & 5 deletions qlasskit/algorithms/qalgorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def format_outcome(out: Union[str, int, List[bool]]) -> List[bool]:
return format_outcome(str(bin(out))[2:])
elif isinstance(out, List):
return out
raise Exception("Invalid format")
raise Exception(f"Invalid format: {out}")


class QAlgorithm:
Expand Down Expand Up @@ -53,11 +53,10 @@ def interpet_counts(self, counts: Dict[str, int]) -> Dict[Any, int]:
outcomes = [(self.interpret_outcome(e), c) for (e, c) in counts.items()]
int_counts: Dict[Any, int] = {}
for e, c in outcomes:
inter = self.interpret_outcome(e)
if inter in int_counts:
int_counts[inter] += c
if e in int_counts:
int_counts[e] += c
else:
int_counts[inter] = c
int_counts[e] = c
return int_counts

def export(self, framework: SupportedFramework = "qiskit") -> Any:
Expand Down
6 changes: 5 additions & 1 deletion qlasskit/ast2logic/t_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,11 @@ def unfold(v_exps, op):
for a, fa in zip(args, def_f[1]):
if isinstance(a[1], List):
for i in range(len(a[1])): # type: ignore
subs[f"{fa.name}.{i}"] = a[1][i] # type: ignore
index = ".".join(a[1][i].name.split(".")[1:])
if index == '':
index = f'{i}'

subs[f"{fa.name}.{index}"] = a[1][i] # type: ignore

else:
subs[fa.name] = a[1]
Expand Down
7 changes: 5 additions & 2 deletions qlasskit/compiler/internalcompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,11 @@ def compile(self, name, args: Args, returns: Arg, exprs: BoolExpList) -> QCircui
def compile_expr( # noqa: C901
self, qc: QCircuitEnhanced, expr: Boolean, dest=None
) -> int:
if isinstance(expr, Symbol) and expr.name in qc:
return qc[expr.name]
if isinstance(expr, Symbol):
if expr.name in qc:
return qc[expr.name]
else:
raise CompilerException(f'Symbol not found in qc: {expr.name}')

elif expr in self.expqmap:
return self.expqmap[expr]
Expand Down

0 comments on commit 1a877fd

Please sign in to comment.