Skip to content

Commit

Permalink
Merge pull request #1467 from makukha/1466-add-tree-find-token
Browse files Browse the repository at this point in the history
Add Tree.find_token() method
  • Loading branch information
erezsh authored Jan 4, 2025
2 parents 2f7c9a4 + 971e418 commit 9ca0c5d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
*.pyc
*.pyo
/.tox
/lark.egg-info/**
/lark_parser.egg-info/**
tags
.vscode
Expand Down
14 changes: 13 additions & 1 deletion lark/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

from typing import List, Callable, Iterator, Union, Optional, Generic, TypeVar, TYPE_CHECKING

from .lexer import Token

if TYPE_CHECKING:
from .lexer import TerminalDef, Token
from .lexer import TerminalDef
try:
import rich
except ImportError:
Expand Down Expand Up @@ -171,6 +173,16 @@ def find_data(self, data: str) -> 'Iterator[Tree[_Leaf_T]]':

###}

def find_token(self, token_type: str) -> Iterator[_Leaf_T]:
"""Returns all tokens whose type equals the given token_type.
This is a recursive function that will find tokens in all the subtrees.
Example:
>>> term_tokens = tree.find_token('TERM')
"""
return self.scan_values(lambda v: isinstance(v, Token) and v.type == token_type)

def expand_kids_by_data(self, *data_values):
"""Expand (inline) children with any of the given data values. Returns True if anything changed"""
changed = False
Expand Down
10 changes: 10 additions & 0 deletions tests/test_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
class TestTrees(TestCase):
def setUp(self):
self.tree1 = Tree('a', [Tree(x, y) for x, y in zip('bcd', 'xyz')])
self.tree2 = Tree('a', [
Tree('b', [Token('T', 'x')]),
Tree('c', [Token('T', 'y')]),
Tree('d', [Tree('z', [Token('T', 'zz'), Tree('zzz', 'zzz')])]),
])

def test_eq(self):
assert self.tree1 == self.tree1
Expand Down Expand Up @@ -48,6 +53,11 @@ def test_iter_subtrees_topdown(self):
nodes = list(self.tree1.iter_subtrees_topdown())
self.assertEqual(nodes, expected)

def test_find_token(self):
expected = [Token('T', 'x'), Token('T', 'y'), Token('T', 'zz')]
tokens = list(self.tree2.find_token('T'))
self.assertEqual(tokens, expected)

def test_visitor(self):
class Visitor1(Visitor):
def __init__(self):
Expand Down

0 comments on commit 9ca0c5d

Please sign in to comment.