Skip to content

Commit

Permalink
fix: pydantic version selection in tests (#2580)
Browse files Browse the repository at this point in the history
* fix: pydantic version selection in tests

The `pydantic_version` fixture returns strings of "v1" and "v2", however, we're using `if pydantic_version == "1"` as the branching condition in the tests.

This PR ensures that we are comparing the correct literal value in the tests.

* fix type error
  • Loading branch information
peterschutt authored Nov 1, 2023
1 parent 7443afb commit 57a8411
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 25 deletions.
3 changes: 3 additions & 0 deletions tests/unit/test_contrib/test_pydantic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import Literal

PydanticVersion = Literal["v1", "v2"]
8 changes: 5 additions & 3 deletions tests/unit/test_contrib/test_pydantic/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from pydantic import v1 as pydantic_v1
from pytest import FixtureRequest

from . import PydanticVersion


@pytest.fixture(params=["v1", "v2"])
def pydantic_version(request: FixtureRequest) -> str:
def pydantic_version(request: FixtureRequest) -> PydanticVersion:
return request.param # type: ignore[no-any-return]


@pytest.fixture()
def base_model(pydantic_version: str) -> type[pydantic.BaseModel | pydantic_v1.BaseModel]:
return pydantic_v1.BaseModel if pydantic_version == "1" else pydantic.BaseModel
def base_model(pydantic_version: PydanticVersion) -> type[pydantic.BaseModel | pydantic_v1.BaseModel]:
return pydantic_v1.BaseModel if pydantic_version == "v1" else pydantic.BaseModel
6 changes: 4 additions & 2 deletions tests/unit/test_contrib/test_pydantic/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from litestar.testing import create_test_client
from tests.unit.test_contrib.test_pydantic.models import PydanticPerson, PydanticV1Person

from . import PydanticVersion


def test_pydantic_v1_validation_error_raises_400() -> None:
class Model(pydantic_v1.BaseModel):
Expand Down Expand Up @@ -99,7 +101,7 @@ def my_route_handler(param: int, data: PydanticV1Person) -> None:


def test_signature_model_invalid_input(
base_model: Type[Union[pydantic_v2.BaseModel, pydantic_v1.BaseModel]], pydantic_version: str
base_model: Type[Union[pydantic_v2.BaseModel, pydantic_v1.BaseModel]], pydantic_version: PydanticVersion
) -> None:
class OtherChild(base_model): # type: ignore[misc, valid-type]
val: List[int]
Expand Down Expand Up @@ -136,7 +138,7 @@ def test(
data = response.json()

assert data
if pydantic_version == "1":
if pydantic_version == "v1":
assert data["extra"] == [
{"key": "child.val", "message": "value is not a valid integer"},
{"key": "child.other_val", "message": "value is not a valid integer"},
Expand Down
30 changes: 16 additions & 14 deletions tests/unit/test_contrib/test_pydantic/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
PydanticV1Person,
)

from . import PydanticVersion

AnyBaseModelType = Type[Union[pydantic_v1.BaseModel, pydantic_v2.BaseModel]]


Expand Down Expand Up @@ -141,20 +143,20 @@ def test_create_collection_constrained_field_schema_pydantic_v2(annotation: Any)


@pytest.fixture()
def conset(pydantic_version: str) -> Any:
return pydantic_v1.conset if pydantic_version == "1" else pydantic_v2.conset
def conset(pydantic_version: PydanticVersion) -> Any:
return pydantic_v1.conset if pydantic_version == "v1" else pydantic_v2.conset


@pytest.fixture()
def conlist(pydantic_version: str) -> Any:
return pydantic_v1.conlist if pydantic_version == "1" else pydantic_v2.conlist
def conlist(pydantic_version: PydanticVersion) -> Any:
return pydantic_v1.conlist if pydantic_version == "v1" else pydantic_v2.conlist


def test_create_collection_constrained_field_schema_sub_fields(
pydantic_version: str, conset: Any, conlist: Any
pydantic_version: PydanticVersion, conset: Any, conlist: Any
) -> None:
for pydantic_fn in [conset, conlist]:
if pydantic_version == "1":
if pydantic_version == "v1":
annotation = pydantic_fn(Union[str, int], min_items=1, max_items=10)
else:
annotation = pydantic_fn(Union[str, int], min_length=1, max_length=10)
Expand Down Expand Up @@ -414,12 +416,12 @@ async def example_route() -> Lookup:
}


def test_schema_by_alias(base_model: AnyBaseModelType, pydantic_version: str) -> None:
def test_schema_by_alias(base_model: AnyBaseModelType, pydantic_version: PydanticVersion) -> None:
class RequestWithAlias(base_model): # type: ignore[misc, valid-type]
first: str = (pydantic_v1.Field if pydantic_version == "1" else pydantic_v2.Field)(alias="second") # type: ignore[operator]
first: str = (pydantic_v1.Field if pydantic_version == "v1" else pydantic_v2.Field)(alias="second") # type: ignore[operator]

class ResponseWithAlias(base_model): # type: ignore[misc, valid-type]
first: str = (pydantic_v1.Field if pydantic_version == "1" else pydantic_v2.Field)(alias="second") # type: ignore[operator]
first: str = (pydantic_v1.Field if pydantic_version == "v1" else pydantic_v2.Field)(alias="second") # type: ignore[operator]

@post("/", signature_types=[RequestWithAlias, ResponseWithAlias])
def handler(data: RequestWithAlias) -> ResponseWithAlias:
Expand Down Expand Up @@ -449,12 +451,12 @@ def handler(data: RequestWithAlias) -> ResponseWithAlias:
assert response.json() == {response_key: "foo"}


def test_schema_by_alias_plugin_override(base_model: AnyBaseModelType, pydantic_version: str) -> None:
def test_schema_by_alias_plugin_override(base_model: AnyBaseModelType, pydantic_version: PydanticVersion) -> None:
class RequestWithAlias(base_model): # type: ignore[misc, valid-type]
first: str = (pydantic_v1.Field if pydantic_version == "1" else pydantic_v2.Field)(alias="second") # type: ignore[operator]
first: str = (pydantic_v1.Field if pydantic_version == "v1" else pydantic_v2.Field)(alias="second") # type: ignore[operator]

class ResponseWithAlias(base_model): # type: ignore[misc, valid-type]
first: str = (pydantic_v1.Field if pydantic_version == "1" else pydantic_v2.Field)(alias="second") # type: ignore[operator]
first: str = (pydantic_v1.Field if pydantic_version == "v1" else pydantic_v2.Field)(alias="second") # type: ignore[operator]

@post("/", signature_types=[RequestWithAlias, ResponseWithAlias])
def handler(data: RequestWithAlias) -> ResponseWithAlias:
Expand Down Expand Up @@ -522,14 +524,14 @@ class Model(pydantic_v2.BaseModel):

@pytest.mark.parametrize("with_future_annotations", [True, False])
def test_create_schema_for_pydantic_model_with_annotated_model_attribute(
with_future_annotations: bool, create_module: "Callable[[str], ModuleType]", pydantic_version: str
with_future_annotations: bool, create_module: "Callable[[str], ModuleType]", pydantic_version: PydanticVersion
) -> None:
"""Test that a model with an annotated attribute is correctly handled."""
module = create_module(
f"""
{'from __future__ import annotations' if with_future_annotations else ''}
from typing_extensions import Annotated
{'from pydantic import BaseModel' if pydantic_version == '1' else 'from pydantic.v1 import BaseModel'}
{'from pydantic import BaseModel' if pydantic_version == 'v1' else 'from pydantic.v1 import BaseModel'}
class Foo(BaseModel):
foo: Annotated[int, "Foo description"]
Expand Down
16 changes: 10 additions & 6 deletions tests/unit/test_contrib/test_pydantic/test_plugin_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
get_serializer,
)

from . import PydanticVersion


class CustomStr(str):
pass
Expand Down Expand Up @@ -118,13 +120,13 @@ class ModelV2(pydantic_v2.BaseModel):


@pytest.fixture()
def model_type(pydantic_version: str) -> type[ModelV1 | ModelV2]:
return ModelV1 if pydantic_version == "1" else ModelV2
def model_type(pydantic_version: PydanticVersion) -> type[ModelV1 | ModelV2]:
return ModelV1 if pydantic_version == "v1" else ModelV2


@pytest.fixture()
def model(pydantic_version: str) -> ModelV1 | ModelV2:
if pydantic_version == "1":
def model(pydantic_version: PydanticVersion) -> ModelV1 | ModelV2:
if pydantic_version == "v1":
return ModelV1(
path=Path("example"),
email_str=pydantic_v1.parse_obj_as(pydantic_v1.EmailStr, "[email protected]"),
Expand Down Expand Up @@ -195,14 +197,16 @@ def test_serialization_of_model_instance(model: ModelV1 | ModelV2) -> None:


@pytest.mark.parametrize("prefer_alias", [False, True])
def test_pydantic_json_compatibility(model: ModelV1 | ModelV2, prefer_alias: bool, pydantic_version: str) -> None:
def test_pydantic_json_compatibility(
model: ModelV1 | ModelV2, prefer_alias: bool, pydantic_version: PydanticVersion
) -> None:
raw = _model_dump_json(model, by_alias=prefer_alias)
encoded_json = encode_json(model, serializer=get_serializer(PydanticInitPlugin.encoders(prefer_alias=prefer_alias)))

raw_result = json.loads(raw)
encoded_result = json.loads(encoded_json)

if pydantic_version == "1":
if pydantic_version == "v1":
# pydantic v1 dumps decimals into floats as json, we therefore regard this as an error
assert raw_result.get("condecimal") == float(encoded_result.get("condecimal"))
del raw_result["condecimal"]
Expand Down

0 comments on commit 57a8411

Please sign in to comment.