Skip to content

Commit

Permalink
Enable debug flag making leaf nodes ast node objects instead of strin…
Browse files Browse the repository at this point in the history
…gs, adapt get_position to get position out of leaf nodes
  • Loading branch information
TimSangster committed Aug 7, 2019
1 parent 4439f71 commit 267aeb1
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 11 deletions.
26 changes: 18 additions & 8 deletions antlr_ast/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,21 @@ def get_text(self, full_text=None):
def get_position(self):
ctx = self._ctx
if ctx is not None:
position = {
"line_start": ctx.start.line,
"column_start": ctx.start.column,
"line_end": ctx.stop.line,
"column_end": ctx.stop.column + (ctx.stop.stop - ctx.stop.start),
}
if hasattr(ctx, "symbol"):
position = {
"line_start": ctx.symbol.line,
"column_start": ctx.symbol.column,
"line_end": ctx.symbol.line,
"column_end": ctx.symbol.column + (ctx.symbol.stop - ctx.symbol.start),
}
else:
position = {
"line_start": ctx.start.line,
"column_start": ctx.start.column,
"line_end": ctx.stop.line,
"column_end": ctx.stop.column + (ctx.stop.stop - ctx.stop.start),
}

else:
position = self.position
return position
Expand All @@ -337,7 +346,7 @@ class Terminal(BaseNode):
"""

_fields = tuple(["value"])
DEBUG = False
DEBUG = True
DEBUG_INSTANCES = []

def __new__(cls, *args, **kwargs):
Expand Down Expand Up @@ -468,7 +477,8 @@ def visit(self, node):
if isinstance(alias, AliasNode) or alias == node:
# this prevents infinite recursion and visiting
# AliasNodes with a name that is also the name of a BaseNode
self.generic_visit(alias)
if isinstance(alias, BaseNode):
self.generic_visit(alias)
else:
# visit BaseNode (e.g. result of Transformer method)
if isinstance(alias, list):
Expand Down
50 changes: 47 additions & 3 deletions tests/test_expr_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
parse as parse_ast,
process_tree,
BaseNodeTransformer,
Terminal,
)

from . import grammar
Expand Down Expand Up @@ -31,9 +32,6 @@ def visit_SubExpr(self, node):
def visit_NotExpr(self, node):
return NotExpr.from_spec(node)

def visit_Terminal(self, node):
return node.get_text()


def parse(text, start="expr", **kwargs):
antlr_tree = parse_ast(grammar, text, start, upper=False, **kwargs)
Expand All @@ -60,6 +58,7 @@ def test_subexpr():
node = parse("(1 + 1)")
assert isinstance(node, SubExpr)
assert isinstance(node.expression, BinaryExpr)
assert isinstance(node.expression.left, Terminal)


def test_fields():
Expand Down Expand Up @@ -103,3 +102,48 @@ def test_speaker_node_cfg():
assert speaker.describe(node, str_tmp, "left") == str_tmp.format(
field_name="left part", node_name="binary expression"
)


# BaseNode.get_position -------------------------------------------------------


def test_get_position():
# Given
code = "1 + (2 + 2)"
correct_position = {
"line_start": 1,
"column_start": 4,
"line_end": 1,
"column_end": 10,
}

# When
result = parse(code)
positions = result.right.get_position()

# Then
assert positions["line_start"] == correct_position["line_start"]
assert positions["line_end"] == correct_position["line_end"]
assert positions["column_start"] == correct_position["column_start"]
assert positions["column_end"] == correct_position["column_end"]


def test_terminal_get_position():
# Given
code = "(2 + 2) + 1"
correct_position = {
"line_start": 1,
"column_start": 10,
"line_end": 1,
"column_end": 10,
}

# When
result = parse(code)
positions = result.right.get_position()

# Then
assert positions["line_start"] == correct_position["line_start"]
assert positions["line_end"] == correct_position["line_end"]
assert positions["column_start"] == correct_position["column_start"]
assert positions["column_end"] == correct_position["column_end"]

0 comments on commit 267aeb1

Please sign in to comment.