diff --git a/qlasskit/algorithms/__init__.py b/qlasskit/algorithms/__init__.py index 99f79212..b606274d 100644 --- a/qlasskit/algorithms/__init__.py +++ b/qlasskit/algorithms/__init__.py @@ -13,5 +13,11 @@ # limitations under the License. # isort:skip_file -from .qalgorithm import QAlgorithm # noqa: F401, E402 +from .qalgorithm import ( # noqa: F401, E402 + QAlgorithm, + format_outcome, + interpret_as_qtype, + oraclize, + ConstantOracleException, +) from .groover import Groover # noqa: F401, E402 diff --git a/qlasskit/algorithms/groover.py b/qlasskit/algorithms/groover.py index 66cba100..63ee2f3c 100644 --- a/qlasskit/algorithms/groover.py +++ b/qlasskit/algorithms/groover.py @@ -13,20 +13,16 @@ # limitations under the License. import math -import sys -from typing import List, Optional, Tuple, Union, get_args - -from sympy import Symbol -from sympy.logic.boolalg import BooleanFalse, BooleanTrue +from typing import List, Optional, Tuple, Union from ..qcircuit import QCircuit, gates from ..qlassf import QlassF from ..types import Qtype -from .qalgorithm import QAlgorithm, format_outcome +from .qalgorithm import QAlgorithm, interpret_as_qtype, oraclize class Groover(QAlgorithm): - def __init__( # noqa: C901 + def __init__( self, oracle: QlassF, element_to_search: Qtype, @@ -34,15 +30,16 @@ def __init__( # noqa: C901 ): """ Args: - oracle (QlassF): our f(x) -> bool that returns True if x satisfies the function + oracle (QlassF): our f(x) -> bool that returns True if x satisfies the function or + a generic function f(x) = y that we want to compare with element_to_search element_to_search (Qtype): the element we want to search n_iterations (int, optional): force a number of iterations (otherwise, pi/4*sqrt(N)) """ if len(oracle.args) != 1: raise Exception("the oracle should receive exactly one parameter") - self.oracle: QlassF = oracle - self.search_space_size = len(self.oracle.args[0]) + self.oracle: QlassF + self.search_space_size = len(oracle.args[0]) if n_iterations is None: n_iterations = math.ceil( @@ -60,47 +57,11 @@ def __init__( # noqa: C901 # Prepare and add the quantum oracle if element_to_search is not None: - if hasattr(self.oracle.args[0].ttype, "__name__"): - argt_name = self.oracle.args[0].ttype.__name__ # type: ignore - - args = get_args(self.oracle.args[0].ttype) - if len(args) > 0: - argt_name += "[" - argt_name += ",".join([x.__name__ for x in args]) - argt_name += "]" - - elif self.oracle.args[0].ttype == bool: - argt_name = "bool" - elif sys.version_info < (3, 9): - argt_name = "Tuple[" - argt_name += ",".join( - [x.__name__ for x in get_args(self.oracle.args[0].ttype)] - ) - argt_name += "]" - - oracle_outer = QlassF.from_function( - f""" -def oracle_outer(v: {argt_name}) -> bool: - return {self.oracle.name}(v) == {element_to_search} -""", - defs=[self.oracle.to_logicfun()], - ) - - if ( - len(oracle_outer.expressions) == 1 - and oracle_outer.expressions[0][0] == Symbol("_ret") - and ( - isinstance(oracle_outer.expressions[0][1], BooleanTrue) - or isinstance(oracle_outer.expressions[0][1], BooleanFalse) - ) - ): - raise Exception( - f"The oracle is constant: {oracle_outer.expressions[0][1]}" - ) + self.oracle = oraclize(oracle, element_to_search) else: - oracle_outer = self.oracle + self.oracle = oracle - oracle_qc = oracle_outer.circuit() + oracle_qc = self.oracle.circuit() # Add negative phase to result oracle_qc.add_qubit(name="_ret_phased") @@ -148,22 +109,6 @@ def out_qubits(self) -> List[int]: def interpret_outcome( self, outcome: Union[str, int, List[bool]] ) -> Union[bool, Tuple, Qtype]: - out = format_outcome(outcome, len(self.out_qubits())) - - len_a = len(self.oracle.args[0]) - if len_a == 1: - return out[0] # type: ignore - - if hasattr(self.oracle.args[0].ttype, "from_bool"): - return self.oracle.args[0].ttype.from_bool(out[::-1][0:len_a]) # type: ignore - elif self.oracle.args[0].ttype == bool: - return out[::-1][0] - else: # Tuple - idx_s = 0 - values = [] - for x in get_args(self.oracle.args[0].ttype): - len_a = x.BIT_SIZE - values.append(x.from_bool(out[::-1][idx_s : idx_s + len_a])) - idx_s += len_a - - return tuple(values) + return interpret_as_qtype( + outcome, self.oracle.args[0].ttype, len(self.oracle.args[0]) + ) diff --git a/qlasskit/algorithms/qalgorithm.py b/qlasskit/algorithms/qalgorithm.py index 6167ad45..33eb8e06 100644 --- a/qlasskit/algorithms/qalgorithm.py +++ b/qlasskit/algorithms/qalgorithm.py @@ -12,17 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Union +import sys +from typing import Any, Dict, List, Optional, Union, get_args + +from sympy import Symbol +from sympy.logic.boolalg import BooleanFalse, BooleanTrue from ..qcircuit import QCircuit, SupportedFramework +from ..qlassf import QlassF -def format_outcome(out: Union[str, int, List[bool]], out_len: int) -> List[bool]: +def format_outcome( + out: Union[str, int, List[bool]], out_len: Optional[int] = None +) -> List[bool]: if isinstance(out, str): return format_outcome([True if c == "1" else False for c in out], out_len) elif isinstance(out, int): return format_outcome(str(bin(out))[2:], out_len) elif isinstance(out, List): + if out_len is None: + out_len = len(out) + if len(out) < out_len: out += [False] * (out_len - len(out)) @@ -30,6 +40,74 @@ def format_outcome(out: Union[str, int, List[bool]], out_len: int) -> List[bool] raise Exception(f"Invalid format: {out}") +def interpret_as_qtype( + out: Union[str, int, List[bool]], qtype, out_len: Optional[int] = None +) -> Any: + out = list(reversed(format_outcome(out, out_len))) + + def _interpret(out, qtype, out_len): + if hasattr(qtype, "from_bool"): + return qtype.from_bool(out[0:out_len]) # type: ignore + elif qtype == bool: + return out[0] + else: # Tuple + idx_s = 0 + values = [] + for x in get_args(qtype): + len_a = x.BIT_SIZE if hasattr(x, "BIT_SIZE") else 1 + values.append(_interpret(out[idx_s : idx_s + len_a], x, len_a)) + idx_s += len_a + + return tuple(values) + + return _interpret(out, qtype, out_len) + + +class ConstantOracleException(Exception): + pass + + +def oraclize(qf: QlassF, element: Any, name="oracle"): + """Transform a QlassF qf and an element to an oracle {f(x) = x == element}""" + if hasattr(qf.args[0].ttype, "__name__"): + argt_name = qf.args[0].ttype.__name__ # type: ignore + + args = get_args(qf.args[0].ttype) + if len(args) > 0: + argt_name += "[" + argt_name += ",".join([x.__name__ for x in args]) + argt_name += "]" + + elif qf.args[0].ttype == bool: + argt_name = "bool" + elif sys.version_info < (3, 9): + argt_name = "Tuple[" + argt_name += ",".join([x.__name__ for x in get_args(qf.args[0].ttype)]) + argt_name += "]" + + if qf.name == name: + qf.name = f"_{name}" + + oracle = QlassF.from_function( + f"def {name}(v: {argt_name}) -> bool:\n return {qf.name}(v) == {element}", + defs=[qf.to_logicfun()], + ) + + if ( + len(oracle.expressions) == 1 + and oracle.expressions[0][0] == Symbol("_ret") + and ( + isinstance(oracle.expressions[0][1], BooleanTrue) + or isinstance(oracle.expressions[0][1], BooleanFalse) + ) + ): + raise ConstantOracleException( + f"The oracle is constant: {oracle.expressions[0][1]}" + ) + + return oracle + + class QAlgorithm: qc: QCircuit diff --git a/qlasskit/ast2ast.py b/qlasskit/ast2ast.py index 8292aa66..4408a4b2 100644 --- a/qlasskit/ast2ast.py +++ b/qlasskit/ast2ast.py @@ -50,7 +50,7 @@ def _replace_types_annotations(ann, arg=None): value=ast.Name(id="Tuple", ctx=ast.Load()), slice=_ituple, ) - + # Replace Qlist[T,n] with Tuple[(T,)*3] if isinstance(ann, ast.Subscript) and ann.value.id == "Qlist": _elts = ann.slice.elts @@ -66,7 +66,8 @@ def _replace_types_annotations(ann, arg=None): return arg else: return ann - + + class ASTRewriter(ast.NodeTransformer): def __init__(self, env={}, ret=None): self.env = {} @@ -111,16 +112,17 @@ def visit_Name(self, node): def visit_List(self, node): return ast.Tuple(elts=[self.visit(el) for el in node.elts]) - - def visit_AnnAssign(self, node): - node.annotation = _replace_types_annotations(node.annotation) + + def visit_AnnAssign(self, node): + node.annotation = _replace_types_annotations(node.annotation) node.value = self.visit(node.value) if node.value else node.value self.env[node.target] = node.annotation return node - def visit_FunctionDef(self, node): - node.args.args = [_replace_types_annotations(x.annotation, arg=x) for x in node.args.args] + node.args.args = [ + _replace_types_annotations(x.annotation, arg=x) for x in node.args.args + ] for x in node.args.args: self.env[x.arg] = x.annotation diff --git a/qlasskit/types/qint.py b/qlasskit/types/qint.py index 0a45feef..e4ba16cf 100644 --- a/qlasskit/types/qint.py +++ b/qlasskit/types/qint.py @@ -224,7 +224,3 @@ class Qint12(Qint): class Qint16(Qint): BIT_SIZE = 16 - - -# class Qlist -# class Qfixed diff --git a/qlasskit/types/qlist.py b/qlasskit/types/qlist.py index 463b8994..7e6fa363 100644 --- a/qlasskit/types/qlist.py +++ b/qlasskit/types/qlist.py @@ -22,7 +22,7 @@ def __getitem__(cls, params): if isinstance(params, tuple) and len(params) == 2: T, n = params if isinstance(T, type) and isinstance(n, int) and n >= 0: - return Tuple[T, ...] if n > 0 else Tuple[T] + return Tuple[(T,) * n] if n > 0 else Tuple[T] class Qlist(metaclass=QlistMeta): diff --git a/test/test_algo.py b/test/test_algo.py index 34db42e8..b1a63476 100644 --- a/test/test_algo.py +++ b/test/test_algo.py @@ -11,3 +11,112 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import unittest +from typing import Tuple + +from sympy import And, Not, Symbol + +from qlasskit import Qint2, Qint4, Qlist, qlassf +from qlasskit.algorithms import ( + ConstantOracleException, + QAlgorithm, + format_outcome, + interpret_as_qtype, + oraclize, +) + + +class TestAlgo_format_outcome(unittest.TestCase): + def test_format_outcome_str(self): + out = "1011" + self.assertEqual(format_outcome(out), [True, False, True, True]) + + def test_format_outcome_str_with_out_len(self): + out = "1011" + self.assertEqual(format_outcome(out, 5), [True, False, True, True, False]) + + def test_format_outcome_int(self): + out = 15 + self.assertEqual(format_outcome(out), [True, True, True, True]) + + def test_format_outcome_int_with_out_len(self): + out = 15 + self.assertEqual(format_outcome(out, 5), [True, True, True, True, False]) + + def test_format_outcome_bool(self): + out = [True, True] + self.assertEqual(format_outcome(out), [True, True]) + + def test_format_outcome_bool_with_out_len(self): + out = [True, True] + self.assertEqual(format_outcome(out, 3), [True, True, False]) + + +class TestAlgo_interpret_as_type(unittest.TestCase): + def test_interpret_bool(self): + _out = interpret_as_qtype([False, True], bool, 1) + self.assertEqual(_out, True) + + def test_interpret_qint2(self): + _out = interpret_as_qtype([False, True], Qint2, 2) + self.assertEqual(_out, 1) + + def test_interpret_qint4(self): + _out = interpret_as_qtype([True, True, True, False], Qint4, 4) + self.assertEqual(_out, 14) + + def test_interpret_qlist_bool_3(self): + _out = list(interpret_as_qtype([True, True, False], Qlist[bool, 3], 3)) + self.assertEqual(_out, [False, True, True]) + + def test_interpret_tuple_bool(self): + _out = interpret_as_qtype([True, True, False], Tuple[bool, bool, bool], 3) + self.assertEqual(_out, (False, True, True)) + + +class QT(QAlgorithm): + def __init__(self, tt, ll): + self.tt = tt + self.ll = ll + + def interpret_outcome(self, oc): + return interpret_as_qtype(oc, self.tt, self.ll) + + +class TestAlgo_interpret_counts(unittest.TestCase): + def test_interpret_counts(self): + q = QT(bool, 1) + c = q.interpet_counts({"010": 3, "110": 2, "111": 1}) + self.assertEqual(c, {False: 5, True: 1}) + + def test_interpret_counts2(self): + q = QT(Tuple[bool, bool], 2) + c = q.interpet_counts({"010": 3, "110": 2, "111": 1}) + self.assertEqual(c, {(False, True): 5, (True, True): 1}) + + +class TestAlgo_oraclize(unittest.TestCase): + def test_constant_oracle(self): + q = qlassf("def test(a: bool) -> bool: return True") + self.assertRaises(ConstantOracleException, lambda x: oraclize(q, True), True) + + def test_oracle_identity(self): + q = qlassf("def test(a: bool) -> bool: return a") + orac = oraclize(q, True) + self.assertEqual(orac.name, "oracle") + self.assertEqual(orac.expressions, [(Symbol("_ret"), Symbol("v"))]) + + def test_oracle_tuple(self): + q = qlassf("def test(a: Tuple[bool, bool]) -> bool: return a[0] and a[1]") + orac = oraclize(q, True) + self.assertEqual( + orac.expressions, [(Symbol("_ret"), And(Symbol("v.0"), Symbol("v.1")))] + ) + + def test_oracle_qtype(self): + q = qlassf("def test(a: Qint2) -> Qint2: return a") + orac = oraclize(q, 1) + self.assertEqual( + orac.expressions, [(Symbol("_ret"), And(Symbol("v.0"), Not(Symbol("v.1"))))] + ) diff --git a/test/test_algo_groover.py b/test/test_algo_groover.py index 34db42e8..29aadeee 100644 --- a/test/test_algo_groover.py +++ b/test/test_algo_groover.py @@ -11,3 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import unittest + +from qlasskit import Qint2, Qint4, ast2logic, exceptions + + +class TestAlgoGroover(unittest.TestCase): + # test interpret outcome + # test various combinations of input / output size + # test out_qubits + # test without element to search + pass diff --git a/test/test_ast2logic_t_arg.py b/test/test_ast2logic_t_arg.py index da79b5d9..6a7f974e 100644 --- a/test/test_ast2logic_t_arg.py +++ b/test/test_ast2logic_t_arg.py @@ -109,7 +109,6 @@ def test_list_of_int2(self): "a.1.1", ], ) - def test_tuple_of_list2(self): f = "a: Tuple[bool, Qlist[bool, 2]]" @@ -117,4 +116,4 @@ def test_tuple_of_list2(self): c = ast2logic.translate_argument(ann_ast, ast2logic.Env(), "a") self.assertEqual(c.name, "a") self.assertEqual(c.ttype, Tuple[bool, Tuple[bool, bool]]) - self.assertEqual(c.bitvec, ["a.0", "a.1.0", "a.1.1"]) \ No newline at end of file + self.assertEqual(c.bitvec, ["a.0", "a.1.0", "a.1.1"])