diff --git a/blacksheep/server/openapi/v3.py b/blacksheep/server/openapi/v3.py index 3ee247c..0e191fe 100644 --- a/blacksheep/server/openapi/v3.py +++ b/blacksheep/server/openapi/v3.py @@ -2,8 +2,10 @@ import inspect import warnings from abc import ABC, abstractmethod +from collections import OrderedDict, defaultdict from dataclasses import dataclass, fields, is_dataclass from datetime import date, datetime +from decimal import Decimal from enum import Enum, IntEnum from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union from typing import _GenericAlias as GenericAlias @@ -656,6 +658,11 @@ def _get_schema_by_type( if schema: return schema + # Dict, OrderedDict, defaultdict are handled first than GenericAlias + schema = self._try_get_schema_for_mapping(object_type, type_args) + if schema: + return schema + # List, Set, Tuple are handled first than GenericAlias schema = self._try_get_schema_for_iterable(object_type, type_args) if schema: @@ -723,6 +730,60 @@ def _try_get_schema_for_iterable( items=self.get_schema_by_type(item_type, context_type_args), ) + def _try_get_schema_for_mapping( + self, object_type: Type, context_type_args: Optional[Dict[Any, Type]] = None + ) -> Optional[Schema]: + properties_regexp = { + str: r"^[a-z0-9!\"#$%&'()*+,.\/:;<=>?@\[\] ^_`{|}~-]+$", + int: r"^[0-9]+$", + float: r"^[0-9]+(?:\.[0-9]+)?$", + Decimal: r"^[0-9]+(?:\.[0-9]+)?$", + UUID: r"^[a-f0-9]{8}(?:-[a-f0-9]{4}){3}-[a-f0-9]{12}$", + bool: r"^(?:true|false)$", + } + + if object_type in {dict, defaultdict, OrderedDict}: + # the user didn't specify the key and value types + return Schema( + type=ValueType.OBJECT, + properties={ + properties_regexp[str]: Schema( + type=ValueType.STRING, + ), + }, + ) + + origin = get_origin(object_type) + + if not origin or origin not in { + dict, + Dict, + collections_abc.Mapping, + }: + return None + + # can be Dict, Dict[str, str] or dict[str, str] (Python 3.9), + # note: it could also be union if it wasn't handled above for dataclasses + try: + key_type, value_type = object_type.__args__ # type: ignore + except AttributeError: # pragma: no cover + key_type, value_type = str, str + + if context_type_args and key_type in context_type_args: + key_type = context_type_args.get(key_type, key_type) + + if context_type_args and value_type in context_type_args: + value_type = context_type_args.get(value_type, value_type) + + return Schema( + type=ValueType.OBJECT, + properties={ + properties_regexp.get( + key_type, "^[a-zA-Z0-9_]+$" + ): self.get_schema_by_type(value_type, context_type_args) + }, + ) + def get_fields(self, object_type: Any) -> List[FieldInfo]: for handler in self._object_types_handlers: if handler.handles_type(object_type): diff --git a/tests/test_openapi_v3.py b/tests/test_openapi_v3.py index 53389e6..bf9623f 100644 --- a/tests/test_openapi_v3.py +++ b/tests/test_openapi_v3.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from datetime import date, datetime from enum import IntEnum -from typing import Generic, List, Optional, Sequence, TypeVar, Union +from typing import Generic, List, Mapping, Optional, Sequence, TypeVar, Union from uuid import UUID import pytest @@ -1312,6 +1312,65 @@ def home() -> Sequence[Cat]: ) +@pytest.mark.asyncio +async def test_handling_of_mapping(docs: OpenAPIHandler, serializer: Serializer): + app = get_app() + + @app.router.route("/") + def home() -> Mapping[str, Mapping[int, Cat]]: + ... + + docs.bind_app(app) + await app.start() + + yaml = serializer.to_yaml(docs.generate_documentation(app)) + + assert ( + yaml.strip() + == r""" +openapi: 3.0.3 +info: + title: Example + version: 0.0.1 +paths: + /: + get: + responses: + '200': + description: Success response + content: + application/json: + schema: + type: object + properties: + ^[a-z0-9!\"#$%&'()*+,.\/:;<=>?@\[\] ^_`{|}~-]+$: + type: object + properties: + ^[0-9]+$: + $ref: '#/components/schemas/Cat' + nullable: false + nullable: false + operationId: home +components: + schemas: + Cat: + type: object + required: + - id + - name + properties: + id: + type: integer + format: int64 + nullable: false + name: + type: string + nullable: false +tags: [] +""".strip() + ) + + def test_handling_of_generic_with_forward_references(docs: OpenAPIHandler): with pytest.warns(UserWarning): docs.register_schema_for_type(GenericWithForwardRefExample[Cat])