Skip to content

Commit

Permalink
fix(2971): OpenAPI schema generation fails for Union of Structs when …
Browse files Browse the repository at this point in the history
…including `None` (#2982)

* Fix 2971

Signed-off-by: Janek Nouvertné <[email protected]>

* Add more tests

Signed-off-by: Janek Nouvertné <[email protected]>

* Fix typing

Signed-off-by: Janek Nouvertné <[email protected]>

---------

Signed-off-by: Janek Nouvertné <[email protected]>
  • Loading branch information
provinzkraut authored Jan 14, 2024
1 parent 93b7db6 commit 7c4dd4b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
2 changes: 1 addition & 1 deletion litestar/_openapi/schema_generation/plugins/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class StructSchemaPlugin(OpenAPISchemaPlugin):
def is_plugin_supported_field(self, field_definition: FieldDefinition) -> bool:
return field_definition.is_subclass_of(Struct)
return not field_definition.is_union and field_definition.is_subclass_of(Struct)

def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema:
def is_field_required(field: FieldInfo) -> bool:
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/test_openapi/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,3 +491,61 @@ def test_process_schema_result_with_unregistered_object_schema() -> None:
schema = Schema(title="has title", type=OpenAPIType.OBJECT)
field_definition = FieldDefinition.from_annotation(dict)
assert SchemaCreator().process_schema_result(field_definition, schema) is schema


@pytest.mark.parametrize("base_type", [msgspec.Struct, TypedDict, dataclass])
def test_type_union(base_type: type) -> None:
if base_type is dataclass: # type: ignore[comparison-overlap]

@dataclass
class ModelA: # pyright: ignore
pass

@dataclass
class ModelB: # pyright: ignore
pass
else:

class ModelA(base_type): # type: ignore[no-redef, misc]
pass

class ModelB(base_type): # type: ignore[no-redef, misc]
pass

schema = get_schema_for_field_definition(
FieldDefinition.from_kwarg(name="Lookup", annotation=Union[ModelA, ModelB])
)
assert schema.one_of == [
Reference(ref="#/components/schemas/tests_unit_test_openapi_test_schema_test_type_union.ModelA"),
Reference(ref="#/components/schemas/tests_unit_test_openapi_test_schema_test_type_union.ModelB"),
]


@pytest.mark.parametrize("base_type", [msgspec.Struct, TypedDict, dataclass])
def test_type_union_with_none(base_type: type) -> None:
# https://github.com/litestar-org/litestar/issues/2971
if base_type is dataclass: # type: ignore[comparison-overlap]

@dataclass
class ModelA: # pyright: ignore
pass

@dataclass
class ModelB: # pyright: ignore
pass
else:

class ModelA(base_type): # type: ignore[no-redef, misc]
pass

class ModelB(base_type): # type: ignore[no-redef, misc]
pass

schema = get_schema_for_field_definition(
FieldDefinition.from_kwarg(name="Lookup", annotation=Union[ModelA, ModelB, None])
)
assert schema.one_of == [
Schema(type=OpenAPIType.NULL),
Reference(ref="#/components/schemas/tests_unit_test_openapi_test_schema_test_type_union_with_none.ModelA"),
Reference("#/components/schemas/tests_unit_test_openapi_test_schema_test_type_union_with_none.ModelB"),
]

0 comments on commit 7c4dd4b

Please sign in to comment.