diff --git a/funnel/forms/account.py b/funnel/forms/account.py index 027368c5d..0060c8be5 100644 --- a/funnel/forms/account.py +++ b/funnel/forms/account.py @@ -496,7 +496,7 @@ def validate_username(self, field: forms.Field) -> None: raise_username_error(reason) -class EnableNotificationsDescriptionMixin: +class EnableNotificationsDescriptionProtoMixin: """Mixin to add a link in the description for enabling notifications.""" enable_notifications: forms.Field @@ -513,7 +513,9 @@ def set_queries(self) -> None: @Account.forms('email_add') -class NewEmailAddressForm(EnableNotificationsDescriptionMixin, forms.RecaptchaForm): +class NewEmailAddressForm( + EnableNotificationsDescriptionProtoMixin, forms.RecaptchaForm +): """Form to add a new email address to an account.""" __expects__ = ('edit_user',) @@ -560,7 +562,7 @@ class EmailPrimaryForm(forms.Form): @Account.forms('phone_add') -class NewPhoneForm(EnableNotificationsDescriptionMixin, forms.RecaptchaForm): +class NewPhoneForm(EnableNotificationsDescriptionProtoMixin, forms.RecaptchaForm): """Form to add a new mobile number (SMS-capable) to an account.""" __expects__ = ('edit_user',) diff --git a/funnel/models/account.py b/funnel/models/account.py index bbc34da48..994b1593b 100644 --- a/funnel/models/account.py +++ b/funnel/models/account.py @@ -6,7 +6,7 @@ import itertools from collections.abc import Iterable, Iterator from datetime import datetime, timedelta -from typing import ClassVar, Literal, Union, cast, overload +from typing import ClassVar, Literal, cast, overload from uuid import UUID import phonenumbers @@ -259,7 +259,7 @@ class Account(UuidMixin, BaseMixin, Model): sa.orm.mapped_column(sa.Integer, nullable=False), read={'all'} ) - search_vector: Mapped[str] = sa.orm.mapped_column( + search_vector: Mapped[TSVectorType] = sa.orm.mapped_column( TSVectorType( 'title', 'name', @@ -1227,11 +1227,10 @@ def organization_links(self) -> list: add_search_trigger(Account, 'name_vector') -class AccountOldId(UuidMixin, BaseMixin, Model): +class AccountOldId(UuidMixin, BaseMixin[UUID], Model): """Record of an older UUID for an account, after account merger.""" __tablename__ = 'account_oldid' - __uuid_primary_key__ = True #: Old account, if still present old_account: Mapped[Account] = relationship( @@ -2178,7 +2177,7 @@ def get( ) #: Anchor type -Anchor = Union[AccountEmail, AccountEmailClaim, AccountPhone, EmailAddress, PhoneNumber] +Anchor = AccountEmail | AccountEmailClaim | AccountPhone | EmailAddress | PhoneNumber # Tail imports # pylint: disable=wrong-import-position diff --git a/funnel/models/comment.py b/funnel/models/comment.py index c0361d9e6..e6f3e4240 100644 --- a/funnel/models/comment.py +++ b/funnel/models/comment.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from datetime import datetime -from typing import Any +from typing import TYPE_CHECKING, Any from werkzeug.utils import cached_property @@ -134,7 +134,7 @@ def __init__(self, **kwargs) -> None: self.count = 0 @cached_property - def parent(self) -> BaseMixin: + def parent(self) -> Project | Proposal | Update: # FIXME: Move this to a CommentMixin that uses a registry, like EmailAddress if self.project is not None: return self.project @@ -147,11 +147,8 @@ def parent(self) -> BaseMixin: with_roles(parent, read={'all'}, datasets={'primary'}) @cached_property - def parent_type(self) -> str | None: - parent = self.parent - if parent is not None: - return parent.__tablename__ - return None + def parent_type(self) -> str: + return self.parent.__tablename__ with_roles(parent_type, read={'all'}) @@ -363,6 +360,7 @@ def _message_expression(cls): @property def title(self) -> str: + """A made-up title referring to the context for the comment.""" obj = self.commentset.parent if obj is not None: return _("{user} commented on {obj}").format( @@ -439,3 +437,10 @@ class __Commentset: ), viewonly=True, ) + + +# Tail imports for type checking +if TYPE_CHECKING: + from .project import Project + from .proposal import Proposal + from .update import Update diff --git a/funnel/models/membership_mixin.py b/funnel/models/membership_mixin.py index 345b9ac89..37e61d46b 100644 --- a/funnel/models/membership_mixin.py +++ b/funnel/models/membership_mixin.py @@ -5,6 +5,7 @@ from collections.abc import Callable, Iterable from datetime import datetime as datetime_type from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar +from uuid import UUID from sqlalchemy import event from sqlalchemy.sql.expression import ColumnElement @@ -27,7 +28,7 @@ sa, ) from .account import Account -from .reorder_mixin import ReorderMixin +from .reorder_mixin import ReorderProtoMixin # Export only symbols needed in views. __all__ = [ @@ -40,7 +41,9 @@ # --- Typing --------------------------------------------------------------------------- MembershipType = TypeVar('MembershipType', bound='ImmutableMembershipMixin') -FrozenAttributionType = TypeVar('FrozenAttributionType', bound='FrozenAttributionMixin') +FrozenAttributionType = TypeVar( + 'FrozenAttributionType', bound='FrozenAttributionProtoMixin' +) # --- Enum ----------------------------------------------------------------------------- @@ -82,10 +85,9 @@ class MembershipRecordTypeError(MembershipError): @declarative_mixin -class ImmutableMembershipMixin(UuidMixin, BaseMixin): +class ImmutableMembershipMixin(UuidMixin, BaseMixin[UUID]): """Support class for immutable memberships.""" - __uuid_primary_key__ = True #: Can granted_by be null? Only in memberships based on legacy data __null_granted_by__: ClassVar[bool] = False #: List of columns that will be copied into a new row when a membership is amended @@ -383,7 +385,7 @@ def member_id(cls) -> Mapped[int]: @with_roles(read={'member', 'editor'}, grants_via={None: {'admin': 'member'}}) @declared_attr @classmethod - def member(cls) -> Mapped[Account]: + def member(cls) -> Mapped[Account]: # type: ignore[override] """Member in this membership record.""" return relationship(Account, foreign_keys=[cls.member_id]) @@ -490,7 +492,7 @@ def migrate_account(cls, old_account: Account, new_account: Account) -> None: @declarative_mixin -class ReorderMembershipMixin(ReorderMixin): +class ReorderMembershipProtoMixin(ReorderProtoMixin): """Customizes ReorderMixin for membership models.""" if TYPE_CHECKING: @@ -558,7 +560,7 @@ def parent_scoped_reorder_query_filter(self) -> ColumnElement: @declarative_mixin -class FrozenAttributionMixin: +class FrozenAttributionProtoMixin: """Provides a `title` data column and support method to freeze it.""" if TYPE_CHECKING: @@ -580,7 +582,7 @@ def _title(cls) -> Mapped[str | None]: def title(self) -> str: """Attribution title for this record.""" if self._local_data_only: - return self._title # This may be None + return self._title # This may be None # type: ignore[return-value] return self._title or self.member.title @title.setter diff --git a/funnel/models/moderation.py b/funnel/models/moderation.py index 0ba9f69e2..7ec499771 100644 --- a/funnel/models/moderation.py +++ b/funnel/models/moderation.py @@ -2,6 +2,8 @@ from __future__ import annotations +from uuid import UUID + from baseframe import __ from coaster.sqlalchemy import StateManager, with_roles from coaster.utils import LabeledEnum @@ -20,9 +22,8 @@ class MODERATOR_REPORT_TYPE(LabeledEnum): # noqa: N801 SPAM = (2, 'spam', __("Spam")) -class CommentModeratorReport(UuidMixin, BaseMixin, Model): +class CommentModeratorReport(UuidMixin, BaseMixin[UUID], Model): __tablename__ = 'comment_moderator_report' - __uuid_primary_key__ = True comment_id = sa.orm.mapped_column( sa.Integer, sa.ForeignKey('comment.id'), nullable=False, index=True diff --git a/funnel/models/notification.py b/funnel/models/notification.py index 165b32007..f6d3661b4 100644 --- a/funnel/models/notification.py +++ b/funnel/models/notification.py @@ -636,7 +636,7 @@ def allow_transport(cls, transport: str) -> bool: @property def role_provider_obj(self) -> _F | _D: """Return fragment if exists, document otherwise, indicating role provider.""" - return cast(Union[_F, _D], self.fragment or self.document) + return cast(_F | _D, self.fragment or self.document) def dispatch(self) -> Generator[NotificationRecipient, None, None]: """ @@ -733,7 +733,7 @@ def __getattr__(self, attr: str) -> Any: return getattr(self.cls, attr) -class NotificationRecipientMixin: +class NotificationRecipientProtoMixin: """Shared mixin for :class:`NotificationRecipient` and :class:`NotificationFor`.""" notification: Mapped[Notification] | Notification | PreviewNotification @@ -788,7 +788,7 @@ def is_not_deleted(self, revoke: bool = False) -> bool: return False -class NotificationRecipient(NotificationRecipientMixin, NoIdMixin, Model): +class NotificationRecipient(NoIdMixin, NotificationRecipientProtoMixin, Model): """ The recipient of a notification. @@ -1201,7 +1201,7 @@ def migrate_account(cls, old_account: Account, new_account: Account) -> None: ) -class NotificationFor(NotificationRecipientMixin): +class NotificationFor(NotificationRecipientProtoMixin): """View-only wrapper to mimic :class:`UserNotification`.""" notification: Notification | PreviewNotification diff --git a/funnel/models/proposal.py b/funnel/models/proposal.py index 3f504f63f..3b79396de 100644 --- a/funnel/models/proposal.py +++ b/funnel/models/proposal.py @@ -33,7 +33,7 @@ ) from .project import Project from .project_membership import project_child_role_map -from .reorder_mixin import ReorderMixin +from .reorder_mixin import ReorderProtoMixin from .video_mixin import VideoMixin __all__ = ['PROPOSAL_STATE', 'Proposal', 'ProposalSuuidRedirect'] @@ -119,7 +119,7 @@ class PROPOSAL_STATE(LabeledEnum): # noqa: N801 class Proposal( # type: ignore[misc] - UuidMixin, BaseScopedIdNameMixin, VideoMixin, ReorderMixin, Model + UuidMixin, BaseScopedIdNameMixin, VideoMixin, ReorderProtoMixin, Model ): __tablename__ = 'proposal' diff --git a/funnel/models/proposal_membership.py b/funnel/models/proposal_membership.py index 1813a6138..115d086bc 100644 --- a/funnel/models/proposal_membership.py +++ b/funnel/models/proposal_membership.py @@ -10,9 +10,9 @@ from .account import Account from .helpers import reopen from .membership_mixin import ( - FrozenAttributionMixin, + FrozenAttributionProtoMixin, ImmutableUserMembershipMixin, - ReorderMembershipMixin, + ReorderMembershipProtoMixin, ) from .project import Project from .proposal import Proposal @@ -21,7 +21,10 @@ class ProposalMembership( # type: ignore[misc] - FrozenAttributionMixin, ReorderMembershipMixin, ImmutableUserMembershipMixin, Model + ImmutableUserMembershipMixin, + FrozenAttributionProtoMixin, + ReorderMembershipProtoMixin, + Model, ): """Users can be presenters or reviewers on proposals.""" diff --git a/funnel/models/reorder_mixin.py b/funnel/models/reorder_mixin.py index b8367035e..1453252ce 100644 --- a/funnel/models/reorder_mixin.py +++ b/funnel/models/reorder_mixin.py @@ -8,23 +8,24 @@ from . import Mapped, QueryProperty, db, declarative_mixin, sa -__all__ = ['ReorderMixin'] +__all__ = ['ReorderProtoMixin'] -# Use of TypeVar for subclasses of ReorderMixin as defined in this mypy ticket: +# Use of TypeVar for subclasses of ReorderMixin as defined in these mypy tickets: # https://github.com/python/mypy/issues/1212 -Reorderable = TypeVar('Reorderable', bound='ReorderMixin') +# https://github.com/python/mypy/issues/7191 +Reorderable = TypeVar('Reorderable', bound='ReorderProtoMixin') @declarative_mixin -class ReorderMixin: +class ReorderProtoMixin: """Adds support for re-ordering sequences within a parent container.""" if TYPE_CHECKING: #: Subclasses must have a created_at column created_at: Mapped[datetime] #: Subclass must have a primary key that is int or uuid - id: Mapped[int] # noqa: A001 + id: Mapped[int | UUID] # noqa: A001 #: Subclass must declare a parent_id synonym to the parent model fkey column parent_id: Mapped[int | UUID] #: Subclass must declare a seq column or synonym, holding a sequence id. It @@ -36,7 +37,7 @@ class ReorderMixin: query: ClassVar[QueryProperty] @property - def parent_scoped_reorder_query_filter(self: Reorderable): + def parent_scoped_reorder_query_filter(self: Reorderable) -> sa.ColumnElement[bool]: """ Return a query filter that includes a scope limitation to the parent. @@ -80,6 +81,7 @@ def reorder_item(self: Reorderable, other: Reorderable, before: bool) -> None: cls.seq >= min(self.seq, other.seq), cls.seq <= max(self.seq, other.seq), ) + .with_for_update(of=cls) # Lock these rows to prevent a parallel update .options(sa.orm.load_only(cls.id, cls.seq)) .order_by(*order_columns) .all() @@ -99,7 +101,9 @@ def reorder_item(self: Reorderable, other: Reorderable, before: bool) -> None: new_seq_number = self.seq # Temporarily give self an out-of-bounds number self.seq = ( - sa.select(sa.func.coalesce(sa.func.max(cls.seq) + 1, 1)) + sa.select( # type: ignore[assignment] + sa.func.coalesce(sa.func.max(cls.seq) + 1, 1) + ) .where(self.parent_scoped_reorder_query_filter) .scalar_subquery() ) @@ -109,7 +113,7 @@ def reorder_item(self: Reorderable, other: Reorderable, before: bool) -> None: for reorderable_item in items_to_reorder[1:]: # Skip 0, which is self reorderable_item.seq, new_seq_number = new_seq_number, reorderable_item.seq # Flush to force execution order. This does not expunge SQLAlchemy cache as - # of SQLAlchemy 1.3.x. Should that behaviour change, a switch to + # of SQLAlchemy 2.0.x. Should that behaviour change, a switch to # bulk_update_mappings will be required db.session.flush() if reorderable_item.id == other.id: diff --git a/funnel/models/sponsor_membership.py b/funnel/models/sponsor_membership.py index 3bfb4ff34..52b8bd9de 100644 --- a/funnel/models/sponsor_membership.py +++ b/funnel/models/sponsor_membership.py @@ -10,9 +10,9 @@ from .account import Account from .helpers import reopen from .membership_mixin import ( - FrozenAttributionMixin, + FrozenAttributionProtoMixin, ImmutableUserMembershipMixin, - ReorderMembershipMixin, + ReorderMembershipProtoMixin, ) from .project import Project from .proposal import Proposal @@ -21,9 +21,9 @@ class ProjectSponsorMembership( # type: ignore[misc] - FrozenAttributionMixin, - ReorderMembershipMixin, ImmutableUserMembershipMixin, + FrozenAttributionProtoMixin, + ReorderMembershipProtoMixin, Model, ): """Sponsor of a project.""" @@ -151,8 +151,8 @@ def has_sponsors(self) -> bool: # FIXME: Replace this with existing proposal collaborator as they're now both related # to "account" class ProposalSponsorMembership( # type: ignore[misc] - FrozenAttributionMixin, - ReorderMembershipMixin, + FrozenAttributionProtoMixin, + ReorderMembershipProtoMixin, ImmutableUserMembershipMixin, Model, ): diff --git a/funnel/typing.py b/funnel/typing.py index 52446c431..4e0aced06 100644 --- a/funnel/typing.py +++ b/funnel/typing.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Optional, ParamSpec, TypeAlias, TypeVar, Union +from typing import ParamSpec, TypeAlias, TypeVar from flask.typing import ResponseReturnValue from werkzeug.wrappers import Response # Base class for Flask Response @@ -32,7 +32,7 @@ ReturnView: TypeAlias = ResponseReturnValue #: Return type of the `migrate_user` and `migrate_profile` methods -OptionalMigratedTables: TypeAlias = Optional[Union[list[str], tuple[str], set[str]]] +OptionalMigratedTables: TypeAlias = list[str] | tuple[str] | set[str] | None #: Return type for Response objects ReturnResponse: TypeAlias = Response diff --git a/funnel/views/mixins.py b/funnel/views/mixins.py index 49127da29..9ec94ab2f 100644 --- a/funnel/views/mixins.py +++ b/funnel/views/mixins.py @@ -24,7 +24,7 @@ VenueRoom, db, ) -from ..typing import ReturnRenderWith, ReturnView +from ..typing import ReturnView from .helpers import render_redirect @@ -250,7 +250,7 @@ def get_draft_data( return draft.revision, draft.formdata return None, None - def autosave_post(self, obj: UuidModelUnion | None = None) -> ReturnRenderWith: + def autosave_post(self, obj: UuidModelUnion | None = None) -> ReturnView: """Handle autosave POST requests.""" obj = obj if obj is not None else self.obj if 'form.revision' not in request.form: diff --git a/funnel/views/search.py b/funnel/views/search.py index f518001aa..dc14c3fae 100644 --- a/funnel/views/search.py +++ b/funnel/views/search.py @@ -86,7 +86,11 @@ def regconfig(self) -> str: @property def title_column(self) -> sa.ColumnElement[str]: """Return a column or column expression representing the object's title.""" - return self.model.title + # `Comment.title` is a property not a column, as comments don't have titles. + # That makes this return value incorrect, but here we ignore the error as + # class:`CommentSearch` explicitly overrides :meth:`hltitle_column`, and that is + # the only place this property is accessed + return self.model.title # type: ignore[return-value] @property def hltext(self) -> sa.ColumnElement[str]: diff --git a/migrations/script.py.mako b/migrations/script.py.mako index af188f4d1..88b3d025d 100644 --- a/migrations/script.py.mako +++ b/migrations/script.py.mako @@ -9,8 +9,6 @@ Create Date: ${create_date} """ -from typing import Optional, Tuple, Union - from alembic import op import sqlalchemy as sa ${imports if imports else ""} @@ -18,8 +16,8 @@ ${imports if imports else ""} # revision identifiers, used by Alembic. revision: str = ${repr(up_revision)} down_revision: str = ${repr(down_revision)} -branch_labels: Optional[Union[str, Tuple[str, ...]]] = ${repr(branch_labels)} -depends_on: Optional[Union[str, Tuple[str, ...]]] = ${repr(depends_on)} +branch_labels: str | tuple[str, ...] | None = ${repr(branch_labels)} +depends_on: str, tuple[str, ...] | None = ${repr(depends_on)} def upgrade(engine_name: str = '') -> None: diff --git a/tests/conftest.py b/tests/conftest.py index f5055d380..fc4a58083 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,13 +5,14 @@ import re import time -import typing as t import warnings +from collections.abc import Callable, Iterator from contextlib import ExitStack from dataclasses import dataclass from datetime import datetime, timezone from difflib import unified_diff from types import MethodType, ModuleType, SimpleNamespace +from typing import TYPE_CHECKING, Any, NamedTuple, get_type_hints from unittest.mock import patch import flask_wtf.csrf @@ -23,7 +24,7 @@ from flask_sqlalchemy.session import Session as FsaSession from sqlalchemy.orm import Session as DatabaseSessionClass -if t.TYPE_CHECKING: +if TYPE_CHECKING: from flask import Flask from flask.testing import FlaskClient, TestResponse from rich.console import Console @@ -60,7 +61,7 @@ def firefox_options(firefox_options): return firefox_options -def pytest_collection_modifyitems(items: t.List[pytest.Function]) -> None: +def pytest_collection_modifyitems(items: list[pytest.Function]) -> None: """Sort tests to run lower level before higher level.""" test_order = ( 'tests/unit/models', @@ -79,7 +80,7 @@ def pytest_collection_modifyitems(items: t.List[pytest.Function]) -> None: 'tests/features', ) - def sort_key(item: pytest.Function) -> t.Tuple[int, str]: + def sort_key(item: pytest.Function) -> tuple[int, str]: # pytest.Function's base class pytest.Item reports the file containing the test # as item.location == (file_path, line_no, function_name). However, pytest-bdd # reports itself for file_path, so we can't use that and must extract the path @@ -96,10 +97,10 @@ def sort_key(item: pytest.Function) -> t.Tuple[int, str]: # Adapted from https://github.com/untitaker/pytest-fixture-typecheck def pytest_runtest_call(item: pytest.Function) -> None: try: - annotations = t.get_type_hints( + annotations = get_type_hints( item.obj, globalns=item.obj.__globals__, - localns={'Any': t.Any}, # pytest-bdd appears to insert an `Any` annotation + localns={'Any': Any}, # pytest-bdd appears to insert an `Any` annotation ) except TypeError: # get_type_hints may fail on Python <3.10 because pytest-bdd appears to have @@ -154,7 +155,7 @@ def funnel_devtest() -> ModuleType: @pytest.fixture(scope='session') -def response_with_forms() -> t.Any: # Since the actual return type is defined within +def response_with_forms() -> Any: # Since the actual return type is defined within from flask.wrappers import Response from lxml.html import FormElement, HtmlElement, fromstring # nosec @@ -175,11 +176,11 @@ def response_with_forms() -> t.Any: # Since the actual return type is defined w re.ASCII | re.IGNORECASE | re.VERBOSE, ) - class MetaRefreshContent(t.NamedTuple): + class MetaRefreshContent(NamedTuple): """Timeout and optional URL in a Meta Refresh tag.""" timeout: int - url: t.Optional[str] = None + url: str | None = None class ResponseWithForms(Response): """ @@ -196,7 +197,7 @@ def test_mytest(client) -> None: next_response = form.submit(client) """ - _parsed_html: t.Optional[HtmlElement] = None + _parsed_html: HtmlElement | None = None @property def html(self) -> HtmlElement: @@ -206,7 +207,7 @@ def html(self) -> HtmlElement: # add click method to all links def _click( - self: HtmlElement, client: FlaskClient, **kwargs: t.Any + self: HtmlElement, client: FlaskClient, **kwargs: Any ) -> TestResponse: # `self` is the `a` element here path = self.attrib['href'] @@ -219,8 +220,8 @@ def _click( def _submit( self: FormElement, client: FlaskClient, - path: t.Optional[str] = None, - **kwargs: t.Any, + path: str | None = None, + **kwargs: Any, ) -> TestResponse: # `self` is the `form` element here data = dict(self.form_values()) @@ -238,7 +239,7 @@ def _submit( return self._parsed_html @property - def forms(self) -> t.List[FormElement]: + def forms(self) -> list[FormElement]: """ Return list of all forms in the document. @@ -249,8 +250,8 @@ def forms(self) -> t.List[FormElement]: return self.html.forms def form( - self, id_: t.Optional[str] = None, name: t.Optional[str] = None - ) -> t.Optional[FormElement]: + self, id_: str | None = None, name: str | None = None + ) -> FormElement | None: """Return the first form matching given id or name in the document.""" if id_: forms = self.html.cssselect(f'form#{id_}') @@ -262,11 +263,11 @@ def form( return forms[0] return None - def links(self, selector: str = 'a') -> t.List[HtmlElement]: + def links(self, selector: str = 'a') -> list[HtmlElement]: """Get all the links matching the given CSS selector.""" return self.html.cssselect(selector) - def link(self, selector: str = 'a') -> t.Optional[HtmlElement]: + def link(self, selector: str = 'a') -> HtmlElement | None: """Get first link matching the given CSS selector.""" links = self.links(selector) if links: @@ -274,7 +275,7 @@ def link(self, selector: str = 'a') -> t.Optional[HtmlElement]: return None @property - def metarefresh(self) -> t.Optional[MetaRefreshContent]: + def metarefresh(self) -> MetaRefreshContent | None: """Get content of Meta Refresh tag if present.""" meta_elements = self.html.cssselect('meta[http-equiv="refresh"]') if not meta_elements: @@ -299,7 +300,7 @@ def rich_console() -> Console: @pytest.fixture(scope='session') -def colorama() -> t.Iterator[SimpleNamespace]: +def colorama() -> Iterator[SimpleNamespace]: """Provide the colorama print colorizer.""" from colorama import Back, Fore, Style, deinit, init @@ -309,10 +310,10 @@ def colorama() -> t.Iterator[SimpleNamespace]: @pytest.fixture(scope='session') -def colorize_code(rich_console: Console) -> t.Callable[[str, t.Optional[str]], str]: +def colorize_code(rich_console: Console) -> Callable[[str, str | None], str]: """Return colorized output for a string of code, for current terminal's colors.""" - def no_colorize(code_string: str, lang: t.Optional[str] = 'python') -> str: + def no_colorize(code_string: str, lang: str | None = 'python') -> str: # Pygments is not available or terminal does not support colour output return code_string @@ -337,7 +338,7 @@ def no_colorize(code_string: str, lang: t.Optional[str] = 'python') -> str: # color_system is `None` or `'windows'` or something unrecognised. No colours. return no_colorize - def colorize(code_string: str, lang: t.Optional[str] = 'python') -> str: + def colorize(code_string: str, lang: str | None = 'python') -> str: if lang in (None, 'auto'): lexer = guess_lexer(code_string) else: @@ -348,7 +349,7 @@ def colorize(code_string: str, lang: t.Optional[str] = 'python') -> str: @pytest.fixture(scope='session') -def print_stack(pytestconfig, colorama, colorize_code) -> t.Callable[[int, int], None]: +def print_stack(pytestconfig, colorama, colorize_code) -> Callable[[int, int], None]: """Print a stack trace up to an outbound call from within this repository.""" import os.path from inspect import stack as inspect_stack @@ -419,20 +420,20 @@ def unsubscribeapp(funnel) -> Flask: @pytest.fixture() -def app_context(app) -> t.Iterator: +def app_context(app) -> Iterator: """Create an app context for the test.""" with app.app_context() as ctx: yield ctx @pytest.fixture() -def request_context(app) -> t.Iterator: +def request_context(app) -> Iterator: """Create a request context with default values for the test.""" with app.test_request_context() as ctx: yield ctx -config_test_keys: t.Dict[str, t.Set[str]] = { +config_test_keys: dict[str, set[str]] = { 'recaptcha': {'RECAPTCHA_PUBLIC_KEY', 'RECAPTCHA_PRIVATE_KEY'}, 'twilio': {'SMS_TWILIO_SID', 'SMS_TWILIO_TOKEN'}, 'exotel': {'SMS_EXOTEL_SID', 'SMS_EXOTEL_TOKEN'}, @@ -459,11 +460,11 @@ def request_context(app) -> t.Iterator: @pytest.fixture(autouse=True) -def _mock_config(request: pytest.FixtureRequest) -> t.Iterator: +def _mock_config(request: pytest.FixtureRequest) -> Iterator: """Mock app config (using ``mock_config`` mark).""" def backup_and_apply_config( - app_name: str, app_fixture: Flask, saved_config: dict, key: str, value: t.Any + app_name: str, app_fixture: Flask, saved_config: dict, key: str, value: Any ) -> None: if key in saved_config: pytest.fail(f"Duplicate mock for {app_name}.config[{key!r}]") @@ -479,7 +480,7 @@ def backup_and_apply_config( app_fixture.config[key] = value if request.node.get_closest_marker('mock_config'): - saved_app_config: t.Dict[str, t.Any] = {} + saved_app_config: dict[str, Any] = {} for mark in request.node.iter_markers('mock_config'): if len(mark.args) < 1: pytest.fail(_mock_config_syntax) @@ -534,7 +535,7 @@ def _requires_config(request: pytest.FixtureRequest) -> None: @pytest.fixture(scope='session') -def _app_events(colorama, print_stack, app) -> t.Iterator: +def _app_events(colorama, print_stack, app) -> Iterator: """Fixture to report Flask signals with a stack trace when debugging a test.""" from functools import partial @@ -572,7 +573,7 @@ def signal_handler(signal_name, *args, **kwargs): @pytest.fixture() -def _database_events(models, colorama, colorize_code, print_stack) -> t.Iterator: +def _database_events(models, colorama, colorize_code, print_stack) -> Iterator: """ Fixture to report database session events for debugging a test. @@ -911,7 +912,7 @@ def drop_tables(): @pytest.fixture() def db_session_truncate( funnel, app, database, app_context -) -> t.Iterator[DatabaseSessionClass]: +) -> Iterator[DatabaseSessionClass]: """Empty the database after each use of the fixture.""" yield database.session sa.orm.close_all_sessions() @@ -927,27 +928,27 @@ def db_session_truncate( @dataclass class BindConnectionTransaction: engine: sa.engine.Engine - connection: t.Any - transaction: t.Any + connection: Any + transaction: Any class BoundSession(FsaSession): def __init__( self, db: SQLAlchemy, - bindcts: t.Dict[t.Optional[str], BindConnectionTransaction], - **kwargs: t.Any, + bindcts: dict[str | None, BindConnectionTransaction], + **kwargs: Any, ) -> None: super().__init__(db, **kwargs) self.bindcts = bindcts def get_bind( self, - mapper: t.Optional[t.Any] = None, - clause: t.Optional[t.Any] = None, - bind: t.Optional[t.Union[sa.engine.Engine, sa.engine.Connection]] = None, - **kwargs: t.Any, - ) -> t.Union[sa.engine.Engine, sa.engine.Connection]: + mapper: Any | None = None, + clause: Any | None = None, + bind: sa.engine.Engine | sa.engine.Connection | None = None, + **kwargs: Any, + ) -> sa.engine.Engine | sa.engine.Connection: if bind is not None: return bind if mapper is not None: @@ -964,11 +965,11 @@ def get_bind( @pytest.fixture() def db_session_rollback( funnel, app, database, app_context -) -> t.Iterator[DatabaseSessionClass]: +) -> Iterator[DatabaseSessionClass]: """Create a nested transaction for the test and rollback after.""" original_session = database.session - bindcts: t.Dict[t.Optional[str], BindConnectionTransaction] = {} + bindcts: dict[str | None, BindConnectionTransaction] = {} for bind, engine in database.engines.items(): connection = engine.connect() transaction = connection.begin() @@ -1152,7 +1153,7 @@ def logout() -> None: @pytest.fixture() -def getuser(request) -> t.Callable[[str], funnel_models.User]: +def getuser(request) -> Callable[[str], funnel_models.User]: """Get a user fixture by their name.""" usermap = { "Twoflower": 'user_twoflower', diff --git a/tests/integration/views/conftest.py b/tests/integration/views/conftest.py index 843909c00..979f9bb36 100644 --- a/tests/integration/views/conftest.py +++ b/tests/integration/views/conftest.py @@ -2,11 +2,11 @@ from __future__ import annotations -import typing as t +from typing import TYPE_CHECKING from pytest_bdd import given, parsers, when -if t.TYPE_CHECKING: +if TYPE_CHECKING: from funnel import models