diff --git a/src/prompt_toolkit/contrib/regular_languages/compiler.py b/src/prompt_toolkit/contrib/regular_languages/compiler.py index dd558a68a..699a600f6 100644 --- a/src/prompt_toolkit/contrib/regular_languages/compiler.py +++ b/src/prompt_toolkit/contrib/regular_languages/compiler.py @@ -42,7 +42,7 @@ from __future__ import annotations import re -from typing import Callable, Dict, Iterable, Iterator, Pattern +from typing import Callable, Dict, Iterable, Iterator, Pattern, TypeVar, overload from typing import Match as RegexMatch from .regex_parser import ( @@ -57,9 +57,7 @@ tokenize_regex, ) -__all__ = [ - "compile", -] +__all__ = ["compile", "Match", "Variables"] # Name of the named group in the regex, matching trailing input. @@ -491,6 +489,9 @@ def end_nodes(self) -> Iterable[MatchVariable]: yield MatchVariable(varname, value, (reg[0], reg[1])) +_T = TypeVar("_T") + + class Variables: def __init__(self, tuples: list[tuple[str, str, tuple[int, int]]]) -> None: #: List of (varname, value, slice) tuples. @@ -502,7 +503,13 @@ def __repr__(self) -> str: ", ".join(f"{k}={v!r}" for k, v, _ in self._tuples), ) - def get(self, key: str, default: str | None = None) -> str | None: + @overload + def get(self, key: str) -> str | None: ... + + @overload + def get(self, key: str, default: _T = None) -> str | _T: ... + + def get(self, key: str, default: _T = None) -> str | _T: items = self.getall(key) return items[0] if items else default