diff --git a/qlasskit/ast2ast.py b/qlasskit/ast2ast.py index 9aa3a736..3f622c52 100644 --- a/qlasskit/ast2ast.py +++ b/qlasskit/ast2ast.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import ast +import copy from .ast2logic import flatten @@ -19,6 +20,7 @@ class ASTRewriter(ast.NodeTransformer): def __init__(self, env={}, ret=None): self.env = {} + self.const = {} self.ret = None def __unroll_arg(self, arg): @@ -42,9 +44,19 @@ def __unroll_arg(self, arg): def generic_visit(self, node): return super().generic_visit(node) + def visit_Subscript(self, node): + if ( + isinstance(node.slice, ast.Index) + and isinstance(node.slice.value, ast.Name) + and node.slice.value.id in self.const + ): + node.slice.value = self.const[node.slice.value.id] + return node + def visit_Name(self, node): if node.id[0:2] == "__": raise Exception("invalid name starting with __") + return node def visit_FunctionDef(self, node): @@ -64,6 +76,7 @@ def visit_Assign(self, node): # TODO: support unrolling tuple # TODO: if value is not self referencing, we can skip this (ie: a = b + 1) + # Reassigning an already present variable (use a temp variable) if was_known and not isinstance(node.value, ast.Constant): new_targ = ast.Name(id=f"__{node.targets[0].id}", ctx=ast.Load()) @@ -71,7 +84,7 @@ def visit_Assign(self, node): return [ ast.Assign( targets=[new_targ], - value=node.value, + value=self.visit(node.value), ), ast.Assign( targets=node.targets, @@ -90,7 +103,9 @@ def visit_AugAssign(self, node): return [ ast.Assign( targets=[new_targ], - value=ast.BinOp(left=node.target, op=node.op, right=node.value), + value=self.visit( + ast.BinOp(left=node.target, op=node.op, right=node.value) + ), ), ast.Assign( targets=[node.target], @@ -103,13 +118,13 @@ def visit_For(self, node): rolls = [] for i in iter: + self.const[node.target.id] = ast.Constant(value=i) tar_assign = self.visit( ast.Assign(targets=[node.target], value=ast.Constant(value=i)) ) rolls.extend(flatten([tar_assign])) - rolls.extend(flatten([self.visit(b) for b in node.body])) + rolls.extend(flatten([self.visit(copy.deepcopy(b)) for b in node.body])) - # print(list(map(ast.dump, rolls))) return rolls def __call_range(self, node):