Skip to content

Commit

Permalink
fix subscript with constant in for unrolling
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 23, 2023
1 parent 7cce318 commit bf6e7d9
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions qlasskit/ast2ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import copy

from .ast2logic import flatten


class ASTRewriter(ast.NodeTransformer):
def __init__(self, env={}, ret=None):
self.env = {}
self.const = {}
self.ret = None

def __unroll_arg(self, arg):
Expand All @@ -42,9 +44,19 @@ def __unroll_arg(self, arg):
def generic_visit(self, node):
return super().generic_visit(node)

def visit_Subscript(self, node):
if (
isinstance(node.slice, ast.Index)
and isinstance(node.slice.value, ast.Name)
and node.slice.value.id in self.const
):
node.slice.value = self.const[node.slice.value.id]
return node

def visit_Name(self, node):
if node.id[0:2] == "__":
raise Exception("invalid name starting with __")

return node

def visit_FunctionDef(self, node):
Expand All @@ -64,14 +76,15 @@ def visit_Assign(self, node):

# TODO: support unrolling tuple
# TODO: if value is not self referencing, we can skip this (ie: a = b + 1)

# Reassigning an already present variable (use a temp variable)
if was_known and not isinstance(node.value, ast.Constant):
new_targ = ast.Name(id=f"__{node.targets[0].id}", ctx=ast.Load())

return [
ast.Assign(
targets=[new_targ],
value=node.value,
value=self.visit(node.value),
),
ast.Assign(
targets=node.targets,
Expand All @@ -90,7 +103,9 @@ def visit_AugAssign(self, node):
return [
ast.Assign(
targets=[new_targ],
value=ast.BinOp(left=node.target, op=node.op, right=node.value),
value=self.visit(
ast.BinOp(left=node.target, op=node.op, right=node.value)
),
),
ast.Assign(
targets=[node.target],
Expand All @@ -103,13 +118,13 @@ def visit_For(self, node):

rolls = []
for i in iter:
self.const[node.target.id] = ast.Constant(value=i)
tar_assign = self.visit(
ast.Assign(targets=[node.target], value=ast.Constant(value=i))
)
rolls.extend(flatten([tar_assign]))
rolls.extend(flatten([self.visit(b) for b in node.body]))
rolls.extend(flatten([self.visit(copy.deepcopy(b)) for b in node.body]))

# print(list(map(ast.dump, rolls)))
return rolls

def __call_range(self, node):
Expand Down

0 comments on commit bf6e7d9

Please sign in to comment.