diff --git a/.gitignore b/.gitignore index 26275e4b..d4e64180 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ *.pyc *.pyo /.tox +/lark.egg-info/** /lark_parser.egg-info/** tags .vscode diff --git a/lark/tree.py b/lark/tree.py index 76f8738e..9dccadd7 100644 --- a/lark/tree.py +++ b/lark/tree.py @@ -3,8 +3,10 @@ from typing import List, Callable, Iterator, Union, Optional, Generic, TypeVar, TYPE_CHECKING +from .lexer import Token + if TYPE_CHECKING: - from .lexer import TerminalDef, Token + from .lexer import TerminalDef try: import rich except ImportError: @@ -171,6 +173,16 @@ def find_data(self, data: str) -> 'Iterator[Tree[_Leaf_T]]': ###} + def find_token(self, token_type: str) -> Iterator[_Leaf_T]: + """Returns all tokens whose type equals the given token_type. + + This is a recursive function that will find tokens in all the subtrees. + + Example: + >>> term_tokens = tree.find_token('TERM') + """ + return self.scan_values(lambda v: isinstance(v, Token) and v.type == token_type) + def expand_kids_by_data(self, *data_values): """Expand (inline) children with any of the given data values. Returns True if anything changed""" changed = False diff --git a/tests/test_trees.py b/tests/test_trees.py index 1f69869e..55fdae91 100644 --- a/tests/test_trees.py +++ b/tests/test_trees.py @@ -17,6 +17,11 @@ class TestTrees(TestCase): def setUp(self): self.tree1 = Tree('a', [Tree(x, y) for x, y in zip('bcd', 'xyz')]) + self.tree2 = Tree('a', [ + Tree('b', [Token('T', 'x')]), + Tree('c', [Token('T', 'y')]), + Tree('d', [Tree('z', [Token('T', 'zz'), Tree('zzz', 'zzz')])]), + ]) def test_eq(self): assert self.tree1 == self.tree1 @@ -48,6 +53,11 @@ def test_iter_subtrees_topdown(self): nodes = list(self.tree1.iter_subtrees_topdown()) self.assertEqual(nodes, expected) + def test_find_token(self): + expected = [Token('T', 'x'), Token('T', 'y'), Token('T', 'zz')] + tokens = list(self.tree2.find_token('T')) + self.assertEqual(tokens, expected) + def test_visitor(self): class Visitor1(Visitor): def __init__(self):