diff --git a/litestar/contrib/pydantic/__init__.py b/litestar/contrib/pydantic/__init__.py index 9bab707c31..94114806c6 100644 --- a/litestar/contrib/pydantic/__init__.py +++ b/litestar/contrib/pydantic/__init__.py @@ -14,6 +14,8 @@ from pydantic.v1 import BaseModel as BaseModelV1 from litestar.config.app import AppConfig + from litestar.types.serialization import PydanticV1FieldsListType, PydanticV2FieldsListType + __all__ = ( "PydanticDTO", @@ -43,14 +45,39 @@ def _model_dump_json(model: BaseModel | BaseModelV1, by_alias: bool = False) -> class PydanticPlugin(InitPluginProtocol): """A plugin that provides Pydantic integration.""" - __slots__ = ("prefer_alias",) + __slots__ = ( + "exclude", + "exclude_defaults", + "exclude_none", + "exclude_unset", + "include", + "prefer_alias", + ) - def __init__(self, prefer_alias: bool = False) -> None: + def __init__( + self, + exclude: PydanticV1FieldsListType | PydanticV2FieldsListType = None, + exclude_defaults: bool = False, + exclude_none: bool = False, + exclude_unset: bool = False, + include: PydanticV1FieldsListType | PydanticV2FieldsListType = None, + prefer_alias: bool = False, + ) -> None: """Initialize ``PydanticPlugin``. Args: + exclude: ``type_encoders`` will exclude specified fields + exclude_defaults: ``type_encoders`` will exclude default fields + exclude_none: ``type_encoders`` will exclude ``None`` fields + exclude_unset: ``type_encoders`` will exclude not set fields + include: ``type_encoders`` will include only specified fields prefer_alias: OpenAPI and ``type_encoders`` will export by alias """ + self.exclude = exclude + self.exclude_defaults = exclude_defaults + self.exclude_none = exclude_none + self.exclude_unset = exclude_unset + self.include = include self.prefer_alias = prefer_alias def on_app_init(self, app_config: AppConfig) -> AppConfig: @@ -61,7 +88,14 @@ def on_app_init(self, app_config: AppConfig) -> AppConfig: """ app_config.plugins.extend( [ - PydanticInitPlugin(prefer_alias=self.prefer_alias), + PydanticInitPlugin( + exclude=self.exclude, + exclude_defaults=self.exclude_defaults, + exclude_none=self.exclude_none, + exclude_unset=self.exclude_unset, + include=self.include, + prefer_alias=self.prefer_alias, + ), PydanticSchemaPlugin(prefer_alias=self.prefer_alias), PydanticDIPlugin(), ] diff --git a/litestar/contrib/pydantic/pydantic_init_plugin.py b/litestar/contrib/pydantic/pydantic_init_plugin.py index 1261cd8ebc..c06f99dfe5 100644 --- a/litestar/contrib/pydantic/pydantic_init_plugin.py +++ b/litestar/contrib/pydantic/pydantic_init_plugin.py @@ -29,7 +29,10 @@ if TYPE_CHECKING: + from pydantic.typing import AbstractSetIntStr, MappingIntStrAny + from litestar.config.app import AppConfig + from litestar.types.serialization import PydanticV1FieldsListType, PydanticV2FieldsListType T = TypeVar("T") @@ -119,16 +122,63 @@ def extract(annotation: Any, default: Any) -> Any: class PydanticInitPlugin(InitPluginProtocol): - __slots__ = ("prefer_alias",) + __slots__ = ( + "exclude", + "exclude_defaults", + "exclude_none", + "exclude_unset", + "include", + "prefer_alias", + ) - def __init__(self, prefer_alias: bool = False) -> None: + def __init__( + self, + exclude: PydanticV1FieldsListType | PydanticV2FieldsListType = None, + exclude_defaults: bool = False, + exclude_none: bool = False, + exclude_unset: bool = False, + include: PydanticV1FieldsListType | PydanticV2FieldsListType = None, + prefer_alias: bool = False, + ) -> None: + self.exclude = exclude + self.exclude_defaults = exclude_defaults + self.exclude_none = exclude_none + self.exclude_unset = exclude_unset + self.include = include self.prefer_alias = prefer_alias @classmethod - def encoders(cls, prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: - encoders = {**_base_encoders, **cls._create_pydantic_v1_encoders(prefer_alias)} + def encoders( + cls, + exclude: PydanticV1FieldsListType | PydanticV2FieldsListType = None, + exclude_defaults: bool = False, + exclude_none: bool = False, + exclude_unset: bool = False, + include: PydanticV1FieldsListType | PydanticV2FieldsListType = None, + prefer_alias: bool = False, + ) -> dict[Any, Callable[[Any], Any]]: + encoders = { + **_base_encoders, + **cls._create_pydantic_v1_encoders( + prefer_alias=prefer_alias, + exclude=exclude, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + exclude_unset=exclude_unset, + include=include, + ), + } if pydantic_v2 is not None: # pragma: no cover - encoders.update(cls._create_pydantic_v2_encoders(prefer_alias)) + encoders.update( + cls._create_pydantic_v2_encoders( + prefer_alias=prefer_alias, + exclude=exclude, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + exclude_unset=exclude_unset, + include=include, + ) + ) return encoders @classmethod @@ -145,10 +195,25 @@ def decoders(cls) -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any] return decoders @staticmethod - def _create_pydantic_v1_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: # pragma: no cover + def _create_pydantic_v1_encoders( + exclude: AbstractSetIntStr | MappingIntStrAny | None = None, + exclude_defaults: bool = False, + exclude_none: bool = False, + exclude_unset: bool = False, + include: AbstractSetIntStr | MappingIntStrAny | None = None, + prefer_alias: bool = False, + ) -> dict[Any, Callable[[Any], Any]]: # pragma: no cover return { pydantic_v1.BaseModel: lambda model: { - k: v.decode() if isinstance(v, bytes) else v for k, v in model.dict(by_alias=prefer_alias).items() + k: v.decode() if isinstance(v, bytes) else v + for k, v in model.dict( + by_alias=prefer_alias, + exclude=exclude, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + exclude_unset=exclude_unset, + include=include, + ).items() }, pydantic_v1.SecretField: str, pydantic_v1.StrictBool: int, @@ -159,9 +224,24 @@ def _create_pydantic_v1_encoders(prefer_alias: bool = False) -> dict[Any, Callab } @staticmethod - def _create_pydantic_v2_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: + def _create_pydantic_v2_encoders( + exclude: PydanticV2FieldsListType = None, + exclude_defaults: bool = False, + exclude_none: bool = False, + exclude_unset: bool = False, + include: PydanticV2FieldsListType = None, + prefer_alias: bool = False, + ) -> dict[Any, Callable[[Any], Any]]: encoders: dict[Any, Callable[[Any], Any]] = { - pydantic_v2.BaseModel: lambda model: model.model_dump(mode="json", by_alias=prefer_alias), + pydantic_v2.BaseModel: lambda model: model.model_dump( + by_alias=prefer_alias, + exclude=exclude, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + exclude_unset=exclude_unset, + include=include, + mode="json", + ), pydantic_v2.types.SecretStr: lambda val: "**********" if val else "", pydantic_v2.types.SecretBytes: lambda val: "**********" if val else "", pydantic_v2.AnyUrl: str, @@ -175,7 +255,17 @@ def _create_pydantic_v2_encoders(prefer_alias: bool = False) -> dict[Any, Callab return encoders def on_app_init(self, app_config: AppConfig) -> AppConfig: - app_config.type_encoders = {**self.encoders(self.prefer_alias), **(app_config.type_encoders or {})} + app_config.type_encoders = { + **self.encoders( + prefer_alias=self.prefer_alias, + exclude=self.exclude, + exclude_defaults=self.exclude_defaults, + exclude_none=self.exclude_none, + exclude_unset=self.exclude_unset, + include=self.include, + ), + **(app_config.type_encoders or {}), + } app_config.type_decoders = [*self.decoders(), *(app_config.type_decoders or [])] _KWARG_META_EXTRACTORS.add(ConstrainedFieldMetaExtractor) diff --git a/litestar/types/serialization.py b/litestar/types/serialization.py index 0f61e10533..ed9f75893d 100644 --- a/litestar/types/serialization.py +++ b/litestar/types/serialization.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Dict, Set if TYPE_CHECKING: from collections import deque @@ -28,8 +28,13 @@ try: from pydantic import BaseModel + from pydantic.main import IncEx + from pydantic.typing import AbstractSetIntStr, MappingIntStrAny except ImportError: BaseModel = Any # type: ignore[assignment, misc] + IncEx = Any # type: ignore[misc] + AbstractSetIntStr = Any + MappingIntStrAny = Any try: from attrs import AttrsInstance @@ -57,3 +62,5 @@ EncodableMsgSpecType: TypeAlias = "Ext | Raw | Struct" LitestarEncodableType: TypeAlias = "EncodableBuiltinType | EncodableBuiltinCollectionType | EncodableStdLibType | EncodableStdLibIPType | EncodableMsgSpecType | BaseModel | AttrsInstance" # pyright: ignore DataContainerType: TypeAlias = "Struct | BaseModel | AttrsInstance | TypedDictClass | DataclassProtocol" # pyright: ignore +PydanticV2FieldsListType: TypeAlias = "Set[int] | Set[str] | Dict[int, Any] | Dict[str, Any] | None" +PydanticV1FieldsListType: TypeAlias = "IncEx | AbstractSetIntStr | MappingIntStrAny" # pyright: ignore diff --git a/litestar/typing.py b/litestar/typing.py index 1149926be2..d6816796f8 100644 --- a/litestar/typing.py +++ b/litestar/typing.py @@ -339,7 +339,7 @@ def is_required(self) -> bool: if isinstance(self.kwarg_definition, ParameterKwarg) and self.kwarg_definition.required is not None: return self.kwarg_definition.required - return not self.is_optional and not self.is_any and (not self.has_default or self.default is None) + return not self.is_optional and not self.is_any and (not self.has_default) @property def is_annotated(self) -> bool: diff --git a/tests/e2e/test_pydantic.py b/tests/e2e/test_pydantic.py index ad95441ae6..6faf4e7caf 100644 --- a/tests/e2e/test_pydantic.py +++ b/tests/e2e/test_pydantic.py @@ -1,24 +1,60 @@ import pydantic +import pytest +from pydantic import Field as FieldV2 +from pydantic.v1.fields import Field as FieldV1 from litestar import get +from litestar.contrib.pydantic import PydanticPlugin from litestar.testing import create_test_client -def test_app_with_v1_and_v2_models() -> None: +@pytest.mark.parametrize( + "plugin_params, response", + ( + ( + {"exclude": {"alias"}}, + { + "none": None, + "default": "default", + }, + ), + ({"exclude_defaults": True}, {"alias": "prefer_alias"}), + ({"exclude_none": True}, {"alias": "prefer_alias", "default": "default"}), + ({"exclude_unset": True}, {"alias": "prefer_alias"}), + ({"include": {"alias"}}, {"alias": "prefer_alias"}), + ({"prefer_alias": True}, {"prefer_alias": "prefer_alias", "default": "default", "none": None}), + ), + ids=( + "Exclude alias field", + "Exclude default fields", + "Exclude None field", + "Exclude unset fields", + "Include alias field", + "Use alias in response", + ), +) +def test_app_with_v1_and_v2_models(plugin_params: dict, response: dict) -> None: class ModelV1(pydantic.v1.BaseModel): # pyright: ignore - foo: str + alias: str = FieldV1(alias="prefer_alias") + default: str = "default" + none: None = None + + class Config: + allow_population_by_field_name = True class ModelV2(pydantic.BaseModel): - foo: str + alias: str = FieldV2(serialization_alias="prefer_alias") + default: str = "default" + none: None = None @get("/v1") def handler_v1() -> ModelV1: - return ModelV1(foo="bar") + return ModelV1(alias="prefer_alias") # type: ignore[call-arg] @get("/v2") def handler_v2() -> ModelV2: - return ModelV2(foo="bar") + return ModelV2(alias="prefer_alias") - with create_test_client([handler_v1, handler_v2]) as client: - assert client.get("/v1").json() == {"foo": "bar"} - assert client.get("/v2").json() == {"foo": "bar"} + with create_test_client([handler_v1, handler_v2], plugins=[PydanticPlugin(**plugin_params)]) as client: + assert client.get("/v1").json() == response + assert client.get("/v2").json() == response