Skip to content

Commit

Permalink
feat: add a factory supporting pydantic dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
Slyces committed Dec 6, 2024
1 parent 1e0c847 commit 9fe676f
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 7 deletions.
4 changes: 4 additions & 0 deletions docs/usage/library_factories/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@ These include:

:class:`TypedDictFactory <polyfactory.factories.typed_dict_factory.TypedDictFactory>`
a base factory for typed-dicts

:class:`ModelFactory <polyfactory.factories.pydantic_factory.ModelFactory>`
a base factory for `pydantic <https://docs.pydantic.dev/>`_ models

:class:`PydanticDataclassFactory <polyfactory.factories.pydantic_factory.PydanticDataclassFactory>`
a base factory for `pydantic <https://docs.pydantic.dev/latest/concepts/dataclasses/>`_ dataclasses

:class:`BeanieDocumentFactory <polyfactory.factories.beanie_odm_factory.BeanieDocumentFactory>`
a base factory for `beanie <https://beanie-odm.dev/>`_ documents

Expand Down
48 changes: 42 additions & 6 deletions polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
ModelField, # pyright: ignore[attr-defined,reportAttributeAccessIssue]
Undefined, # pyright: ignore[attr-defined,reportAttributeAccessIssue]
)
from pydantic.dataclasses import is_pydantic_dataclass

# Keep this import last to prevent warnings from pydantic if pydantic v2
# is installed.
Expand All @@ -68,6 +69,7 @@

# v2 specific imports
from pydantic import BaseModel as BaseModelV2
from pydantic.dataclasses import is_pydantic_dataclass
from pydantic_core import PydanticUndefined as UndefinedV2
from pydantic_core import to_json

Expand Down Expand Up @@ -99,7 +101,8 @@

from typing_extensions import NotRequired, TypeGuard

T = TypeVar("T", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm]
ModelT = TypeVar("ModelT", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm]
T = TypeVar("T")

_IS_PYDANTIC_V1 = VERSION.startswith("1")

Expand Down Expand Up @@ -370,7 +373,7 @@ def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]:
return metadata


class ModelFactory(Generic[T], BaseFactory[T]):
class ModelFactory(Generic[ModelT], BaseFactory[ModelT]):
"""Base factory for pydantic models"""

__forward_ref_resolution_type_mapping__: ClassVar[Mapping[str, type]] = {}
Expand All @@ -388,7 +391,7 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__) # type: ignore[attr-defined]

@classmethod
def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
def is_supported_type(cls, value: Any) -> TypeGuard[type[ModelT]]:
"""Determine whether the given value is supported by the factory.
:param value: An arbitrary value.
Expand Down Expand Up @@ -454,7 +457,7 @@ def build(
cls,
factory_use_construct: bool = False,
**kwargs: Any,
) -> T:
) -> ModelT:
"""Build an instance of the factory's __model__
:param factory_use_construct: A boolean that determines whether validations will be made when instantiating the
Expand Down Expand Up @@ -492,7 +495,7 @@ def _get_build_context(cls, build_context: BaseBuildContext | PydanticBuildConte
}

@classmethod
def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> T:
def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> ModelT:
"""Create an instance of the factory's __model__
:param _build_context: BuildContext instance.
Expand All @@ -508,7 +511,7 @@ def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> T
return cls.__model__(**kwargs) # type: ignore[return-value]

@classmethod
def coverage(cls, factory_use_construct: bool = False, **kwargs: Any) -> abc.Iterator[T]:
def coverage(cls, factory_use_construct: bool = False, **kwargs: Any) -> abc.Iterator[ModelT]:
"""Build a batch of the factory's Meta.model will full coverage of the sub-types of the model.
:param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.
Expand Down Expand Up @@ -629,3 +632,36 @@ def _is_pydantic_v1_model(model: Any) -> TypeGuard[BaseModelV1]:

def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: # pyright: ignore[reportInvalidTypeForm]
return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2)


class PydanticDataclassFactory(ModelFactory[T]): # type: ignore[type-var]
"""Base factory for pydantic dataclasses"""

__is_base_factory__ = True

@classmethod
def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
return is_pydantic_dataclass(value)

@classmethod
def get_model_fields(cls) -> list[FieldMeta]:
if not is_pydantic_dataclass(cls.__model__):
# This should be unreachable
return []

pydantic_fields = cls.__model__.__pydantic_fields__
pydantic_config = cls.__model__.__pydantic_config__
cls._fields_metadata = [
PydanticFieldMeta.from_field_info(
field_info=field_info,
field_name=field_name,
random=cls.__random__,
use_alias=not pydantic_config.get(
"populate_by_name",
False,
),
)
for field_name, field_info in pydantic_fields.items()
]

return cls._fields_metadata
50 changes: 49 additions & 1 deletion tests/test_pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@
constr,
validator,
)
from pydantic.dataclasses import dataclass as pydantic_dataclass

from polyfactory.exceptions import ConfigurationException
from polyfactory.factories import DataclassFactory
from polyfactory.factories.pydantic_factory import _IS_PYDANTIC_V1, ModelFactory
from polyfactory.factories.pydantic_factory import _IS_PYDANTIC_V1, ModelFactory, PydanticDataclassFactory
from tests.models import Person, PetFactory

IS_PYDANTIC_V1 = _IS_PYDANTIC_V1
Expand Down Expand Up @@ -1038,3 +1040,49 @@ class A(BaseModel):
AFactory = ModelFactory.create_factory(A)

assert AFactory.build()


def test_simple_pydantic_dataclass() -> None:
@pydantic_dataclass
class DataclassModel:
a: int
b: Annotated[str, MinLen(1)]

class DataclassModelFactory(PydanticDataclassFactory[DataclassModel]):
__model__ = DataclassModel

instance = DataclassModelFactory.build()
assert isinstance(instance, DataclassModel)
assert isinstance(instance.a, int)
assert isinstance(instance.b, str)
assert len(instance.b) >= 1


def test_nested_pydantic_dataclass() -> None:
@pydantic_dataclass
class FooDataclass:
content: int

@pydantic_dataclass
class NestedDataclassModel:
foo: FooDataclass

class DataclassModelFactory(PydanticDataclassFactory[NestedDataclassModel]):
__model__ = NestedDataclassModel

instance = DataclassModelFactory.build()
assert isinstance(instance, NestedDataclassModel)
assert isinstance(instance.foo, FooDataclass)
assert isinstance(instance.foo.content, int)


def test_pydantic_dataclass_factory_raises_for_std_dataclasses() -> None:
@dataclass
class DataclassModel:
a: int
b: str

with pytest.raises(ConfigurationException):

class DataclassModelFactory(PydanticDataclassFactory[DataclassModel]):
__model__ = DataclassModel

0 comments on commit 9fe676f

Please sign in to comment.