Skip to content

Commit

Permalink
Fix DTO msgspec meta constraints not being included in transfer model
Browse files Browse the repository at this point in the history
  • Loading branch information
provinzkraut committed Feb 14, 2024
1 parent 11d79ce commit 8a63b25
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 7 deletions.
5 changes: 1 addition & 4 deletions docs/examples/contrib/piccolo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,4 @@ async def on_startup():
await create_db_tables(Task, if_not_exists=True)


app = Litestar(
route_handlers=[tasks, create_task, delete_task, update_task],
on_startup=[on_startup],
)
app = Litestar(route_handlers=[tasks, create_task, delete_task, update_task], on_startup=[on_startup], debug=True)
17 changes: 14 additions & 3 deletions litestar/contrib/piccolo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from dataclasses import replace
from decimal import Decimal
from typing import Any, Generator, Generic, List, Optional, TypeVar
Expand All @@ -9,7 +10,7 @@

from litestar.dto import AbstractDTO, DTOField, Mark
from litestar.dto.data_structures import DTOFieldDefinition
from litestar.exceptions import MissingDependencyException
from litestar.exceptions import LitestarWarning, MissingDependencyException
from litestar.types import Empty
from litestar.typing import FieldDefinition
from litestar.utils import warn_deprecation
Expand Down Expand Up @@ -38,12 +39,22 @@ def __getattr__(name: str) -> Any:


def _parse_piccolo_type(column: Column, extra: dict[str, Any]) -> FieldDefinition:
is_optional = not column._meta.required

if isinstance(column, (column_types.Decimal, column_types.Numeric)):
column_type: Any = Decimal
meta = Meta(extra=extra)
elif isinstance(column, (column_types.Email, column_types.Varchar)):
column_type = str
meta = Meta(max_length=column.length, extra=extra)
if is_optional:
meta = Meta(extra=extra)
warnings.warn(
f"Dropping max_length constraint for column {column!r} because the " "column is optional",
category=LitestarWarning,
stacklevel=2,
)
else:
meta = Meta(max_length=column.length, extra=extra)
elif isinstance(column, column_types.Array):
column_type = List[column.base_column.value_type] # type: ignore
meta = Meta(extra=extra)
Expand All @@ -57,7 +68,7 @@ def _parse_piccolo_type(column: Column, extra: dict[str, Any]) -> FieldDefinitio
column_type = column.value_type
meta = Meta(extra=extra)

if not column._meta.required:
if is_optional:
column_type = Optional[column_type]

return FieldDefinition.from_annotation(Annotated[column_type, meta])
Expand Down
24 changes: 24 additions & 0 deletions litestar/dto/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
cast,
)

import msgspec
from msgspec import UNSET, Struct, UnsetType, convert, defstruct, field
from typing_extensions import Annotated

from litestar.dto._types import (
CollectionType,
Expand All @@ -33,6 +35,7 @@
from litestar.dto.data_structures import DTOData, DTOFieldDefinition
from litestar.dto.field import Mark
from litestar.enums import RequestEncodingType
from litestar.params import KwargDefinition
from litestar.serialization import decode_json, decode_msgpack
from litestar.types import Empty
from litestar.typing import FieldDefinition
Expand Down Expand Up @@ -740,6 +743,24 @@ def _create_msgspec_field(field_definition: TransferDTOFieldDefinition) -> Any:
return field(**kwargs)


def _create_struct_field_meta_for_field_definition(field_definition: TransferDTOFieldDefinition) -> msgspec.Meta | None:
if (kwarg_definition := field_definition.kwarg_definition) is None or not isinstance(
kwarg_definition, KwargDefinition
):
return None

return msgspec.Meta(
gt=kwarg_definition.gt,
ge=kwarg_definition.ge,
lt=kwarg_definition.lt,
le=kwarg_definition.le,
multiple_of=kwarg_definition.multiple_of,
min_length=kwarg_definition.min_length if not field_definition.is_partial else None,
max_length=kwarg_definition.max_length if not field_definition.is_partial else None,
pattern=kwarg_definition.pattern,
)


def _create_struct_for_field_definitions(
model_name: str,
field_definitions: tuple[TransferDTOFieldDefinition, ...],
Expand All @@ -755,6 +776,9 @@ def _create_struct_for_field_definitions(
if field_definition.is_partial:
field_type = Union[field_type, UnsetType]

if (field_meta := _create_struct_field_meta_for_field_definition(field_definition)) is not None:
field_type = Annotated[field_type, field_meta]

struct_fields.append(
(
field_definition.name,
Expand Down
50 changes: 50 additions & 0 deletions tests/unit/test_dto/test_factory/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,3 +848,53 @@ def test(data: Optional[Foo] = None) -> dict:
with create_test_client([test]) as client:
response = client.post("/")
assert response.json() == {"foo": None}


@pytest.mark.parametrize(
"field_type, constraint_name, constraint_value, request_data",
[
(int, "gt", 2, 2),
(int, "ge", 2, 1),
(int, "lt", 2, 2),
(int, "le", 2, 3),
(int, "multiple_of", 2, 3),
(str, "min_length", 2, "1"),
(str, "max_length", 1, "12"),
(str, "pattern", r"\d", "a"),
],
)
def test_msgspec_dto_copies_constraints(
field_type: Any, constraint_name: str, constraint_value: Any, request_data: Any, use_experimental_dto_backend: bool
) -> None:
# https://github.com/litestar-org/litestar/issues/3026
struct = msgspec.defstruct(
"Foo",
fields=[("bar", Annotated[field_type, msgspec.Meta(**{constraint_name: constraint_value})])], # type: ignore[list-item]
)

@post(
"/",
dto=Annotated[MsgspecDTO[struct], DTOConfig(experimental_codegen_backend=use_experimental_dto_backend)], # type: ignore[arg-type, valid-type]
signature_namespace={"struct": struct},
)
def handler(data: struct) -> None: # type: ignore[valid-type]
pass

with create_test_client([handler]) as client:
assert client.post("/", json={"bar": request_data}).status_code == 400


def test_msgspec_dto_dont_copy_length_constraint_for_partial_dto() -> None:
class Foo(msgspec.Struct):
bar: Annotated[str, msgspec.Meta(min_length=2)]
baz: Annotated[str, msgspec.Meta(max_length=2)]

class FooDTO(MsgspecDTO[Foo]):
config = DTOConfig(partial=True)

@post("/", dto=FooDTO, signature_types={Foo})
def handler(data: Foo) -> None:
pass

with create_test_client([handler]) as client:
assert client.post("/", json={"bar": "1", "baz": "123"}).status_code == 201

0 comments on commit 8a63b25

Please sign in to comment.