Skip to content

Commit

Permalink
Merge pull request #2 from yunline/fix-#1
Browse files Browse the repository at this point in the history
Fix #1
  • Loading branch information
yunline authored Sep 12, 2023
2 parents 40665b9 + 72af25a commit bacda4e
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 5 deletions.
39 changes: 39 additions & 0 deletions oneliner/expr_transform.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
from ast import *

from oneliner.namespaces import Namespace
Expand Down Expand Up @@ -80,6 +81,42 @@ def get_result(self) -> expr:
return self.nsp.get_load_name(self.node.id)


_CompNode: typing.TypeAlias = ListComp | SetComp | DictComp | GeneratorExp


class PendingComp(PendingExprGeneric):
target_names: set[str]
node: _CompNode

def __init__(self, node: _CompNode, nsp: Namespace):
super().__init__(node)
self.nsp = nsp
self.target_names = set()

for comp in self.node.generators:
self.get_comp_target_names(comp.target)

self.nsp.comp_stack.append(self)

def get_result(self) -> expr:
assert self.nsp.comp_stack[-1] is self
self.nsp.comp_stack.pop()

return super().get_result()

def get_comp_target_names(self, target):
"""
Recursion warning
"""
if isinstance(target, Name):
self.target_names.add(target.id)
elif isinstance(target, (Tuple, List)):
for sub_target in target.elts:
self.get_comp_target_names(sub_target)
else: # pragma: no cover
raise RuntimeError("Unknown comprehension target")


class ExpressionTransformer:
def __init__(self, nsp: Namespace):
self.pending_stack: list[PendingExprGeneric] = []
Expand All @@ -90,6 +127,8 @@ def get_pending(self, node: expr) -> PendingExprGeneric:
return PendingNamedExpr(node, self.nsp)
elif isinstance(node, Name):
return PendingName(node, self.nsp)
elif isinstance(node, (ListComp, SetComp, DictComp, GeneratorExp)):
return PendingComp(node, self.nsp)
else:
return PendingExprGeneric(node)

Expand Down
66 changes: 62 additions & 4 deletions oneliner/namespaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import symtable
import sys
from ast import *

from oneliner.reserved_identifiers import *
Expand All @@ -18,12 +19,16 @@ class Namespace:
outer_nsp: "Namespace"
inner_nsp: list["Namespace"]

loop_stack: list["oneliner.pending_nodes._PendingLoop"]
comp_stack: list["oneliner.expr_transform.PendingComp"]

@classmethod
def _generate(cls):
raise NotImplementedError() # pragma: no cover

def __init__(self):
self.loop_stack: list["pending_nodes._PendingLoop"] = []
self.loop_stack = []
self.comp_stack = []
self.inner_nsp = []

def get_assign(self, name: str, value_expr: expr) -> expr:
Expand All @@ -43,6 +48,10 @@ def get_load_name(self, name: str) -> expr:


class NamespaceGlobal(Namespace):
use_itertools: bool
use_importlib: bool
use_preset_iter_wrapper: bool

@classmethod
def _generate(cls, symt: symtable.SymbolTable):
self = cls()
Expand Down Expand Up @@ -189,6 +198,10 @@ def get_assign(self, name: str, value_expr: expr) -> expr:
)

def get_load_name(self, name: str) -> expr:
for comp in self.comp_stack:
if name in comp.target_names:
return Name(id=name, ctx=Load())

if name in self.inner_nonlocal_names:
return Subscript(
value=self.nonlocal_dict_expr,
Expand All @@ -215,6 +228,9 @@ class NamespaceClass(Namespace):
# keys --> nonlocal names of THIS namespace
# values --> where the nonlocal name was born

if sys.version_info < (3, 12):
globals_used_in_comp: set[str] # global names used in comprehensions

@classmethod
def _generate(cls, symt: symtable.Class, stack: list[Namespace]):
# don't push/pop the stack in this function
Expand All @@ -225,6 +241,8 @@ def _generate(cls, symt: symtable.Class, stack: list[Namespace]):
self.outer_nsp = stack[-1]
self.outer_nsp.inner_nsp.append(self)
self.outer_nonlocal_map = {}
if sys.version_info < (3, 12):
self.globals_used_in_comp = set()

for symbol in self.symt.get_symbols():
if not (symbol.is_nonlocal() or symbol.is_free()):
Expand Down Expand Up @@ -294,6 +312,13 @@ def get_assign(self, name: str, value_expr: expr) -> expr:
)

def get_load_name(self, name: str) -> expr:
for comp in self.comp_stack:
if name in comp.target_names:
return Name(id=name, ctx=Load())

if sys.version_info < (3, 12) and name in self.globals_used_in_comp:
return Name(id=name, ctx=Load())

