From 2601c433ed83794ded0e7a25d45679c77fb8ade0 Mon Sep 17 00:00:00 2001 From: Greg Kuwaye Date: Tue, 6 Aug 2024 17:18:54 -1000 Subject: [PATCH] Fix: Handle `type` when it is a list --- .../swagger_generation/generator_utils.py | 13 +++++-- .../test_generator_utils.py | 38 +++++++++++++++++++ 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/flask_rebar/swagger_generation/generator_utils.py b/flask_rebar/swagger_generation/generator_utils.py index fc83991..ce64cc1 100644 --- a/flask_rebar/swagger_generation/generator_utils.py +++ b/flask_rebar/swagger_generation/generator_utils.py @@ -140,15 +140,22 @@ def flatten(schema: Dict[str, Any], base: str) -> Tuple[Dict[str, str], Dict[str def _flatten( schema: Dict[str, Any], definitions: Dict[str, Any], base: str ) -> Dict[str, str]: - schema_type = schema.get(sw.type_) + # With OpenAPI 3.1, this will be a list of allowed types that includes sw.null if allow_none=True. + schema_type: str | list[str] | None = schema.get(sw.type_) + schema_types = [] + if type(schema_type) is str: + schema_types = [schema_type] + elif type(schema_type) is list: + schema_types = schema_type + subschema_keyword = _get_subschema_keyword(schema) - if schema_type == sw.object_: + if sw.object_ in schema_types: properties = schema.get(sw.properties, {}) for key, prop in properties.items(): properties[key] = _flatten(schema=prop, definitions=definitions, base=base) - elif schema_type == sw.array: + elif sw.array in schema_types: schema[sw.items] = _flatten( schema=schema[sw.items], definitions=definitions, base=base ) diff --git a/tests/swagger_generation/test_generator_utils.py b/tests/swagger_generation/test_generator_utils.py index 099667e..c2b4e78 100644 --- a/tests/swagger_generation/test_generator_utils.py +++ b/tests/swagger_generation/test_generator_utils.py @@ -188,6 +188,44 @@ def test_flatten_subschemas(self): self.assertEqual(schema, expected_schema) self.assertEqual(definitions, expected_definitions) + def test_flatten_creates_refs_when_type_is_list(self): + self.maxDiff = None + input_ = { + 'properties': { + 'data': { + 'items': { + 'properties': {'name': {'type': 'string'}}, + 'title': 'NestedSchema', + 'type': 'object', + }, + 'type': ['array', 'null'], + }, + }, + 'title': 'ParentAllowNoneTrueSchema', + 'type': 'object', + } + + expected_schema = { + '$ref': '#/definitions/ParentAllowNoneTrueSchema' + } + + expected_definitions = { + 'NestedSchema': { + 'properties': {'name': {'type': 'string'}}, + 'title': 'NestedSchema', + 'type': 'object', + }, + 'ParentAllowNoneTrueSchema': { + 'properties': {'data': {'items': {'$ref': '#/definitions/NestedSchema'}, 'type': ['array', 'null']}}, + 'title': 'ParentAllowNoneTrueSchema', + 'type': 'object', + }, + } + + schema, definitions = flatten(input_, base="#/definitions") + self.assertEqual(schema, expected_schema) + self.assertEqual(definitions, expected_definitions) + class TestFormatPathForSwagger(unittest.TestCase): def test_format_path(self):