Skip to content

Commit

Permalink
handle annassign
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 31, 2023
1 parent 8313b87 commit 6f01928
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 20 deletions.
55 changes: 35 additions & 20 deletions qlasskit/ast2ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,32 @@ def visit_Name(self, node):
return node


def _replace_types_annotations(ann, arg=None):
if isinstance(ann, ast.Subscript) and ann.value.id == "Tuple":
_elts = ann.slice.elts
_ituple = ast.Tuple(elts=[_replace_types_annotations(el) for el in _elts])

ann = ast.Subscript(
value=ast.Name(id="Tuple", ctx=ast.Load()),
slice=_ituple,
)

# Replace Qlist[T,n] with Tuple[(T,)*3]
if isinstance(ann, ast.Subscript) and ann.value.id == "Qlist":
_elts = ann.slice.elts
_ituple = ast.Tuple(elts=[copy.deepcopy(_elts[0])] * _elts[1].value)

ann = ast.Subscript(
value=ast.Name(id="Tuple", ctx=ast.Load()),
slice=_ituple,
)

if arg is not None:
arg.annotation = ann
return arg
else:
return ann

class ASTRewriter(ast.NodeTransformer):
def __init__(self, env={}, ret=None):
self.env = {}
Expand Down Expand Up @@ -85,32 +111,21 @@ def visit_Name(self, node):

def visit_List(self, node):
return ast.Tuple(elts=[self.visit(el) for el in node.elts])

def visit_AnnAssign(self, node):
node.annotation = _replace_types_annotations(node.annotation)
node.value = self.visit(node.value) if node.value else node.value
self.env[node.target] = node.annotation
return node


def visit_FunctionDef(self, node):
def _replace_types(ann, arg=None):
# Replace Qlist[T,n] with Tuple[(T,)*3]
if isinstance(ann, ast.Subscript) and ann.value.id == "Qlist":
_elts = ann.slice.elts

_ituple = ast.Tuple(elts=[copy.deepcopy(_elts[0])] * _elts[1].value)

ann = ast.Subscript(
value=ast.Name(id="Tuple", ctx=ast.Load()),
slice=_ituple,
)

if arg is not None:
arg.annotation = ann
return arg
else:
return ann

node.args.args = [_replace_types(x.annotation, arg=x) for x in node.args.args]
node.args.args = [_replace_types_annotations(x.annotation, arg=x) for x in node.args.args]

for x in node.args.args:
self.env[x.arg] = x.annotation

node.returns = _replace_types(node.returns)
node.returns = _replace_types_annotations(node.returns)
self.ret = node.returns

return super().generic_visit(node)
Expand Down
25 changes: 25 additions & 0 deletions test/test_ast2logic_t_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,28 @@ def test_tuple_of_int2(self):
"a.1.1",
],
)

def test_list_of_int2(self):
f = "a: Qlist[Qint2, 2]"
ann_ast = ast2ast(ast.parse(f)).body[0].annotation
c = ast2logic.translate_argument(ann_ast, ast2logic.Env(), "a")
self.assertEqual(c.name, "a")
self.assertEqual(c.ttype, Tuple[Qint2, Qint2])
self.assertEqual(
c.bitvec,
[
"a.0.0",
"a.0.1",
"a.1.0",
"a.1.1",
],
)


def test_tuple_of_list2(self):
f = "a: Tuple[bool, Qlist[bool, 2]]"
ann_ast = ast2ast(ast.parse(f)).body[0].annotation
c = ast2logic.translate_argument(ann_ast, ast2logic.Env(), "a")
self.assertEqual(c.name, "a")
self.assertEqual(c.ttype, Tuple[bool, Tuple[bool, bool]])
self.assertEqual(c.bitvec, ["a.0", "a.1.0", "a.1.1"])

0 comments on commit 6f01928

Please sign in to comment.