From 7e81cc75ada85980ca6095b8ce49bdde50fecddc Mon Sep 17 00:00:00 2001 From: "Davide Gessa (dakk)" Date: Sat, 28 Oct 2023 16:15:33 +0200 Subject: [PATCH] crop return type --- TODO.md | 1 + qlasskit/ast2ast.py | 22 +++++++++++++++++++++- qlasskit/ast2logic/t_expression.py | 2 +- qlasskit/ast2logic/t_statement.py | 6 ++++++ qlasskit/qlassf.py | 1 + qlasskit/types/qint.py | 10 ++++++++++ qlasskit/types/qtype.py | 5 +++++ 7 files changed, 45 insertions(+), 2 deletions(-) diff --git a/TODO.md b/TODO.md index c7b05351..57d8b536 100644 --- a/TODO.md +++ b/TODO.md @@ -69,6 +69,7 @@ - [x] Inner function - [x] Groover algorithm - [x] Tuple-tuple comparison +- [x] Multi var assign ### Week 2: (30 Oct 23) diff --git a/qlasskit/ast2ast.py b/qlasskit/ast2ast.py index 96eb6288..18330d34 100644 --- a/qlasskit/ast2ast.py +++ b/qlasskit/ast2ast.py @@ -17,6 +17,21 @@ from .ast2logic import flatten +class NameConstantReplacer(ast.NodeTransformer): + def __init__(self, name_id, constant): + self.name_id = name_id + self.constant = constant + + def generic_visit(self, node): + return super().generic_visit(node) + + def visit_Name(self, node): + if node.id == self.name_id: + return ast.Constant(value=self.constant) + + return node + + class ASTRewriter(ast.NodeTransformer): def __init__(self, env={}, ret=None): self.env = {} @@ -145,7 +160,12 @@ def visit_For(self, node): ast.Assign(targets=[node.target], value=ast.Constant(value=i)) ) rolls.extend(flatten([tar_assign])) - rolls.extend(flatten([self.visit(copy.deepcopy(b)) for b in node.body])) + + new_body = [ + NameConstantReplacer(node.target.id, i).visit(copy.deepcopy(b)) + for b in node.body + ] + rolls.extend(flatten([self.visit(copy.deepcopy(b)) for b in new_body])) return rolls diff --git a/qlasskit/ast2logic/t_expression.py b/qlasskit/ast2logic/t_expression.py index d5c0dfc9..a9cce87b 100644 --- a/qlasskit/ast2logic/t_expression.py +++ b/qlasskit/ast2logic/t_expression.py @@ -253,7 +253,7 @@ def unfold(v_exps, op): elif isinstance(expr.op, ast.BitXor) and hasattr(tleft[0], "bitwise_xor"): return tleft[0].bitwise_xor(tleft, tright) elif isinstance(expr.op, ast.BitAnd) and hasattr(tleft[0], "bitwise_and"): - return tleft[0].bitwise_and(tleft, tright) + return tright[0].bitwise_and(tleft, tright) # type: ignore elif isinstance(expr.op, ast.BitOr) and hasattr(tleft[0], "bitwise_or"): return tleft[0].bitwise_or(tleft, tright) elif ( diff --git a/qlasskit/ast2logic/t_statement.py b/qlasskit/ast2logic/t_statement.py index 64e425ad..ada2db4a 100644 --- a/qlasskit/ast2logic/t_statement.py +++ b/qlasskit/ast2logic/t_statement.py @@ -71,6 +71,12 @@ def translate_statement( # noqa: C901 and texp.BIT_SIZE < ret_type.BIT_SIZE ): texp, vexp = ret_type.fill((texp, vexp)) # type: ignore + elif ( + hasattr(texp, "BIT_SIZE") + and hasattr(ret_type, "BIT_SIZE") + and texp.BIT_SIZE > ret_type.BIT_SIZE + ): + texp, vexp = ret_type.crop((texp, vexp)) # type: ignore elif texp != ret_type: raise exceptions.TypeErrorException(texp, ret_type) diff --git a/qlasskit/qlassf.py b/qlasskit/qlassf.py index 21630c14..33e4a187 100644 --- a/qlasskit/qlassf.py +++ b/qlasskit/qlassf.py @@ -284,6 +284,7 @@ def from_function( # Return the qlassf object qf = QlassF(fun_name, original_f, args, fun_ret, exps) + if to_compile: qf.compile() return qf diff --git a/qlasskit/types/qint.py b/qlasskit/types/qint.py index 308c6f7e..0a45feef 100644 --- a/qlasskit/types/qint.py +++ b/qlasskit/types/qint.py @@ -71,6 +71,16 @@ def fill(cls, v: TExp) -> TExp: ) return v + @classmethod + def crop(cls, v: TExp) -> TExp: + """Crop a Qint to reach its bit_size""" + if len(v[1]) > cls.BIT_SIZE: # type: ignore + v = ( + cls, + v[1][0 : cls.BIT_SIZE], # type: ignore + ) + return v + # Comparators @staticmethod diff --git a/qlasskit/types/qtype.py b/qlasskit/types/qtype.py index c12ca685..c70351ca 100644 --- a/qlasskit/types/qtype.py +++ b/qlasskit/types/qtype.py @@ -73,6 +73,11 @@ def fill(v: TExp) -> TExp: """Fill with leading false""" raise Exception("abstract") + @staticmethod + def crop(v: TExp) -> TExp: + """Crop to right size""" + raise Exception("abstract") + # Comparators @staticmethod