diff --git a/stdlib/email/message.pyi b/stdlib/email/message.pyi index 18852f4d3bb2..7c00a756edf9 100644 --- a/stdlib/email/message.pyi +++ b/stdlib/email/message.pyi @@ -13,6 +13,7 @@ _T = TypeVar("_T") _PayloadType: TypeAlias = list[Message] | str | bytes | bytearray _CharsetType: TypeAlias = Charset | str | None +# Type returned by Policy.header_fetch_parse, AnyOf[str | Header] _HeaderType: TypeAlias = Any class Message: @@ -31,7 +32,8 @@ class Message: def __len__(self) -> int: ... def __contains__(self, name: str) -> bool: ... def __iter__(self) -> Iterator[str]: ... - def __getitem__(self, name: str) -> _HeaderType: ... + # Same as get with failobj=None + def __getitem__(self, name: str) -> _HeaderType | None: ... def __setitem__(self, name: str, val: _HeaderType) -> None: ... def __delitem__(self, name: str) -> None: ... def keys(self) -> list[str]: ... diff --git a/stdlib/email/policy.pyi b/stdlib/email/policy.pyi index 804044031fcd..344f07d6e2a5 100644 --- a/stdlib/email/policy.pyi +++ b/stdlib/email/policy.pyi @@ -1,14 +1,23 @@ +from _typeshed import Unused from abc import ABCMeta, abstractmethod from collections.abc import Callable from email.contentmanager import ContentManager from email.errors import MessageDefect from email.header import Header +from email.headerregistry import BaseHeader from email.message import Message -from typing import Any +from typing import Any, Protocol, TypeVar, overload from typing_extensions import Self __all__ = ["Compat32", "compat32", "Policy", "EmailPolicy", "default", "strict", "SMTP", "HTTP"] +class _HasName(Protocol): + name: str + +_HasNameT = TypeVar("_HasNameT", bound=_HasName) +_HeaderT = TypeVar("_HeaderT", bound=Header) +_StrBaseHeader = TypeVar("_StrBaseHeader", bound=str) # BaseHeader matches this bound + class Policy(metaclass=ABCMeta): max_line_length: int | None linesep: str @@ -35,7 +44,7 @@ class Policy(metaclass=ABCMeta): @abstractmethod def header_store_parse(self, name: str, value: str) -> tuple[str, str]: ... @abstractmethod - def header_fetch_parse(self, name: str, value: str) -> str: ... + def header_fetch_parse(self, name: str, value: str) -> str | Header: ... @abstractmethod def fold(self, name: str, value: str) -> str: ... @abstractmethod @@ -44,7 +53,10 @@ class Policy(metaclass=ABCMeta): class Compat32(Policy): def header_source_parse(self, sourcelines: list[str]) -> tuple[str, str]: ... def header_store_parse(self, name: str, value: str) -> tuple[str, str]: ... - def header_fetch_parse(self, name: str, value: str) -> str | Header: ... # type: ignore[override] + @overload + def header_fetch_parse(self, name: Unused, value: _HeaderT) -> _HeaderT: ... + @overload + def header_fetch_parse(self, name: str, value: _StrBaseHeader) -> _StrBaseHeader | Header: ... def fold(self, name: str, value: str) -> str: ... def fold_binary(self, name: str, value: str) -> bytes: ... @@ -71,7 +83,10 @@ class EmailPolicy(Policy): ) -> None: ... def header_source_parse(self, sourcelines: list[str]) -> tuple[str, str]: ... def header_store_parse(self, name: str, value: Any) -> tuple[str, Any]: ... - def header_fetch_parse(self, name: str, value: str) -> Any: ... + @overload + def header_fetch_parse(self, name: Unused, value: _HasNameT) -> _HasNameT: ... + @overload + def header_fetch_parse(self, name: str, value: str) -> BaseHeader: ... def fold(self, name: str, value: str) -> Any: ... def fold_binary(self, name: str, value: str) -> bytes: ...