Skip to content

Commit

Permalink
refactor: openapi default values (#2587)
Browse files Browse the repository at this point in the history
`Litestar.update_openapi_schema()` is called whenever we first generate the openapi schema for an app object, and whenever we register a handler on the app.

Inside, we perform checks to ensure that the `OpenAPI` object has a not-none `components` attribute, and that the components attribute has a not-none schema attribute.

This PR modifies the definition of the `OpenAPI` type, and the `Component` type to remove the `None` union and instantiate an instance of the respective instance on construction.

Background to the change is that when addressing uncovered lines of code in the `app.py` module, I was unsure of how to hit https://github.com/litestar-org/litestar/blob/fd06486e2ad4ed0a41636659fec4f093a09e3dd0/litestar/app.py#L835. Then realised that perhaps this whole block of code: https://github.com/litestar-org/litestar/blob/fd06486e2ad4ed0a41636659fec4f093a09e3dd0/litestar/app.py#L829-L835 might be unnecessary if we adjust these default values.
  • Loading branch information
peterschutt authored Nov 1, 2023
1 parent 3f58ff1 commit 47281b2
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 22 deletions.
11 changes: 1 addition & 10 deletions litestar/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from litestar.logging.config import LoggingConfig, get_logger_placeholder
from litestar.middleware.cors import CORSMiddleware
from litestar.openapi.config import OpenAPIConfig
from litestar.openapi.spec.components import Components
from litestar.plugins import (
CLIPluginProtocol,
InitPluginProtocol,
Expand Down Expand Up @@ -826,14 +825,6 @@ def update_openapi_schema(self) -> None:

operation_ids: list[str] = []

if not self._openapi_schema.components:
self._openapi_schema.components = Components()
schemas = self._openapi_schema.components.schemas = {}
elif not self._openapi_schema.components.schemas:
schemas = self._openapi_schema.components.schemas = {}
else:
schemas = {}

for route in self.routes:
if (
isinstance(route, HTTPRoute)
Expand All @@ -848,7 +839,7 @@ def update_openapi_schema(self) -> None:
plugins=self.plugins.openapi,
use_handler_docstrings=self.openapi_config.use_handler_docstrings,
operation_id_creator=self.openapi_config.operation_id_creator,
schemas=schemas,
schemas=self._openapi_schema.components.schemas,
)
self._openapi_schema.paths[route.path_format or "/"] = path_item

Expand Down
7 changes: 4 additions & 3 deletions litestar/openapi/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, Literal, cast
from typing import TYPE_CHECKING, Literal

from litestar._openapi.utils import default_operation_id_creator
from litestar.openapi.controller import OpenAPIController
Expand Down Expand Up @@ -67,7 +68,7 @@ class OpenAPIConfig:
Should be an instance of
:data:`SecurityRequirement <.openapi.spec.SecurityRequirement>`.
"""
components: Components | list[Components] | None = field(default=None)
components: Components | list[Components] = field(default_factory=Components)
"""API Components information.
Should be an instance of :class:`Components <litestar.openapi.spec.components.Components>` or a list thereof.
Expand Down Expand Up @@ -133,7 +134,7 @@ def to_openapi_schema(self) -> OpenAPI:
return OpenAPI(
external_docs=self.external_docs,
security=self.security,
components=cast("Components", self.components),
components=deepcopy(self.components), # deepcopy prevents mutation of the config's components
servers=self.servers,
tags=self.tags,
webhooks=self.webhooks,
Expand Down
4 changes: 2 additions & 2 deletions litestar/openapi/spec/components.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

from litestar.openapi.spec.base import BaseSchemaObject
Expand Down Expand Up @@ -31,7 +31,7 @@ class Components(BaseSchemaObject):
outside the components object.
"""

schemas: dict[str, Schema] | None = None
schemas: dict[str, Schema] = field(default_factory=dict)
"""An object to hold reusable
`Schema Objects <https://spec.openapis.org/oas/v3.1.0#schemaObject>`_"""

Expand Down
4 changes: 2 additions & 2 deletions litestar/openapi/spec/open_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from typing import TYPE_CHECKING

from litestar.openapi.spec.base import BaseSchemaObject
from litestar.openapi.spec.components import Components
from litestar.openapi.spec.server import Server

if TYPE_CHECKING:
from litestar.openapi.spec.components import Components
from litestar.openapi.spec.external_documentation import ExternalDocumentation
from litestar.openapi.spec.info import Info
from litestar.openapi.spec.path_item import PathItem
Expand Down Expand Up @@ -64,7 +64,7 @@ class OpenAPI(BaseSchemaObject):
`example <https://github.com/OAI/OpenAPI-Specification/blob/main/examples/v3.1/webhook-example.yaml>`_ is available.
"""

components: Components | None = None
components: Components = field(default_factory=Components)
"""An element to hold various schemas for the document."""

security: list[SecurityRequirement] | None = None
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/test_contrib/test_jwt/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def west_handler() -> None:
def test_jwt_auth_openapi() -> None:
jwt_auth = JWTAuth[Any](token_secret="abc123", retrieve_user_handler=lambda _: None) # type: ignore
assert jwt_auth.openapi_components.to_schema() == {
"schemas": {},
"securitySchemes": {
"BearerToken": {
"type": "http",
Expand All @@ -308,7 +309,7 @@ def test_jwt_auth_openapi() -> None:
"scheme": "Bearer",
"bearerFormat": "JWT",
}
}
},
}
assert jwt_auth.security_requirement == {"BearerToken": []}
app = Litestar(on_app_init=[jwt_auth.on_app_init])
Expand Down Expand Up @@ -363,6 +364,7 @@ def login_custom_handler() -> Response["User"]:
assert response.content != response_custom.content

assert jwt_auth.openapi_components.to_schema() == {
"schemas": {},
"securitySchemes": {
"BearerToken": {
"type": "oauth2",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_piccolo_dto_openapi_spec_generation() -> None:
== "#/components/schemas/RetrieveVenuesVenueResponseBody"
)

concert_schema = schema.components.schemas["CreateConcertConcertRequestBody"] # type: ignore
concert_schema = schema.components.schemas["CreateConcertConcertRequestBody"]
assert concert_schema
assert concert_schema.to_schema() == {
"properties": {
Expand All @@ -124,7 +124,7 @@ def test_piccolo_dto_openapi_spec_generation() -> None:
"type": "object",
}

record_studio_schema = schema.components.schemas["RetrieveStudioRecordingStudioResponseBody"] # type: ignore
record_studio_schema = schema.components.schemas["RetrieveStudioRecordingStudioResponseBody"]
assert record_studio_schema
assert record_studio_schema.to_schema() == {
"properties": {
Expand All @@ -138,7 +138,7 @@ def test_piccolo_dto_openapi_spec_generation() -> None:
"type": "object",
}

venue_schema = schema.components.schemas["RetrieveVenuesVenueResponseBody"] # type: ignore
venue_schema = schema.components.schemas["RetrieveVenuesVenueResponseBody"]
assert venue_schema
assert venue_schema.to_schema() == {
"properties": {
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_security/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def handler() -> dict:
(
(None, None),
(
OpenAPIConfig(title="Litestar API", version="1.0.0", components=None),
OpenAPIConfig(title="Litestar API", version="1.0.0"),
{
"schemas": {},
"securitySchemes": {
Expand Down

0 comments on commit 47281b2

Please sign in to comment.