Skip to content

Commit

Permalink
Remove Context class, use NamespaceGlobal
Browse files Browse the repository at this point in the history
  • Loading branch information
yunline committed Dec 11, 2024
1 parent 84c7f67 commit 01bbb02
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 64 deletions.
19 changes: 0 additions & 19 deletions oneliner/contex.py

This file was deleted.

8 changes: 3 additions & 5 deletions oneliner/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import oneliner.utils as utils
from oneliner.config import Configs
from oneliner.contex import Context
from oneliner.namespaces import Namespace, generate_nsp
from oneliner.pending_nodes import *

Expand Down Expand Up @@ -33,9 +32,8 @@ def convert(
ast_root: ast.Module, symtable_root: symtable.SymbolTable, configs: Configs
) -> ast.expr:
pending_node_stack: list[PendingNode] = []
nsp_global = generate_nsp(symtable_root)
nsp_global = generate_nsp(symtable_root, configs)
nsp_stack: list[Namespace] = [nsp_global]
ctx = Context(nsp_global, configs)

def pending_top() -> PendingNode:
"""Get the stack top of self.pending_node_stack"""
Expand All @@ -46,7 +44,7 @@ def get_pending_node(node: ast.AST) -> PendingNode:
return ast2pending[type(node)](
node,
nsp=nsp_stack[-1],
context=ctx,
nsp_global=nsp_global,
)
except KeyError as err:
raise RuntimeError(
Expand Down Expand Up @@ -80,4 +78,4 @@ def get_pending_node(node: ast.AST) -> PendingNode:

if len(pending_node_stack) == 0:
assert len(nsp_stack) == 1
return ctx.expr_wraper(result_nodes)
return nsp_global.expr_wraper(result_nodes)
11 changes: 10 additions & 1 deletion oneliner/namespaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing
from ast import *

from oneliner.config import Configs
from oneliner.reserved_identifiers import *

__all__ = [
Expand Down Expand Up @@ -52,6 +53,13 @@ class NamespaceGlobal(Namespace[symtable.SymbolTable]):
use_importlib: bool = False
use_preset_iter_wrapper: bool = False

configs: Configs
expr_wraper: typing.Callable[[list[expr]], expr]

def load_configs(self, configs: Configs):
self.configs = configs
self.expr_wraper = utils.get_expr_wrapper(configs)

def get_assign(self, name: str, value_expr: expr) -> NamedExpr:
return NamedExpr(target=Name(id=name, ctx=Store()), value=value_expr)

Expand Down Expand Up @@ -332,10 +340,11 @@ def update_globals_from_lambda_or_comp(symt: symtable.Function, stack: list[Name
stack[-1].globals_used_in_comp.update(_globals)


def generate_nsp(symt: symtable.SymbolTable):
def generate_nsp(symt: symtable.SymbolTable, configs: Configs):
walk_stack = []
generate_stack: list[Namespace] = []
root = NamespaceGlobal(symt, generate_stack)
root.load_configs(configs)
generate_stack.append(root)

walk_stack.append(iter(symt.get_children()))
Expand Down
81 changes: 42 additions & 39 deletions oneliner/pending_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
from ast import *

import oneliner.utils as utils
from oneliner.contex import Context
from oneliner.expr_transform import expr_transf
from oneliner.namespaces import Namespace, NamespaceClass, NamespaceFunction
from oneliner.namespaces import (
Namespace,
NamespaceClass,
NamespaceFunction,
NamespaceGlobal,
)
from oneliner.reserved_identifiers import *

__all__ = [
Expand Down Expand Up @@ -54,11 +58,10 @@
class PendingNode(typing.Generic[T]):
node: T

def __init__(self, node: T, nsp: Namespace, context: "Context"):
def __init__(self, node: T, nsp: Namespace, nsp_global: NamespaceGlobal):
self.iter_node = self._iter_nodes()
self.nsp = nsp
self.nsp_global = context.nsp_global
self.context = context
self.nsp_global = nsp_global
self.node = node

def get_result(self) -> list[expr]:
Expand All @@ -77,8 +80,8 @@ def get_internal_namespace(self) -> Namespace:
class PendingModule(PendingNode[Module]):
converted_body: list[expr]

def __init__(self, node: Module, nsp: Namespace, context: Context):
super().__init__(node, nsp, context)
def __init__(self, node: Module, nsp: Namespace, nsp_global: NamespaceGlobal):
super().__init__(node, nsp, nsp_global)
self.converted_body = []

def _iter_nodes(self) -> typing.Generator[AST, list[expr], None]:
Expand Down Expand Up @@ -159,7 +162,7 @@ def _iter_branch(
stack[-1].append(
IfExp(
test=UnaryOp(op=Not(), operand=get_flow_control_expr()),
body=self.context.expr_wraper(wrapped),
body=self.nsp_global.expr_wraper(wrapped),
orelse=Constant(value=...),
)
)
Expand All @@ -170,8 +173,8 @@ class PendingIf(_PendingCompoundStmt[If]):
converted_body: list[expr]
converted_orelse: list[expr]

def __init__(self, node: If, nsp: Namespace, context: Context):
super().__init__(node, nsp, context)
def __init__(self, node: If, nsp: Namespace, nsp_global: NamespaceGlobal):
super().__init__(node, nsp, nsp_global)

self.converted_body = []
self.converted_orelse = []
Expand All @@ -180,8 +183,8 @@ def get_result(self) -> list[expr]:
return [
IfExp(
test=expr_transf(self.nsp, self.node.test),
body=self.context.expr_wraper(self.converted_body),
orelse=self.context.expr_wraper(self.converted_orelse),
body=self.nsp_global.expr_wraper(self.converted_body),
orelse=self.nsp_global.expr_wraper(self.converted_orelse),
)
]

Expand Down Expand Up @@ -262,8 +265,8 @@ def _iter_nodes(self) -> typing.Generator[AST, list[expr], None]:

class PendingWhile(_PendingLoop[While]):

def __init__(self, node: While, nsp: Namespace, context: Context):
super().__init__(node, nsp, context)
def __init__(self, node: While, nsp: Namespace, nsp_global: NamespaceGlobal):
super().__init__(node, nsp, nsp_global)

self.converted_body = []
self.converted_orelse = []
Expand Down Expand Up @@ -327,15 +330,15 @@ def get_result(self) -> list[expr]:
if self.break_cnt:
while_loop_orelse = IfExp(
test=UnaryOp(op=Not(), operand=self.flow_ctrl_break_expr),
body=self.context.expr_wraper(self.converted_orelse),
body=self.nsp_global.expr_wraper(self.converted_orelse),
orelse=Constant(value=...),
)
else:
while_loop_orelse = self.context.expr_wraper(self.converted_orelse)
while_loop_orelse = self.nsp_global.expr_wraper(self.converted_orelse)

# the main body of the oneliner while loop
while_loop_body = ListComp(
elt=self.context.expr_wraper(self.converted_body),
elt=self.nsp_global.expr_wraper(self.converted_body),
generators=[
comprehension(
target=Name(id="_", ctx=Store()),
Expand Down Expand Up @@ -385,8 +388,8 @@ def get_result(self) -> list[expr]:

class PendingFor(_PendingLoop[For]):

def __init__(self, node: For, nsp: Namespace, context: Context):
super().__init__(node, nsp, context)
def __init__(self, node: For, nsp: Namespace, nsp_global: NamespaceGlobal):
super().__init__(node, nsp, nsp_global)

self.converted_body = []
self.converted_orelse = []
Expand All @@ -405,7 +408,7 @@ def get_result(self) -> list[expr]:
if self.interrupt_cnt == 0 and len(self.node.orelse) == 0:
return [
ListComp(
elt=self.context.expr_wraper(self.converted_body),
elt=self.nsp_global.expr_wraper(self.converted_body),
generators=[
comprehension(
target=self.node.target,
Expand Down Expand Up @@ -468,15 +471,15 @@ def get_result(self) -> list[expr]:
ctx=Load(),
),
),
body=self.context.expr_wraper(self.converted_orelse),
body=self.nsp_global.expr_wraper(self.converted_orelse),
orelse=Constant(value=...),
)
else:
for_loop_orelse = self.context.expr_wraper(self.converted_orelse)
for_loop_orelse = self.nsp_global.expr_wraper(self.converted_orelse)

# the main body of the oneliner for loop
for_loop_body = ListComp(
elt=self.context.expr_wraper(self.converted_body),
elt=self.nsp_global.expr_wraper(self.converted_body),
generators=[
comprehension(
target=self.node.target,
Expand All @@ -498,8 +501,8 @@ def get_result(self) -> list[expr]:


class PendingBreak(PendingNode[Break]):
def __init__(self, node: Break, nsp: Namespace, context: Context):
super().__init__(node, nsp, context)
def __init__(self, node: Break, nsp: Namespace, nsp_global: NamespaceGlobal):
super().__init__(node, nsp, nsp_global)
if len(self.nsp.loop_stack) == 0:
raise SyntaxError(
utils.ast_debug_info(node) + "'break' is not inside a loop"
Expand Down Expand Up @@ -535,8 +538,8 @@ def get_result(self) -> list[expr]:


class PeindingContinue(PendingNode[Continue]):
def __init__(self, node: Continue, nsp: Namespace, context: Context):
super().__init__(node, nsp, context)
def __init__(self, node: Continue, nsp: Namespace, nsp_global: NamespaceGlobal):
super().__init__(node, nsp, nsp_global)
if len(self.nsp.loop_stack) == 0:
raise SyntaxError(
utils.ast_debug_info(node) + "'continue' is not inside a loop"
Expand Down Expand Up @@ -827,8 +830,8 @@ class PendingFunctionDef(_PendingCompoundStmt[FunctionDef]):
internal_nsp: NamespaceFunction
converted_body: list[expr]

def __init__(self, node: FunctionDef, nsp: Namespace, context: Context):
super().__init__(node, nsp, context)
def __init__(self, node: FunctionDef, nsp: Namespace, nsp_global: NamespaceGlobal):
super().__init__(node, nsp, nsp_global)

for tmp_nsp in self.nsp.inner_nsp:
if (
Expand Down Expand Up @@ -936,10 +939,10 @@ def get_result(self) -> list[expr]:
)
)

if self.context.configs.expr_wrapper == "list":
if self.nsp_global.configs.expr_wrapper == "list":
body.extend(self.converted_body)
else:
body.append(self.context.expr_wraper(self.converted_body))
body.append(self.nsp_global.expr_wraper(self.converted_body))
body.append(self.internal_nsp.return_value_expr)
body_expr = utils.list_wrapper(body)

Expand Down Expand Up @@ -971,11 +974,11 @@ def get_result(self) -> list[expr]:


class PendingReturn(PendingNode[Return]):
def __init__(self, node: Return, nsp: Namespace, context: Context):
def __init__(self, node: Return, nsp: Namespace, nsp_global: NamespaceGlobal):
if not isinstance(nsp, NamespaceFunction):
raise SyntaxError(utils.ast_debug_info(node) + "'return' outside function")

super().__init__(node, nsp, context)
super().__init__(node, nsp, nsp_global)
self.nsp: NamespaceFunction

self.nsp.return_cnt += 1
Expand Down Expand Up @@ -1034,8 +1037,8 @@ class PendingClassDef(_PendingCompoundStmt[ClassDef]):
internal_nsp: NamespaceClass
converted_body: list[expr]

def __init__(self, node: ClassDef, nsp: Namespace, context: Context):
super().__init__(node, nsp, context)
def __init__(self, node: ClassDef, nsp: Namespace, nsp_global: NamespaceGlobal):
super().__init__(node, nsp, nsp_global)

for tmp_nsp in self.nsp.inner_nsp:
if (
Expand Down Expand Up @@ -1175,8 +1178,8 @@ def get_result(self) -> list[expr]:


class PendingImport(PendingNode[Import]):
def __init__(self, node: Import, nsp: Namespace, context: Context):
super().__init__(node, nsp, context)
def __init__(self, node: Import, nsp: Namespace, nsp_global: NamespaceGlobal):
super().__init__(node, nsp, nsp_global)
self.nsp_global.use_importlib = True

def get_result(self) -> list[expr]:
Expand Down Expand Up @@ -1205,8 +1208,8 @@ def get_result(self) -> list[expr]:


class PendingImportFrom(PendingNode[ImportFrom]):
def __init__(self, node: ImportFrom, nsp: Namespace, context: Context):
super().__init__(node, nsp, context)
def __init__(self, node: ImportFrom, nsp: Namespace, nsp_global: NamespaceGlobal):
super().__init__(node, nsp, nsp_global)

def get_result(self) -> list[expr]:
result: list[expr] = []
Expand Down

0 comments on commit 01bbb02

Please sign in to comment.