diff --git a/litestar/contrib/pydantic/utils.py b/litestar/contrib/pydantic/utils.py index 88332529a8..cae63d58e8 100644 --- a/litestar/contrib/pydantic/utils.py +++ b/litestar/contrib/pydantic/utils.py @@ -12,8 +12,6 @@ get_origin_or_inner_type, get_type_hints_with_generics_resolved, instantiable_type_mapping, - unwrap_annotation, - wrapper_type_set, ) # isort: off @@ -116,9 +114,6 @@ def pydantic_unwrap_and_get_origin(annotation: Any) -> Any | None: return get_origin_or_inner_type(annotation) origin = annotation.__pydantic_generic_metadata__["origin"] - if origin in wrapper_type_set: - inner, _, _ = unwrap_annotation(annotation) - origin = get_origin_or_inner_type(inner) return instantiable_type_mapping.get(origin, origin) diff --git a/tests/unit/test_contrib/test_pydantic/test_dto.py b/tests/unit/test_contrib/test_pydantic/test_dto.py index 27c051d9ae..aee52fe405 100644 --- a/tests/unit/test_contrib/test_pydantic/test_dto.py +++ b/tests/unit/test_contrib/test_pydantic/test_dto.py @@ -1,11 +1,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional + +import pytest +from pydantic import v1 as pydantic_v1 from litestar import Request, post from litestar.contrib.pydantic import PydanticDTO, _model_dump_json from litestar.dto import DTOConfig from litestar.testing import create_test_client +from litestar.types import Empty +from litestar.typing import FieldDefinition if TYPE_CHECKING: from pydantic import BaseModel @@ -36,3 +41,24 @@ def handler(data: PydanticUser, request: Request) -> dict: ) required = next(iter(received.json()["components"]["schemas"].values()))["required"] assert len(required) == 2 + + +def test_field_definition_implicit_optional_default(base_model: type[BaseModel]) -> None: + class Model(base_model): # type: ignore[misc, valid-type] + a: Optional[str] # noqa: UP007 + + dto_type = PydanticDTO[Model] + field_defs = list(dto_type.generate_field_definitions(Model)) + assert len(field_defs) == 1 + assert field_defs[0].default is None + + +def test_detect_nested_field_pydantic_v1(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("litestar.contrib.pydantic.pydantic_dto_factory.pydantic_v2", Empty) + + class Model(pydantic_v1.BaseModel): + a: str + + dto_type = PydanticDTO[Model] + assert dto_type.detect_nested_field(FieldDefinition.from_annotation(Model)) is True + assert dto_type.detect_nested_field(FieldDefinition.from_annotation(int)) is False diff --git a/tests/unit/test_contrib/test_pydantic/test_integration.py b/tests/unit/test_contrib/test_pydantic/test_integration.py index 1705d84050..2df020302e 100644 --- a/tests/unit/test_contrib/test_pydantic/test_integration.py +++ b/tests/unit/test_contrib/test_pydantic/test_integration.py @@ -1,11 +1,14 @@ from typing import Any, Dict, List, Type, Union import pydantic as pydantic_v2 +import pytest from pydantic import v1 as pydantic_v1 +from typing_extensions import Annotated from litestar import post from litestar.contrib.pydantic.pydantic_dto_factory import PydanticDTO -from litestar.params import Parameter +from litestar.enums import RequestEncodingType +from litestar.params import Body, Parameter from litestar.status_codes import HTTP_400_BAD_REQUEST from litestar.testing import create_test_client from tests.unit.test_contrib.test_pydantic.models import PydanticPerson, PydanticV1Person @@ -13,14 +16,18 @@ from . import PydanticVersion -def test_pydantic_v1_validation_error_raises_400() -> None: +@pytest.mark.parametrize(("meta",), [(None,), (Body(media_type=RequestEncodingType.URL_ENCODED),)]) +def test_pydantic_v1_validation_error_raises_400(meta: Any) -> None: class Model(pydantic_v1.BaseModel): foo: str = pydantic_v1.Field(max_length=2) ModelDTO = PydanticDTO[Model] - @post(dto=ModelDTO, signature_types=[Model]) - def handler(data: Model) -> Model: + annotation: Any + annotation = Annotated[Model, meta] if meta is not None else Model + + @post(dto=ModelDTO, signature_namespace={"annotation": annotation}) + def handler(data: annotation) -> Any: # pyright: ignore return data model_json = {"foo": "too long"} @@ -36,21 +43,26 @@ def handler(data: Model) -> Model: ] with create_test_client(route_handlers=handler) as client: - response = client.post("/", json=model_json) + kws = {"data": model_json} if meta else {"json": model_json} + response = client.post("/", **kws) # type: ignore[arg-type] extra = response.json()["extra"] assert response.status_code == 400 assert extra == expected_errors -def test_pydantic_v2_validation_error_raises_400() -> None: +@pytest.mark.parametrize(("meta",), [(None,), (Body(media_type=RequestEncodingType.URL_ENCODED),)]) +def test_pydantic_v2_validation_error_raises_400(meta: Any) -> None: class Model(pydantic_v2.BaseModel): foo: str = pydantic_v2.Field(max_length=2) ModelDTO = PydanticDTO[Model] - @post(dto=ModelDTO, signature_types=[Model]) - def handler(data: Model) -> Model: + annotation: Any + annotation = Annotated[Model, meta] if meta is not None else Model + + @post(dto=ModelDTO, signature_namespace={"annotation": annotation}) + def handler(data: annotation) -> Any: # pyright: ignore return data model_json = {"foo": "too long"} @@ -67,7 +79,8 @@ def handler(data: Model) -> Model: ] with create_test_client(route_handlers=handler) as client: - response = client.post("/", json=model_json) + kws = {"data": model_json} if meta else {"json": model_json} + response = client.post("/", **kws) # type: ignore[arg-type] extra = response.json()["extra"] extra[0].pop("url") diff --git a/tests/unit/test_contrib/test_pydantic/test_utils.py b/tests/unit/test_contrib/test_pydantic/test_utils.py index 2e7854ce4a..470a42807b 100644 --- a/tests/unit/test_contrib/test_pydantic/test_utils.py +++ b/tests/unit/test_contrib/test_pydantic/test_utils.py @@ -1,8 +1,7 @@ -from typing import Dict, Generic, Tuple +from typing import Any, Dict, Generic, Tuple, TypeVar import pytest from pydantic import BaseModel -from typing_extensions import Any, TypeVar from litestar.contrib.pydantic.utils import pydantic_get_type_hints_with_generics_resolved