diff --git a/docs/usage/library_factories/index.rst b/docs/usage/library_factories/index.rst index 20911acf..8ccfebb4 100644 --- a/docs/usage/library_factories/index.rst +++ b/docs/usage/library_factories/index.rst @@ -11,9 +11,13 @@ These include: :class:`TypedDictFactory ` a base factory for typed-dicts + :class:`ModelFactory ` a base factory for `pydantic `_ models +:class:`PydanticDataclassFactory ` + a base factory for `pydantic `_ dataclasses + :class:`BeanieDocumentFactory ` a base factory for `beanie `_ documents diff --git a/polyfactory/factories/pydantic_factory.py b/polyfactory/factories/pydantic_factory.py index 0d53f341..e29348f6 100644 --- a/polyfactory/factories/pydantic_factory.py +++ b/polyfactory/factories/pydantic_factory.py @@ -55,6 +55,7 @@ ModelField, # pyright: ignore[attr-defined,reportAttributeAccessIssue] Undefined, # pyright: ignore[attr-defined,reportAttributeAccessIssue] ) + from pydantic.dataclasses import is_pydantic_dataclass # Keep this import last to prevent warnings from pydantic if pydantic v2 # is installed. @@ -68,6 +69,7 @@ # v2 specific imports from pydantic import BaseModel as BaseModelV2 + from pydantic.dataclasses import is_pydantic_dataclass from pydantic_core import PydanticUndefined as UndefinedV2 from pydantic_core import to_json @@ -99,7 +101,8 @@ from typing_extensions import NotRequired, TypeGuard -T = TypeVar("T", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm] +ModelT = TypeVar("ModelT", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm] +T = TypeVar("T") _IS_PYDANTIC_V1 = VERSION.startswith("1") @@ -370,7 +373,7 @@ def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]: return metadata -class ModelFactory(Generic[T], BaseFactory[T]): +class ModelFactory(Generic[ModelT], BaseFactory[ModelT]): """Base factory for pydantic models""" __forward_ref_resolution_type_mapping__: ClassVar[Mapping[str, type]] = {} @@ -388,7 +391,7 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__) # type: ignore[attr-defined] @classmethod - def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: + def is_supported_type(cls, value: Any) -> TypeGuard[type[ModelT]]: """Determine whether the given value is supported by the factory. :param value: An arbitrary value. @@ -454,7 +457,7 @@ def build( cls, factory_use_construct: bool = False, **kwargs: Any, - ) -> T: + ) -> ModelT: """Build an instance of the factory's __model__ :param factory_use_construct: A boolean that determines whether validations will be made when instantiating the @@ -492,7 +495,7 @@ def _get_build_context(cls, build_context: BaseBuildContext | PydanticBuildConte } @classmethod - def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> T: + def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> ModelT: """Create an instance of the factory's __model__ :param _build_context: BuildContext instance. @@ -508,7 +511,7 @@ def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> T return cls.__model__(**kwargs) # type: ignore[return-value] @classmethod - def coverage(cls, factory_use_construct: bool = False, **kwargs: Any) -> abc.Iterator[T]: + def coverage(cls, factory_use_construct: bool = False, **kwargs: Any) -> abc.Iterator[ModelT]: """Build a batch of the factory's Meta.model will full coverage of the sub-types of the model. :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. @@ -629,3 +632,36 @@ def _is_pydantic_v1_model(model: Any) -> TypeGuard[BaseModelV1]: def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: # pyright: ignore[reportInvalidTypeForm] return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2) + + +class PydanticDataclassFactory(ModelFactory[T]): # type: ignore[type-var] + """Base factory for pydantic dataclasses""" + + __is_base_factory__ = True + + @classmethod + def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: + return is_pydantic_dataclass(value) + + @classmethod + def get_model_fields(cls) -> list[FieldMeta]: + if not is_pydantic_dataclass(cls.__model__): + # This should be unreachable + return [] + + pydantic_fields = cls.__model__.__pydantic_fields__ + pydantic_config = cls.__model__.__pydantic_config__ + cls._fields_metadata = [ + PydanticFieldMeta.from_field_info( + field_info=field_info, + field_name=field_name, + random=cls.__random__, + use_alias=not pydantic_config.get( + "populate_by_name", + False, + ), + ) + for field_name, field_info in pydantic_fields.items() + ] + + return cls._fields_metadata diff --git a/tests/test_pydantic_factory.py b/tests/test_pydantic_factory.py index 2c24db78..1f859566 100644 --- a/tests/test_pydantic_factory.py +++ b/tests/test_pydantic_factory.py @@ -63,9 +63,11 @@ constr, validator, ) +from pydantic.dataclasses import dataclass as pydantic_dataclass +from polyfactory.exceptions import ConfigurationException from polyfactory.factories import DataclassFactory -from polyfactory.factories.pydantic_factory import _IS_PYDANTIC_V1, ModelFactory +from polyfactory.factories.pydantic_factory import _IS_PYDANTIC_V1, ModelFactory, PydanticDataclassFactory from tests.models import Person, PetFactory IS_PYDANTIC_V1 = _IS_PYDANTIC_V1 @@ -1038,3 +1040,49 @@ class A(BaseModel): AFactory = ModelFactory.create_factory(A) assert AFactory.build() + + +def test_simple_pydantic_dataclass() -> None: + @pydantic_dataclass + class DataclassModel: + a: int + b: Annotated[str, MinLen(1)] + + class DataclassModelFactory(PydanticDataclassFactory[DataclassModel]): + __model__ = DataclassModel + + instance = DataclassModelFactory.build() + assert isinstance(instance, DataclassModel) + assert isinstance(instance.a, int) + assert isinstance(instance.b, str) + assert len(instance.b) >= 1 + + +def test_nested_pydantic_dataclass() -> None: + @pydantic_dataclass + class FooDataclass: + content: int + + @pydantic_dataclass + class NestedDataclassModel: + foo: FooDataclass + + class DataclassModelFactory(PydanticDataclassFactory[NestedDataclassModel]): + __model__ = NestedDataclassModel + + instance = DataclassModelFactory.build() + assert isinstance(instance, NestedDataclassModel) + assert isinstance(instance.foo, FooDataclass) + assert isinstance(instance.foo.content, int) + + +def test_pydantic_dataclass_factory_raises_for_std_dataclasses() -> None: + @dataclass + class DataclassModel: + a: int + b: str + + with pytest.raises(ConfigurationException): + + class DataclassModelFactory(PydanticDataclassFactory[DataclassModel]): + __model__ = DataclassModel