diff --git a/fastapi_crudrouter/core/_utils.py b/fastapi_crudrouter/core/_utils.py index ef3562e4..dd0ea025 100644 --- a/fastapi_crudrouter/core/_utils.py +++ b/fastapi_crudrouter/core/_utils.py @@ -1,7 +1,7 @@ from typing import Optional, Type, Any from fastapi import Depends, HTTPException -from pydantic import create_model +from pydantic import create_model, __version__ as pydantic_version from ._types import T, PAGINATION, PYDANTIC_SCHEMA @@ -14,7 +14,11 @@ def __init__(self, *args, **kwargs) -> None: # type: ignore def get_pk_type(schema: Type[PYDANTIC_SCHEMA], pk_field: str) -> Any: try: - return schema.__fields__[pk_field].type_ + # for handle pydantic 2.x migration + if int(pydantic_version.split(".")[0]) >= 2: + return schema.model_fields[pk_field].annotation + else: + return schema.__fields__[pk_field].type_ except KeyError: return int @@ -26,11 +30,21 @@ def schema_factory( Is used to create a CreateSchema which does not contain pk """ - fields = { - f.name: (f.type_, ...) - for f in schema_cls.__fields__.values() - if f.name != pk_field_name - } + # for handle pydantic 2.x migration + if int(pydantic_version.split(".")[0]) >= 2: + # pydantic 2.x + fields = { + fk: (fv.annotation, ...) + for fk, fv in schema_cls.model_fields.items() + if fk != pk_field_name + } + else: + # pydantic 1.x + fields = { + f.name: (f.type_, ...) + for f in schema_cls.__fields__.values() + if f.name != pk_field_name + } name = schema_cls.__name__ + name schema: Type[T] = create_model(__model_name=name, **fields) # type: ignore @@ -75,3 +89,4 @@ def pagination(skip: int = 0, limit: Optional[int] = max_limit) -> PAGINATION: return {"skip": skip, "limit": limit} return Depends(pagination) +