From 82dbf56b08db59e6beb0d2510f38d7ec493172aa Mon Sep 17 00:00:00 2001 From: Davide Gessa Date: Wed, 20 Mar 2024 11:10:17 +0100 Subject: [PATCH] int and float builting cast (#25) --- docs/source/supported.rst | 2 ++ qlasskit/ast2ast.py | 31 ++++++++++++------ qlasskit/ast2logic/t_expression.py | 52 +++++++++++++++++++++++++----- qlasskit/types/qfixed.py | 6 +++- qlasskit/types/qint.py | 6 +++- test/qlassf/test_builtin.py | 25 ++++++++++++++ 6 files changed, 102 insertions(+), 20 deletions(-) diff --git a/docs/source/supported.rst b/docs/source/supported.rst index dc9ce1af..dbfa4113 100644 --- a/docs/source/supported.rst +++ b/docs/source/supported.rst @@ -205,6 +205,8 @@ Bultin functions: - `any(Tuple)`, `any(Qlist)`: returns True if any of the elemnts are True - `ord(Qchar)`: returns the integer value of the given Qchar - `chr(Qint)`: returns the char given its ascii code +- `int(Qfixed | Qint)`: returns the integer part of a Qfixed +- `float(Qint | Qfixed)`: returns a Qfixed representing the Qint Statements diff --git a/qlasskit/ast2ast.py b/qlasskit/ast2ast.py index 5920e203..24bf2ce2 100644 --- a/qlasskit/ast2ast.py +++ b/qlasskit/ast2ast.py @@ -73,6 +73,20 @@ def _replace_types_annotations(ann, arg=None): slice=_ituple, ) + # Replace QintX with Qint[X] + if isinstance(ann, ast.Name) and ann.id[:4] == "Qint": + ann = ast.Subscript( + value=ast.Name(id="Qint", ctx=ast.Load()), + slice=ast.Constant(value=int(ann.id[4:])), + ) + + # Replace QfixedX with Qfixed[X] + if isinstance(ann, ast.Name) and ann.id[:6] == "Qfixed": + ann = ast.Subscript( + value=ast.Name(id="Qfixed", ctx=ast.Load()), + slice=ast.Constant(value=int(ann.id[6:])), + ) + # Replace Qlist[T,n] with Tuple[(T,)*n] if isinstance(ann, ast.Subscript) and ann.value.id == "Qlist": _elts = ann.slice.elts @@ -471,11 +485,11 @@ def iterif(arg_l): return iterif(args) def __call_sum(self, node): - if len(node.args) == 1: - args = self.__unroll_arg(node.args[0]) - else: + if len(node.args) != 1: raise Exception(f"sum() takes at most 1 argument ({len(node.args)} given)") + args = self.__unroll_arg(node.args[0]) + def iterif(arg_l): if len(arg_l) == 1: return arg_l[0] @@ -485,25 +499,22 @@ def iterif(arg_l): return iterif(args) def __call_anyall(self, node): - if len(node.args) == 1: - args = self.__unroll_arg(node.args[0]) - else: + if len(node.args) != 1: raise Exception(f"any() takes exactly 1 argument ({len(node.args)} given)") + args = self.__unroll_arg(node.args[0]) op = ast.Or() if node.func.id == "any" else ast.And() return ast.BoolOp(op=op, values=args) def __call_chr(self, node): if len(node.args) != 1: raise Exception(f"chr() takes exactly 1 argument ({len(node.args)} given)") - args = self.__unroll_arg(node.args[0]) - return args[0] + return node.args[0] def __call_ord(self, node): if len(node.args) != 1: raise Exception(f"ord() takes exactly 1 argument ({len(node.args)} given)") - args = self.__unroll_arg(node.args[0]) - return args[0] + return node.args[0] def visit_Call(self, node): node.args = [self.visit(ar) for ar in node.args] diff --git a/qlasskit/ast2logic/t_expression.py b/qlasskit/ast2logic/t_expression.py index 36f3a4eb..34abcf25 100644 --- a/qlasskit/ast2logic/t_expression.py +++ b/qlasskit/ast2logic/t_expression.py @@ -18,7 +18,7 @@ from sympy.logic import ITE, And, Not, Or, Xor, false, true from ..boolquant import QuantumBooleanGate -from ..types import Qbool, Qtype, TExp, const_to_qtype +from ..types import Qbool, Qfixed, Qint, Qtype, TExp, const_to_qtype from . import Env, exceptions @@ -313,19 +313,55 @@ def unfold(v_exps, op): else: return args[0][0], q_gate(*args_v) - if not hasattr(expr.func, "id"): + elif not hasattr(expr.func, "id"): raise exceptions.ExpressionNotHandledException(expr) # Typecast - if ( - env.know_type(expr.func.id) - and len(expr.args) == 1 - and isinstance(expr.args[0], ast.Constant) - ): + elif env.know_type(expr.func.id): + if len(expr.args) != 1: + raise Exception( + f"{expr.func.id}() takes exactly 1 argument ({len(expr.args)} given)" + ) + + if not isinstance(expr.args[0], ast.Constant): + raise Exception(f"{expr.func.id}() accepts only constant values") + return env.gettype(expr.func.id).const(expr.args[0].value) + # int() + elif expr.func.id == "int": + if len(expr.args) != 1: + raise Exception( + f"int() takes exactly 1 argument ({len(expr.args)} given)" + ) + + (ta, te) = translate_expression(expr.args[0], env) + if ta.__name__[:4] == "Qint": # type: ignore + return (ta, te) + elif ta.__name__[:6] == "Qfixed": # type: ignore + ip = ta.integer_part((ta, te)) # type: ignore + return (Qint.type_for_size(len(ip)), ip) + else: + raise Exception(f"int() accepts only Qfixed and Qint: {ta} given") + + # float() + elif expr.func.id == "float": + if len(expr.args) != 1: + raise Exception( + f"float() takes exactly 1 argument ({len(expr.args)} given)" + ) + + (ta, te) = translate_expression(expr.args[0], env) + if ta.__name__[:6] == "Qfixed": # type: ignore + return (ta, te) + elif ta.__name__[:4] == "Qint": # type: ignore + tf = Qfixed.type_for_size(len(te)) + return tf.fill((tf, te)) + else: + raise Exception(f"float() accepts only Qfixed and Qint: {ta} given") + # Known function - if env.know_function(expr.func.id): + elif env.know_function(expr.func.id): def_f = env.getdef(expr.func.id) args = [translate_expression(e, env) for e in expr.args] diff --git a/qlasskit/types/qfixed.py b/qlasskit/types/qfixed.py index 149e150a..801b1dec 100644 --- a/qlasskit/types/qfixed.py +++ b/qlasskit/types/qfixed.py @@ -361,4 +361,8 @@ def __getitem__(cls, params): class Qfixed(metaclass=QfixedMeta): - pass + @staticmethod + def type_for_size(s: int): + for det_type in QFIXED_TYPES: + if det_type.BIT_SIZE_INTEGER == s: + return det_type diff --git a/qlasskit/types/qint.py b/qlasskit/types/qint.py index 2572cebf..49828f95 100644 --- a/qlasskit/types/qint.py +++ b/qlasskit/types/qint.py @@ -296,4 +296,8 @@ def __getitem__(cls, params): class Qint(metaclass=QintMeta): - pass + @staticmethod + def type_for_size(s: int): + for det_type in QINT_TYPES: + if det_type.BIT_SIZE == s: + return det_type diff --git a/test/qlassf/test_builtin.py b/test/qlassf/test_builtin.py index 532b16be..187064b1 100644 --- a/test/qlassf/test_builtin.py +++ b/test/qlassf/test_builtin.py @@ -144,3 +144,28 @@ def test_max_in_list(self): # f = "def test(a: Qlist[bool, 3]) -> Qint[4]:\n\treturn range(len(a))" # qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) # compute_and_compare_results(self, qf) + + def test_int_of_int(self): + f = "def test(a: Qint[2]) -> Qint[4]:\n\treturn int(a) * 2" + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) + + def test_int_of_int2(self): + f = "def test(a: Qint2) -> Qint[4]:\n\treturn int(a) * 2" + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) + + def test_int_of_fixed(self): + f = "def test(a: Qfixed[2,2]) -> Qint[4]:\n\treturn int(a) * 2" + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) + + def test_float_of_fixed(self): + f = "def test(a: Qfixed[2,2]) -> Qfixed[2,2]:\n\treturn float(a)" + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) + + def test_float_of_int(self): + f = "def test(a: Qint[2]) -> Qfixed[2,2]:\n\treturn float(a)" + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf)