Skip to content

Commit

Permalink
separate generic qalgorithm functions from groover, test qalgorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Nov 1, 2023
1 parent d458085 commit cff8846
Show file tree
Hide file tree
Showing 9 changed files with 232 additions and 85 deletions.
8 changes: 7 additions & 1 deletion qlasskit/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,11 @@
# limitations under the License.
# isort:skip_file

from .qalgorithm import QAlgorithm # noqa: F401, E402
from .qalgorithm import ( # noqa: F401, E402
QAlgorithm,
format_outcome,
interpret_as_qtype,
oraclize,
ConstantOracleException,
)
from .groover import Groover # noqa: F401, E402
81 changes: 13 additions & 68 deletions qlasskit/algorithms/groover.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,33 @@
# limitations under the License.

import math
import sys
from typing import List, Optional, Tuple, Union, get_args

from sympy import Symbol
from sympy.logic.boolalg import BooleanFalse, BooleanTrue
from typing import List, Optional, Tuple, Union

from ..qcircuit import QCircuit, gates
from ..qlassf import QlassF
from ..types import Qtype
from .qalgorithm import QAlgorithm, format_outcome
from .qalgorithm import QAlgorithm, interpret_as_qtype, oraclize


class Groover(QAlgorithm):
def __init__( # noqa: C901
def __init__(
self,
oracle: QlassF,
element_to_search: Qtype,
n_iterations: Optional[int] = None,
):
"""
Args:
oracle (QlassF): our f(x) -> bool that returns True if x satisfies the function
oracle (QlassF): our f(x) -> bool that returns True if x satisfies the function or
a generic function f(x) = y that we want to compare with element_to_search
element_to_search (Qtype): the element we want to search
n_iterations (int, optional): force a number of iterations (otherwise, pi/4*sqrt(N))
"""
if len(oracle.args) != 1:
raise Exception("the oracle should receive exactly one parameter")

self.oracle: QlassF = oracle
self.search_space_size = len(self.oracle.args[0])
self.oracle: QlassF
self.search_space_size = len(oracle.args[0])

if n_iterations is None:
n_iterations = math.ceil(
Expand All @@ -60,47 +57,11 @@ def __init__( # noqa: C901

# Prepare and add the quantum oracle
if element_to_search is not None:
if hasattr(self.oracle.args[0].ttype, "__name__"):
argt_name = self.oracle.args[0].ttype.__name__ # type: ignore

args = get_args(self.oracle.args[0].ttype)
if len(args) > 0:
argt_name += "["
argt_name += ",".join([x.__name__ for x in args])
argt_name += "]"

elif self.oracle.args[0].ttype == bool:
argt_name = "bool"
elif sys.version_info < (3, 9):
argt_name = "Tuple["
argt_name += ",".join(
[x.__name__ for x in get_args(self.oracle.args[0].ttype)]
)
argt_name += "]"

oracle_outer = QlassF.from_function(
f"""
def oracle_outer(v: {argt_name}) -> bool:
return {self.oracle.name}(v) == {element_to_search}
""",
defs=[self.oracle.to_logicfun()],
)

if (
len(oracle_outer.expressions) == 1
and oracle_outer.expressions[0][0] == Symbol("_ret")
and (
isinstance(oracle_outer.expressions[0][1], BooleanTrue)
or isinstance(oracle_outer.expressions[0][1], BooleanFalse)
)
):
raise Exception(
f"The oracle is constant: {oracle_outer.expressions[0][1]}"
)
self.oracle = oraclize(oracle, element_to_search)
else:
oracle_outer = self.oracle
self.oracle = oracle

oracle_qc = oracle_outer.circuit()
oracle_qc = self.oracle.circuit()

# Add negative phase to result
oracle_qc.add_qubit(name="_ret_phased")
Expand Down Expand Up @@ -148,22 +109,6 @@ def out_qubits(self) -> List[int]:
def interpret_outcome(
self, outcome: Union[str, int, List[bool]]
) -> Union[bool, Tuple, Qtype]:
out = format_outcome(outcome, len(self.out_qubits()))

len_a = len(self.oracle.args[0])
if len_a == 1:
return out[0] # type: ignore

if hasattr(self.oracle.args[0].ttype, "from_bool"):
return self.oracle.args[0].ttype.from_bool(out[::-1][0:len_a]) # type: ignore
elif self.oracle.args[0].ttype == bool:
return out[::-1][0]
else: # Tuple
idx_s = 0
values = []
for x in get_args(self.oracle.args[0].ttype):
len_a = x.BIT_SIZE
values.append(x.from_bool(out[::-1][idx_s : idx_s + len_a]))
idx_s += len_a

return tuple(values)
return interpret_as_qtype(
outcome, self.oracle.args[0].ttype, len(self.oracle.args[0])
)
82 changes: 80 additions & 2 deletions qlasskit/algorithms/qalgorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,102 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Union
import sys
from typing import Any, Dict, List, Optional, Union, get_args

from sympy import Symbol
from sympy.logic.boolalg import BooleanFalse, BooleanTrue

from ..qcircuit import QCircuit, SupportedFramework
from ..qlassf import QlassF


def format_outcome(out: Union[str, int, List[bool]], out_len: int) -> List[bool]:
def format_outcome(
out: Union[str, int, List[bool]], out_len: Optional[int] = None
) -> List[bool]:
if isinstance(out, str):
return format_outcome([True if c == "1" else False for c in out], out_len)
elif isinstance(out, int):
return format_outcome(str(bin(out))[2:], out_len)
elif isinstance(out, List):
if out_len is None:
out_len = len(out)

if len(out) < out_len:
out += [False] * (out_len - len(out))

return out
raise Exception(f"Invalid format: {out}")


def interpret_as_qtype(
out: Union[str, int, List[bool]], qtype, out_len: Optional[int] = None
) -> Any:
out = list(reversed(format_outcome(out, out_len)))

def _interpret(out, qtype, out_len):
if hasattr(qtype, "from_bool"):
return qtype.from_bool(out[0:out_len]) # type: ignore
elif qtype == bool:
return out[0]
else: # Tuple
idx_s = 0
values = []
for x in get_args(qtype):
len_a = x.BIT_SIZE if hasattr(x, "BIT_SIZE") else 1
values.append(_interpret(out[idx_s : idx_s + len_a], x, len_a))
idx_s += len_a

return tuple(values)

return _interpret(out, qtype, out_len)


class ConstantOracleException(Exception):
pass


def oraclize(qf: QlassF, element: Any, name="oracle"):
"""Transform a QlassF qf and an element to an oracle {f(x) = x == element}"""
if hasattr(qf.args[0].ttype, "__name__"):
argt_name = qf.args[0].ttype.__name__ # type: ignore

args = get_args(qf.args[0].ttype)
if len(args) > 0:
argt_name += "["
argt_name += ",".join([x.__name__ for x in args])
argt_name += "]"

elif qf.args[0].ttype == bool:
argt_name = "bool"
elif sys.version_info < (3, 9):
argt_name = "Tuple["
argt_name += ",".join([x.__name__ for x in get_args(qf.args[0].ttype)])
argt_name += "]"

if qf.name == name:
qf.name = f"_{name}"

oracle = QlassF.from_function(
f"def {name}(v: {argt_name}) -> bool:\n return {qf.name}(v) == {element}",
defs=[qf.to_logicfun()],
)

if (
len(oracle.expressions) == 1
and oracle.expressions[0][0] == Symbol("_ret")
and (
isinstance(oracle.expressions[0][1], BooleanTrue)
or isinstance(oracle.expressions[0][1], BooleanFalse)
)
):
raise ConstantOracleException(
f"The oracle is constant: {oracle.expressions[0][1]}"
)

return oracle


class QAlgorithm:
qc: QCircuit

Expand Down
16 changes: 9 additions & 7 deletions qlasskit/ast2ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _replace_types_annotations(ann, arg=None):
value=ast.Name(id="Tuple", ctx=ast.Load()),
slice=_ituple,
)

# Replace Qlist[T,n] with Tuple[(T,)*3]
if isinstance(ann, ast.Subscript) and ann.value.id == "Qlist":
_elts = ann.slice.elts
Expand All @@ -66,7 +66,8 @@ def _replace_types_annotations(ann, arg=None):
return arg
else:
return ann



class ASTRewriter(ast.NodeTransformer):
def __init__(self, env={}, ret=None):
self.env = {}
Expand Down Expand Up @@ -111,16 +112,17 @@ def visit_Name(self, node):

def visit_List(self, node):
return ast.Tuple(elts=[self.visit(el) for el in node.elts])
def visit_AnnAssign(self, node):
node.annotation = _replace_types_annotations(node.annotation)

def visit_AnnAssign(self, node):
node.annotation = _replace_types_annotations(node.annotation)
node.value = self.visit(node.value) if node.value else node.value
self.env[node.target] = node.annotation
return node


def visit_FunctionDef(self, node):
node.args.args = [_replace_types_annotations(x.annotation, arg=x) for x in node.args.args]
node.args.args = [
_replace_types_annotations(x.annotation, arg=x) for x in node.args.args
]

for x in node.args.args:
self.env[x.arg] = x.annotation
Expand Down
4 changes: 0 additions & 4 deletions qlasskit/types/qint.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,3 @@ class Qint12(Qint):

class Qint16(Qint):
BIT_SIZE = 16


# class Qlist
# class Qfixed
2 changes: 1 addition & 1 deletion qlasskit/types/qlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __getitem__(cls, params):
if isinstance(params, tuple) and len(params) == 2:
T, n = params
if isinstance(T, type) and isinstance(n, int) and n >= 0:
return Tuple[T, ...] if n > 0 else Tuple[T]
return Tuple[(T,) * n] if n > 0 else Tuple[T]


class Qlist(metaclass=QlistMeta):
Expand Down
Loading

0 comments on commit cff8846

Please sign in to comment.