Skip to content

Commit

Permalink
Support mapping types in OpenAPI schema
Browse files Browse the repository at this point in the history
It is useful to return some grouped data
```python
@DataClass
class User:
    id: str
    name: str
```

Example for `Dict[str, list[User]]`

```json
{
  "type": "object",
  "patternProperties": {
    "^[a-z0-9!\"#$%&'()*+,.\/:;<=>?@\[\] ^_`{|}~-]+$": {
      "type": "array",
      "items": {
        "type": "object",
        "properties": {
          "id": { "type": "string" },
          "name": { "type": "string" }
        },
        "required": ["id", "name"]
      }
    }
  }
}

```
  • Loading branch information
tyzhnenko committed Jan 27, 2024
1 parent bec8e14 commit 9e905a5
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 1 deletion.
61 changes: 61 additions & 0 deletions blacksheep/server/openapi/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import inspect
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict, defaultdict
from dataclasses import dataclass, fields, is_dataclass
from datetime import date, datetime
from decimal import Decimal
from enum import Enum, IntEnum
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
from typing import _GenericAlias as GenericAlias
Expand Down Expand Up @@ -656,6 +658,11 @@ def _get_schema_by_type(
if schema:
return schema

# Dict, OrderedDict, defaultdict are handled first than GenericAlias
schema = self._try_get_schema_for_mapping(object_type, type_args)
if schema:
return schema

# List, Set, Tuple are handled first than GenericAlias
schema = self._try_get_schema_for_iterable(object_type, type_args)
if schema:
Expand Down Expand Up @@ -723,6 +730,60 @@ def _try_get_schema_for_iterable(
items=self.get_schema_by_type(item_type, context_type_args),
)

def _try_get_schema_for_mapping(
self, object_type: Type, context_type_args: Optional[Dict[Any, Type]] = None
) -> Optional[Schema]:
properties_regexp = {
str: r"^[a-z0-9!\"#$%&'()*+,.\/:;<=>?@\[\] ^_`{|}~-]+$",
int: r"^[0-9]+$",
float: r"^[0-9]+(?:\.[0-9]+)?$",
Decimal: r"^[0-9]+(?:\.[0-9]+)?$",
UUID: r"^[a-f0-9]{8}(?:-[a-f0-9]{4}){3}-[a-f0-9]{12}$",
bool: r"^(?:true|false)$",
}

if object_type in {dict, defaultdict, OrderedDict}:
# the user didn't specify the key and value types
return Schema(
type=ValueType.OBJECT,
properties={
properties_regexp[str]: Schema(
type=ValueType.STRING,
),
},
)

origin = get_origin(object_type)

if not origin or origin not in {
dict,
Dict,
collections_abc.Mapping,
}:
return None

# can be Dict, Dict[str, str] or dict[str, str] (Python 3.9),
# note: it could also be union if it wasn't handled above for dataclasses
try:
key_type, value_type = object_type.__args__ # type: ignore
except AttributeError: # pragma: no cover
key_type, value_type = str, str

if context_type_args and key_type in context_type_args:
key_type = context_type_args.get(key_type, key_type)

if context_type_args and value_type in context_type_args:
value_type = context_type_args.get(value_type, value_type)

return Schema(
type=ValueType.OBJECT,
properties={
properties_regexp.get(
key_type, "^[a-zA-Z0-9_]+$"
): self.get_schema_by_type(value_type, context_type_args)
},
)

def get_fields(self, object_type: Any) -> List[FieldInfo]:
for handler in self._object_types_handlers:
if handler.handles_type(object_type):
Expand Down
61 changes: 60 additions & 1 deletion tests/test_openapi_v3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from datetime import date, datetime
from enum import IntEnum
from typing import Generic, List, Optional, Sequence, TypeVar, Union
from typing import Generic, List, Mapping, Optional, Sequence, TypeVar, Union
from uuid import UUID

import pytest
Expand Down Expand Up @@ -1312,6 +1312,65 @@ def home() -> Sequence[Cat]:
)


@pytest.mark.asyncio
async def test_handling_of_mapping(docs: OpenAPIHandler, serializer: Serializer):
app = get_app()

@app.router.route("/")
def home() -> Mapping[str, Mapping[int, Cat]]:
...

docs.bind_app(app)
await app.start()

yaml = serializer.to_yaml(docs.generate_documentation(app))

assert (
yaml.strip()
== r"""
openapi: 3.0.3
info:
title: Example
version: 0.0.1
paths:
/:
get:
responses:
'200':
description: Success response
content:
application/json:
schema:
type: object
properties:
^[a-z0-9!\"#$%&'()*+,.\/:;<=>?@\[\] ^_`{|}~-]+$:
type: object
properties:
^[0-9]+$:
$ref: '#/components/schemas/Cat'
nullable: false
nullable: false
operationId: home
components:
schemas:
Cat:
type: object
required:
- id
- name
properties:
id:
type: integer
format: int64
nullable: false
name:
type: string
nullable: false
tags: []
""".strip()
)


def test_handling_of_generic_with_forward_references(docs: OpenAPIHandler):
with pytest.warns(UserWarning):
docs.register_schema_for_type(GenericWithForwardRefExample[Cat])
Expand Down

0 comments on commit 9e905a5

Please sign in to comment.