Skip to content

Commit

Permalink
Improve CaseTransformInputStream interface
Browse files Browse the repository at this point in the history
  • Loading branch information
hermansje committed Aug 20, 2019
1 parent 7abce7e commit b6a6783
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 19 deletions.
27 changes: 19 additions & 8 deletions antlr_ast/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,23 @@
from antlr4.tree.Tree import ErrorNode, TerminalNodeImpl, ParseTree

from antlr_ast.inputstream import CaseTransformInputStream
from antlr4.error.ErrorListener import ErrorListener
from antlr4.error.ErrorListener import ErrorListener, ConsoleErrorListener


def parse(
grammar,
text: str,
start: str,
strict=False,
upper=True,
transform: Union[str, Callable] = None,
error_listener: ErrorListener = None,
) -> ParseTree:
input_stream = CaseTransformInputStream(text, upper=upper)
input_stream = CaseTransformInputStream(text, transform=transform)

lexer = grammar.Lexer(input_stream)
lexer.removeErrorListeners()
lexer.addErrorListener(LexerErrorListener())

token_stream = CommonTokenStream(lexer)
parser = grammar.Parser(token_stream)
parser.buildParseTrees = True # default
Expand Down Expand Up @@ -148,6 +151,7 @@ def get_info(node_cfg):

# Error Listener ------------------------------------------------------------------


# from antlr4.error.Errors import RecognitionException


Expand All @@ -157,12 +161,12 @@ def __init__(self, msg, orig):


class StrictErrorListener(ErrorListener):
# The recognizer will be the parser instance
def syntaxError(self, recognizer, badSymbol, line, col, msg, e):
if e is not None:
msg = "line {line}: {col} {msg}".format(line=line, col=col, msg=msg)
raise AntlrException(msg, e)
else:
raise AntlrException(msg, None)
msg = "line {line}:{col} {msg}".format(
badSymbol=badSymbol, line=line, col=col, msg=msg
)
raise AntlrException(msg, e)

def reportAmbiguity(
self, recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs
Expand All @@ -183,6 +187,13 @@ def reportContextSensitivity(
# raise Exception("TODO")


class LexerErrorListener(ConsoleErrorListener):
def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
if isinstance(e.input, CaseTransformInputStream):
msg = msg + " " + repr(e.input)
super().syntaxError(recognizer, offendingSymbol, line, column, msg, e)


# Parse Tree Visitor ----------------------------------------------------------
# TODO: visitor inheritance not really needed, but indicates compatibility
# TODO: make general node (Terminal) accessible in class property (.subclasses)?
Expand Down
29 changes: 19 additions & 10 deletions antlr_ast/inputstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,28 @@ class CaseTransformInputStream(InputStream):
"""Support case insensitive languages
https://github.com/antlr/antlr4/blob/master/doc/case-insensitive-lexing.md#custom-character-streams-approach
"""
UPPER = "upper"
LOWER = "lower"

def __init__(self, *args, transform=None, **kwargs):
if transform is None:
self.transform = lambda x: x
elif transform == self.UPPER:
self.transform = methodcaller("upper")
elif transform == self.LOWER:
self.transform = methodcaller("lower")
elif callable(transform):
self.transform = transform
else:
raise ValueError("Invalid transform")

def __init__(self, *args, upper=None, **kwargs):
self.upper = upper
super().__init__(*args, **kwargs)

def _loadString(self):
self._index = 0
if self.upper:
transform = methodcaller("upper")
elif self.upper is False:
transform = methodcaller("lower")
elif self.upper is None:
transform = lambda x: x

self.data = [ord(transform(c)) for c in self.strdata]

self.data = [ord(self.transform(c)) for c in self.strdata]
self._size = len(self.data)

def __repr__(self):
return "<{} {}>".format(self.__class__.__name__, self.transform)
5 changes: 4 additions & 1 deletion tests/test_expr_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
BaseNodeTransformer,
Terminal,
)
from antlr_ast.inputstream import CaseTransformInputStream

from . import grammar

Expand Down Expand Up @@ -34,7 +35,9 @@ def visit_NotExpr(self, node):


def parse(text, start="expr", **kwargs):
antlr_tree = parse_ast(grammar, text, start, upper=False, **kwargs)
antlr_tree = parse_ast(
grammar, text, start, transform=CaseTransformInputStream.LOWER, **kwargs
)
simple_tree = process_tree(antlr_tree, transformer_cls=Transformer)

return simple_tree
Expand Down

0 comments on commit b6a6783

Please sign in to comment.