Skip to content

Commit

Permalink
Overload Variables class for better typing experience
Browse files Browse the repository at this point in the history
  • Loading branch information
tchalupnik committed Sep 24, 2024
1 parent 8045b8f commit bde7db8
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions src/prompt_toolkit/contrib/regular_languages/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from __future__ import annotations

import re
from typing import Callable, Dict, Iterable, Iterator, Pattern
from typing import Callable, Dict, Iterable, Iterator, Pattern, TypeVar, overload
from typing import Match as RegexMatch

from .regex_parser import (
Expand All @@ -57,9 +57,7 @@
tokenize_regex,
)

__all__ = [
"compile",
]
__all__ = ["compile", "Match", "Variables"]


# Name of the named group in the regex, matching trailing input.
Expand Down Expand Up @@ -491,6 +489,9 @@ def end_nodes(self) -> Iterable[MatchVariable]:
yield MatchVariable(varname, value, (reg[0], reg[1]))


_T = TypeVar("_T")


class Variables:
def __init__(self, tuples: list[tuple[str, str, tuple[int, int]]]) -> None:
#: List of (varname, value, slice) tuples.
Expand All @@ -502,7 +503,13 @@ def __repr__(self) -> str:
", ".join(f"{k}={v!r}" for k, v, _ in self._tuples),
)

def get(self, key: str, default: str | None = None) -> str | None:
@overload
def get(self, key: str) -> str | None: ...

@overload
def get(self, key: str, default: _T = None) -> str | _T: ...

def get(self, key: str, default: _T = None) -> str | _T:
items = self.getall(key)
return items[0] if items else default

Expand Down

0 comments on commit bde7db8

Please sign in to comment.