symbol = self.symt.lookup(name)
if name in self.outer_nonlocal_map:
outer = self.outer_nonlocal_map[name]
Expand All @@ -313,6 +338,33 @@ def get_load_name(self, name: str) -> expr:
)


if sys.version_info < (3, 12):

def _comp_check(symt: symtable.Function):
if symt.get_name() not in ["listcomp", "genexpr", "setcomp", "dictcomp"]:
return False
if ".0" not in symt.get_parameters():
return False
return True


def update_globals_from_lambda_or_comp(symt: symtable.Function, stack: list[Namespace]):
if not isinstance(stack[-1], NamespaceClass):
return

_globals: set[str] = set()
comp_stack = [symt]
while comp_stack:
symt = comp_stack.pop()
for symbol in symt.get_symbols():
if symbol.is_global():
_globals.add(symbol.get_name())
for child_symt in symt.get_children():
assert isinstance(child_symt, symtable.Function)
comp_stack.append(child_symt)
stack[-1].globals_used_in_comp.update(_globals)


def generate_nsp(symt: symtable.SymbolTable):
walk_stack = []
root = NamespaceGlobal._generate(symt)
Expand All @@ -328,20 +380,26 @@ def generate_nsp(symt: symtable.SymbolTable):
else:
if isinstance(child_symt, symtable.Function):
if child_symt.get_name() == "lambda":
update_globals_from_lambda_or_comp(child_symt, generate_stack)
continue
if sys.version_info < (3, 12):
if _comp_check(child_symt):
update_globals_from_lambda_or_comp(child_symt, generate_stack)
continue

generate_stack.append(
NamespaceFunction._generate(child_symt, generate_stack)
)
walk_stack.append(iter(child_symt.get_children()))
elif isinstance(child_symt, symtable.Class):
generate_stack.append(
NamespaceClass._generate(child_symt, generate_stack)
)
walk_stack.append(iter(child_symt.get_children()))
else: # pragma: no cover
raise RuntimeError("Unknown type of child symbol table")
walk_stack.append(iter(child_symt.get_children()))
return root


# fix error caused by circular import
import oneliner.pending_nodes as pending_nodes
import oneliner.expr_transform
import oneliner.pending_nodes
1 change: 0 additions & 1 deletion oneliner/pending_nodes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import typing
from ast import *
from ast import AST

import oneliner.utils as utils
from oneliner.expr_transform import expr_transf
Expand Down
8 changes: 8 additions & 0 deletions tests/oneliner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ class TestNonlocal(test_utils.OnelinerTestCaseBase):
test_case_filename = "nonlocal.py"


class TestComprehension(test_utils.OnelinerTestCaseBase):
"""
Test if the namespace of comprehension expr is isolated
"""

test_case_filename = "comprehension.py"


class TestGlobal(test_utils.OnelinerTestCaseBase):
test_case_filename = "global.py"

Expand Down
102 changes: 102 additions & 0 deletions tests/test_cases/comprehension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Test if the namespace of comprehension expr is isolated
# type: ignore


def func():
i = 0

def func2():
nonlocal i
i = 2
[print(i) for i in range(10)] # listcomp
list(print(i) for i in range(10)) # genexpr
{print(i) for i in range(10)} # setcomp
{print(i): None for i in range(10)} # dictcomp

func2()
print(i)


func()


def func():
i = 0

class Foo:
nonlocal i
# test: comp inside a class
[print(i) for i in range(10)]

# test: nested comp inside a class, with global names used.
[[print(i), [print(bin(j)) for j in range(5)]] for i in range(10)]

print(i)


func()


def func():
i, j, k, m = 0, 0, 0, 0

def func2():
nonlocal i, j, k, m
i, j, k, m = 9, 9, 9, 9
lst = [(1, (2, 3)), (6, (7, 8))]

# test: multi generator + tuple target
[print(m, k, j, i) for (i, (j, k)) in lst for m in range(4)]

func2()
print(m, k, j, i)


func()


def func():
i, j, k = 0, 0, 0

def func2():
nonlocal i, j, k
i, j, k = 9, 9, 9

# test: nested comp
[
[
print(i),
[
[
print(j),
[print(k) for k in range(4)],
]
for j in range(4)
],
]
for i in range(4)
]

func2()
print(k, j, i)


func()


# test: a function is named "listcomp"
# "listcomp" is used as the name of listcomp symbol table. (py < 3.12)
# this conflict should be handled properly.
def listcomp(a=[i for i in range(3)]): # just for test, don't use kwarg like this.
b = 0

def func2():
nonlocal b
b = 1

func2()
print(b)
print(a)


listcomp()

0 comments on commit bacda4e

Please sign in to comment.