Skip to content

Commit

Permalink
feat: Add support for v2 model_dump and v1 dict to pydantic plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
kedod committed Apr 6, 2024
1 parent b5d9c6f commit 8046b36
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 23 deletions.
40 changes: 37 additions & 3 deletions litestar/contrib/pydantic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from pydantic.v1 import BaseModel as BaseModelV1

from litestar.config.app import AppConfig
from litestar.types.serialization import PydanticV1FieldsListType, PydanticV2FieldsListType


__all__ = (
"PydanticDTO",
Expand Down Expand Up @@ -43,14 +45,39 @@ def _model_dump_json(model: BaseModel | BaseModelV1, by_alias: bool = False) ->
class PydanticPlugin(InitPluginProtocol):
"""A plugin that provides Pydantic integration."""

__slots__ = ("prefer_alias",)
__slots__ = (
"exclude",
"exclude_defaults",
"exclude_none",
"exclude_unset",
"include",
"prefer_alias",
)

def __init__(self, prefer_alias: bool = False) -> None:
def __init__(
self,
exclude: PydanticV1FieldsListType | PydanticV2FieldsListType = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
include: PydanticV1FieldsListType | PydanticV2FieldsListType = None,
prefer_alias: bool = False,
) -> None:
"""Initialize ``PydanticPlugin``.
Args:
exclude: ``type_encoders`` will exclude specified fields
exclude_defaults: ``type_encoders`` will exclude default fields
exclude_none: ``type_encoders`` will exclude ``None`` fields
exclude_unset: ``type_encoders`` will exclude not set fields
include: ``type_encoders`` will include only specified fields
prefer_alias: OpenAPI and ``type_encoders`` will export by alias
"""
self.exclude = exclude
self.exclude_defaults = exclude_defaults
self.exclude_none = exclude_none
self.exclude_unset = exclude_unset
self.include = include
self.prefer_alias = prefer_alias

def on_app_init(self, app_config: AppConfig) -> AppConfig:
Expand All @@ -61,7 +88,14 @@ def on_app_init(self, app_config: AppConfig) -> AppConfig:
"""
app_config.plugins.extend(
[
PydanticInitPlugin(prefer_alias=self.prefer_alias),
PydanticInitPlugin(
exclude=self.exclude,
exclude_defaults=self.exclude_defaults,
exclude_none=self.exclude_none,
exclude_unset=self.exclude_unset,
include=self.include,
prefer_alias=self.prefer_alias,
),
PydanticSchemaPlugin(prefer_alias=self.prefer_alias),
PydanticDIPlugin(),
]
Expand Down
110 changes: 100 additions & 10 deletions litestar/contrib/pydantic/pydantic_init_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@


if TYPE_CHECKING:
from pydantic.typing import AbstractSetIntStr, MappingIntStrAny

from litestar.config.app import AppConfig
from litestar.types.serialization import PydanticV1FieldsListType, PydanticV2FieldsListType


T = TypeVar("T")
Expand Down Expand Up @@ -119,16 +122,63 @@ def extract(annotation: Any, default: Any) -> Any:


class PydanticInitPlugin(InitPluginProtocol):
__slots__ = ("prefer_alias",)
__slots__ = (
"exclude",
"exclude_defaults",
"exclude_none",
"exclude_unset",
"include",
"prefer_alias",
)

def __init__(self, prefer_alias: bool = False) -> None:
def __init__(
self,
exclude: PydanticV1FieldsListType | PydanticV2FieldsListType = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
include: PydanticV1FieldsListType | PydanticV2FieldsListType = None,
prefer_alias: bool = False,
) -> None:
self.exclude = exclude
self.exclude_defaults = exclude_defaults
self.exclude_none = exclude_none
self.exclude_unset = exclude_unset
self.include = include
self.prefer_alias = prefer_alias

@classmethod
def encoders(cls, prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]:
encoders = {**_base_encoders, **cls._create_pydantic_v1_encoders(prefer_alias)}
def encoders(
cls,
exclude: PydanticV1FieldsListType | PydanticV2FieldsListType = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
include: PydanticV1FieldsListType | PydanticV2FieldsListType = None,
prefer_alias: bool = False,
) -> dict[Any, Callable[[Any], Any]]:
encoders = {
**_base_encoders,
**cls._create_pydantic_v1_encoders(
prefer_alias=prefer_alias,
exclude=exclude,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
exclude_unset=exclude_unset,
include=include,
),
}
if pydantic_v2 is not None: # pragma: no cover
encoders.update(cls._create_pydantic_v2_encoders(prefer_alias))
encoders.update(
cls._create_pydantic_v2_encoders(
prefer_alias=prefer_alias,
exclude=exclude,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
exclude_unset=exclude_unset,
include=include,
)
)
return encoders

@classmethod
Expand All @@ -145,10 +195,25 @@ def decoders(cls) -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]
return decoders

@staticmethod
def _create_pydantic_v1_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: # pragma: no cover
def _create_pydantic_v1_encoders(
exclude: AbstractSetIntStr | MappingIntStrAny | None = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
include: AbstractSetIntStr | MappingIntStrAny | None = None,
prefer_alias: bool = False,
) -> dict[Any, Callable[[Any], Any]]: # pragma: no cover
return {
pydantic_v1.BaseModel: lambda model: {
k: v.decode() if isinstance(v, bytes) else v for k, v in model.dict(by_alias=prefer_alias).items()
k: v.decode() if isinstance(v, bytes) else v
for k, v in model.dict(
by_alias=prefer_alias,
exclude=exclude,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
exclude_unset=exclude_unset,
include=include,
).items()
},
pydantic_v1.SecretField: str,
pydantic_v1.StrictBool: int,
Expand All @@ -159,9 +224,24 @@ def _create_pydantic_v1_encoders(prefer_alias: bool = False) -> dict[Any, Callab
}

@staticmethod
def _create_pydantic_v2_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]:
def _create_pydantic_v2_encoders(
exclude: PydanticV2FieldsListType = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
include: PydanticV2FieldsListType = None,
prefer_alias: bool = False,
) -> dict[Any, Callable[[Any], Any]]:
encoders: dict[Any, Callable[[Any], Any]] = {
pydantic_v2.BaseModel: lambda model: model.model_dump(mode="json", by_alias=prefer_alias),
pydantic_v2.BaseModel: lambda model: model.model_dump(
by_alias=prefer_alias,
exclude=exclude,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
exclude_unset=exclude_unset,
include=include,
mode="json",
),
pydantic_v2.types.SecretStr: lambda val: "**********" if val else "",
pydantic_v2.types.SecretBytes: lambda val: "**********" if val else "",
pydantic_v2.AnyUrl: str,
Expand All @@ -175,7 +255,17 @@ def _create_pydantic_v2_encoders(prefer_alias: bool = False) -> dict[Any, Callab
return encoders

def on_app_init(self, app_config: AppConfig) -> AppConfig:
app_config.type_encoders = {**self.encoders(self.prefer_alias), **(app_config.type_encoders or {})}
app_config.type_encoders = {
**self.encoders(
prefer_alias=self.prefer_alias,
exclude=self.exclude,
exclude_defaults=self.exclude_defaults,
exclude_none=self.exclude_none,
exclude_unset=self.exclude_unset,
include=self.include,
),
**(app_config.type_encoders or {}),
}
app_config.type_decoders = [*self.decoders(), *(app_config.type_decoders or [])]

_KWARG_META_EXTRACTORS.add(ConstrainedFieldMetaExtractor)
Expand Down
9 changes: 8 additions & 1 deletion litestar/types/serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Dict, Set

if TYPE_CHECKING:
from collections import deque
Expand Down Expand Up @@ -28,8 +28,13 @@

try:
from pydantic import BaseModel
from pydantic.main import IncEx
from pydantic.typing import AbstractSetIntStr, MappingIntStrAny
except ImportError:
BaseModel = Any # type: ignore[assignment, misc]
IncEx = Any # type: ignore[misc]
AbstractSetIntStr = Any
MappingIntStrAny = Any

try:
from attrs import AttrsInstance
Expand Down Expand Up @@ -57,3 +62,5 @@
EncodableMsgSpecType: TypeAlias = "Ext | Raw | Struct"
LitestarEncodableType: TypeAlias = "EncodableBuiltinType | EncodableBuiltinCollectionType | EncodableStdLibType | EncodableStdLibIPType | EncodableMsgSpecType | BaseModel | AttrsInstance" # pyright: ignore
DataContainerType: TypeAlias = "Struct | BaseModel | AttrsInstance | TypedDictClass | DataclassProtocol" # pyright: ignore
PydanticV2FieldsListType: TypeAlias = "Set[int] | Set[str] | Dict[int, Any] | Dict[str, Any] | None"
PydanticV1FieldsListType: TypeAlias = "IncEx | AbstractSetIntStr | MappingIntStrAny" # pyright: ignore
2 changes: 1 addition & 1 deletion litestar/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def is_required(self) -> bool:
if isinstance(self.kwarg_definition, ParameterKwarg) and self.kwarg_definition.required is not None:
return self.kwarg_definition.required

return not self.is_optional and not self.is_any and (not self.has_default or self.default is None)
return not self.is_optional and not self.is_any and (not self.has_default)

@property
def is_annotated(self) -> bool:
Expand Down
52 changes: 44 additions & 8 deletions tests/e2e/test_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,60 @@
import pydantic
import pytest
from pydantic import Field as FieldV2
from pydantic.v1.fields import Field as FieldV1

from litestar import get
from litestar.contrib.pydantic import PydanticPlugin
from litestar.testing import create_test_client


def test_app_with_v1_and_v2_models() -> None:
@pytest.mark.parametrize(
"plugin_params, response",
(
(
{"exclude": {"alias"}},
{
"none": None,
"default": "default",
},
),
({"exclude_defaults": True}, {"alias": "prefer_alias"}),
({"exclude_none": True}, {"alias": "prefer_alias", "default": "default"}),
({"exclude_unset": True}, {"alias": "prefer_alias"}),
({"include": {"alias"}}, {"alias": "prefer_alias"}),
({"prefer_alias": True}, {"prefer_alias": "prefer_alias", "default": "default", "none": None}),
),
ids=(
"Exclude alias field",
"Exclude default fields",
"Exclude None field",
"Exclude unset fields",
"Include alias field",
"Use alias in response",
),
)
def test_app_with_v1_and_v2_models(plugin_params: dict, response: dict) -> None:
class ModelV1(pydantic.v1.BaseModel): # pyright: ignore
foo: str
alias: str = FieldV1(alias="prefer_alias")
default: str = "default"
none: None = None

class Config:
allow_population_by_field_name = True

class ModelV2(pydantic.BaseModel):
foo: str
alias: str = FieldV2(serialization_alias="prefer_alias")
default: str = "default"
none: None = None

@get("/v1")
def handler_v1() -> ModelV1:
return ModelV1(foo="bar")
return ModelV1(alias="prefer_alias") # type: ignore[call-arg]

@get("/v2")
def handler_v2() -> ModelV2:
return ModelV2(foo="bar")
return ModelV2(alias="prefer_alias")

with create_test_client([handler_v1, handler_v2]) as client:
assert client.get("/v1").json() == {"foo": "bar"}
assert client.get("/v2").json() == {"foo": "bar"}
with create_test_client([handler_v1, handler_v2], plugins=[PydanticPlugin(**plugin_params)]) as client:
assert client.get("/v1").json() == response
assert client.get("/v2").json() == response

0 comments on commit 8046b36

Please sign in to comment.