diff --git a/antlr_ast/ast.py b/antlr_ast/ast.py index 2905c8e..68b74a4 100644 --- a/antlr_ast/ast.py +++ b/antlr_ast/ast.py @@ -311,12 +311,21 @@ def get_text(self, full_text=None): def get_position(self): ctx = self._ctx if ctx is not None: - position = { - "line_start": ctx.start.line, - "column_start": ctx.start.column, - "line_end": ctx.stop.line, - "column_end": ctx.stop.column + (ctx.stop.stop - ctx.stop.start), - } + if hasattr(ctx, "symbol"): + position = { + "line_start": ctx.symbol.line, + "column_start": ctx.symbol.column, + "line_end": ctx.symbol.line, + "column_end": ctx.symbol.column + (ctx.symbol.stop - ctx.symbol.start), + } + else: + position = { + "line_start": ctx.start.line, + "column_start": ctx.start.column, + "line_end": ctx.stop.line, + "column_end": ctx.stop.column + (ctx.stop.stop - ctx.stop.start), + } + else: position = self.position return position @@ -337,7 +346,7 @@ class Terminal(BaseNode): """ _fields = tuple(["value"]) - DEBUG = False + DEBUG = True DEBUG_INSTANCES = [] def __new__(cls, *args, **kwargs): @@ -468,7 +477,8 @@ def visit(self, node): if isinstance(alias, AliasNode) or alias == node: # this prevents infinite recursion and visiting # AliasNodes with a name that is also the name of a BaseNode - self.generic_visit(alias) + if isinstance(alias, BaseNode): + self.generic_visit(alias) else: # visit BaseNode (e.g. result of Transformer method) if isinstance(alias, list): diff --git a/tests/test_expr_ast.py b/tests/test_expr_ast.py index 76f0d20..87971b5 100644 --- a/tests/test_expr_ast.py +++ b/tests/test_expr_ast.py @@ -4,6 +4,7 @@ parse as parse_ast, process_tree, BaseNodeTransformer, + Terminal, ) from . import grammar @@ -31,9 +32,6 @@ def visit_SubExpr(self, node): def visit_NotExpr(self, node): return NotExpr.from_spec(node) - def visit_Terminal(self, node): - return node.get_text() - def parse(text, start="expr", **kwargs): antlr_tree = parse_ast(grammar, text, start, upper=False, **kwargs) @@ -60,6 +58,7 @@ def test_subexpr(): node = parse("(1 + 1)") assert isinstance(node, SubExpr) assert isinstance(node.expression, BinaryExpr) + assert isinstance(node.expression.left, Terminal) def test_fields(): @@ -103,3 +102,48 @@ def test_speaker_node_cfg(): assert speaker.describe(node, str_tmp, "left") == str_tmp.format( field_name="left part", node_name="binary expression" ) + + +# BaseNode.get_position ------------------------------------------------------- + + +def test_get_position(): + # Given + code = "1 + (2 + 2)" + correct_position = { + "line_start": 1, + "column_start": 4, + "line_end": 1, + "column_end": 10, + } + + # When + result = parse(code) + positions = result.right.get_position() + + # Then + assert positions["line_start"] == correct_position["line_start"] + assert positions["line_end"] == correct_position["line_end"] + assert positions["column_start"] == correct_position["column_start"] + assert positions["column_end"] == correct_position["column_end"] + + +def test_terminal_get_position(): + # Given + code = "(2 + 2) + 1" + correct_position = { + "line_start": 1, + "column_start": 10, + "line_end": 1, + "column_end": 10, + } + + # When + result = parse(code) + positions = result.right.get_position() + + # Then + assert positions["line_start"] == correct_position["line_start"] + assert positions["line_end"] == correct_position["line_end"] + assert positions["column_start"] == correct_position["column_start"] + assert positions["column_end"] == correct_position["column_end"]