diff --git a/polyfactory/factories/pydantic_factory.py b/polyfactory/factories/pydantic_factory.py index e29348f6..ab1ebf70 100644 --- a/polyfactory/factories/pydantic_factory.py +++ b/polyfactory/factories/pydantic_factory.py @@ -2,6 +2,7 @@ import copy from contextlib import suppress +from dataclasses import is_dataclass from datetime import timezone from functools import partial from os.path import realpath @@ -55,7 +56,6 @@ ModelField, # pyright: ignore[attr-defined,reportAttributeAccessIssue] Undefined, # pyright: ignore[attr-defined,reportAttributeAccessIssue] ) - from pydantic.dataclasses import is_pydantic_dataclass # Keep this import last to prevent warnings from pydantic if pydantic v2 # is installed. @@ -69,7 +69,6 @@ # v2 specific imports from pydantic import BaseModel as BaseModelV2 - from pydantic.dataclasses import is_pydantic_dataclass from pydantic_core import PydanticUndefined as UndefinedV2 from pydantic_core import to_json @@ -101,6 +100,8 @@ from typing_extensions import NotRequired, TypeGuard + from pydantic.dataclasses import PydanticDataclass # pyright: ignore[reportPrivateImportUsage] + ModelT = TypeVar("ModelT", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm] T = TypeVar("T") @@ -634,6 +635,11 @@ def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: # pyright: ign return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2) +def is_pydantic_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclass]: + # This method is available in the `pydantic.dataclasses` module for python >= 3.9 + return is_dataclass(cls) and "__pydantic_validator__" in cls.__dict__ + + class PydanticDataclassFactory(ModelFactory[T]): # type: ignore[type-var] """Base factory for pydantic dataclasses"""