Skip to content

Commit

Permalink
pydantic v2 support (#245)
Browse files Browse the repository at this point in the history
* pydantic_v2_support

* pin to Piccolo v1

* add comment

* remove KeyError Exception

* upgrade coverage

* fix type warnings with `nested`

* fix mypy warnings

* update black

* replacement for `outer_type`

* change assertion

---------

Co-authored-by: Daniel Townsend <[email protected]>
  • Loading branch information
sinisaos and dantownsend authored Sep 4, 2023
1 parent 5a7ce77 commit fde7a4f
Show file tree
Hide file tree
Showing 19 changed files with 226 additions and 134 deletions.
50 changes: 28 additions & 22 deletions piccolo_api/crud/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from piccolo.query.methods.select import Select
from piccolo.table import Table
from piccolo.utils.encoding import dump_json
from pydantic.error_wrappers import ValidationError
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Route, Router
Expand All @@ -38,7 +37,7 @@
)

from .exceptions import MalformedQuery, db_exception_handler
from .serializers import Config, create_pydantic_model
from .serializers import create_pydantic_model
from .validators import Validators, apply_validators

if t.TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -307,7 +306,7 @@ def _pydantic_model_output(
self,
include_readable: bool = False,
include_columns: t.Tuple[Column, ...] = (),
nested: t.Union[bool, t.Tuple[Column, ...]] = False,
nested: t.Union[bool, t.Tuple[ForeignKey, ...]] = False,
) -> t.Type[pydantic.BaseModel]:
return create_pydantic_model(
self.table,
Expand Down Expand Up @@ -343,8 +342,8 @@ def pydantic_model_plural(
self,
include_readable=False,
include_columns: t.Tuple[Column, ...] = (),
nested: t.Union[bool, t.Tuple[Column, ...]] = False,
):
nested: t.Union[bool, t.Tuple[ForeignKey, ...]] = False,
) -> t.Type[pydantic.BaseModel]:
"""
This is for when we want to serialise many copies of the model.
"""
Expand All @@ -358,7 +357,9 @@ def pydantic_model_plural(
)
return pydantic.create_model(
str(self.table.__name__) + "Plural",
__config__=Config,
__config__=pydantic.config.ConfigDict(
arbitrary_types_allowed=True
),
rows=(t.List[base_model], None),
)

Expand All @@ -367,7 +368,7 @@ async def get_schema(self, request: Request) -> JSONResponse:
"""
Return a representation of the model, so a UI can generate a form.
"""
return JSONResponse(self.pydantic_model.schema())
return JSONResponse(self.pydantic_model.model_json_schema())

###########################################################################

Expand Down Expand Up @@ -713,7 +714,7 @@ def _apply_filters(
"""
fields = params.fields
if fields:
model_dict = self.pydantic_model_optional(**fields).dict()
model_dict = self.pydantic_model_optional(**fields).model_dump()
for field_name in fields.keys():
value = model_dict.get(field_name, ...)
if value is ...:
Expand Down Expand Up @@ -778,7 +779,9 @@ async def get_all(
nested: t.Union[bool, t.Tuple[Column, ...]]
if visible_fields:
nested = tuple(
i for i in visible_fields if len(i._meta.call_chain) > 0
i._meta.call_chain[-1]
for i in visible_fields
if len(i._meta.call_chain) > 0
)
else:
visible_fields = self.table._meta.columns
Expand Down Expand Up @@ -865,7 +868,7 @@ async def get_all(
include_readable=include_readable,
include_columns=tuple(visible_fields),
nested=nested,
)(rows=rows).json()
)(rows=rows).model_dump_json()
return CustomJSONResponse(json, headers=headers)

###########################################################################
Expand Down Expand Up @@ -894,19 +897,19 @@ async def post_single(
cleaned_data = self._clean_data(data)
try:
model = self.pydantic_model(**cleaned_data)
except ValidationError as exception:
except pydantic.ValidationError as exception:
return Response(str(exception), status_code=400)

if issubclass(self.table, BaseUser):
try:
user = await self.table.create_user(**model.dict())
user = await self.table.create_user(**model.model_dump())
json = dump_json({"id": user.id})
return CustomJSONResponse(json, status_code=201)
except Exception as e:
return Response(f"Error: {e}", status_code=400)
else:
try:
row = self.table(**model.dict())
row = self.table(**model.model_dump())
if self._hook_map:
row = await execute_post_hooks(
hooks=self._hook_map,
Expand Down Expand Up @@ -969,7 +972,7 @@ async def get_new(self, request: Request) -> CustomJSONResponse:
row_dict.pop(column_name)

return CustomJSONResponse(
self.pydantic_model_optional(**row_dict).json()
self.pydantic_model_optional(**row_dict).model_dump_json()
)

###########################################################################
Expand Down Expand Up @@ -1053,11 +1056,13 @@ async def get_single(self, request: Request, row_id: PK_TYPES) -> Response:
return Response(str(exception), status_code=400)

# Visible fields
nested: t.Union[bool, t.Tuple[Column, ...]]
nested: t.Union[bool, t.Tuple[ForeignKey, ...]]
visible_fields = split_params.visible_fields
if visible_fields:
nested = tuple(
i for i in visible_fields if len(i._meta.call_chain) > 0
i._meta.call_chain[-1]
for i in visible_fields
if len(i._meta.call_chain) > 0
)
else:
visible_fields = self.table._meta.columns
Expand Down Expand Up @@ -1098,7 +1103,7 @@ async def get_single(self, request: Request, row_id: PK_TYPES) -> Response:
include_readable=split_params.include_readable,
include_columns=tuple(visible_fields),
nested=nested,
)(**row).json()
)(**row).model_dump_json()
)

@apply_validators
Expand All @@ -1113,7 +1118,7 @@ async def put_single(

try:
model = self.pydantic_model(**cleaned_data)
except ValidationError as exception:
except pydantic.ValidationError as exception:
return Response(str(exception), status_code=400)

cls = self.table
Expand All @@ -1123,7 +1128,6 @@ async def put_single(
}

try:

await cls.update(values).where(
cls._meta.primary_key == row_id
).run()
Expand All @@ -1143,7 +1147,7 @@ async def patch_single(

try:
model = self.pydantic_model_optional(**cleaned_data)
except ValidationError as exception:
except pydantic.ValidationError as exception:
return Response(str(exception), status_code=400)

cls = self.table
Expand All @@ -1168,7 +1172,9 @@ async def patch_single(
for key in data.keys()
}
except AttributeError:
unrecognised_keys = set(data.keys()) - set(model.dict().keys())
unrecognised_keys = set(data.keys()) - set(
model.model_dump().keys()
)
return Response(
f"Unrecognised keys - {unrecognised_keys}.",
status_code=400,
Expand All @@ -1195,7 +1201,7 @@ async def patch_single(
)
assert new_row
return CustomJSONResponse(
self.pydantic_model(**new_row).json()
self.pydantic_model(**new_row).model_dump_json()
)
except ValueError:
return Response(
Expand Down
4 changes: 2 additions & 2 deletions piccolo_api/crud/serializers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from piccolo.utils.pydantic import Config, create_pydantic_model # noqa
from piccolo.utils.pydantic import create_pydantic_model # noqa

__all__ = ["Config", "create_pydantic_model"]
__all__ = ["create_pydantic_model"]
2 changes: 1 addition & 1 deletion piccolo_api/crud/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def apply_validators(function):
:class:`PiccoloCRUD`.
"""

async def run_validators(*args, **kwargs):
async def run_validators(*args, **kwargs) -> None:
piccolo_crud: PiccoloCRUD = args[0]
validators = piccolo_crud.validators

Expand Down
37 changes: 35 additions & 2 deletions piccolo_api/fastapi/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,36 @@ class ReferencesModel(BaseModel):
references: t.List[ReferenceModel]


def _get_type(type_: t.Type) -> t.Type:
"""
Extract the inner type from an optional if necessary, otherwise return
the type as is.
For example::
>>> get_type(Optional[int])
int
>>> get_type(int)
int
>>> get_type(list[str])
list[str]
"""
origin = t.get_origin(type_)

# Note: even if `t.Optional` is passed in, the origin is still a
# `t.Union`.
if origin is t.Union:
args = t.get_args(type_)

if len(args) == 2 and None in args:
return [i for i in args if i is not None][0]

return type_


class FastAPIWrapper:
"""
Wraps ``PiccoloCRUD`` so it can easily be integrated into FastAPI.
Expand Down Expand Up @@ -413,8 +443,11 @@ def modify_signature(
),
]

for field_name, _field in model.__fields__.items():
type_ = _field.outer_type_
for field_name, _field in model.model_fields.items():
annotation = _field.annotation
assert annotation is not None
type_ = _get_type(annotation)

parameters.append(
Parameter(
name=field_name,
Expand Down
2 changes: 1 addition & 1 deletion piccolo_api/media/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ async def get_file_keys(self) -> t.List[str]:
Returns the file key for each file we have stored.
"""
file_keys = []
for (_, _, filenames) in os.walk(self.media_path):
for _, _, filenames in os.walk(self.media_path):
file_keys.extend(filenames)
break

Expand Down
14 changes: 6 additions & 8 deletions piccolo_api/media/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def __init__(
column: t.Union[Text, Varchar, Array],
bucket_name: str,
folder_name: t.Optional[str] = None,
connection_kwargs: t.Dict[str, t.Any] = None,
connection_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
sign_urls: bool = True,
signed_url_expiry: int = 3600,
upload_metadata: t.Dict[str, t.Any] = None,
upload_metadata: t.Optional[t.Dict[str, t.Any]] = None,
executor: t.Optional[Executor] = None,
allowed_extensions: t.Optional[t.Sequence[str]] = ALLOWED_EXTENSIONS,
allowed_characters: t.Optional[t.Sequence[str]] = ALLOWED_CHARACTERS,
Expand Down Expand Up @@ -130,9 +130,9 @@ def __init__(
self.boto3 = boto3

self.bucket_name = bucket_name
self.upload_metadata = upload_metadata
self.upload_metadata = upload_metadata or {}
self.folder_name = folder_name
self.connection_kwargs = connection_kwargs
self.connection_kwargs = connection_kwargs or {}
self.sign_urls = sign_urls
self.signed_url_expiry = signed_url_expiry
self.executor = executor or ThreadPoolExecutor(max_workers=10)
Expand Down Expand Up @@ -181,7 +181,7 @@ def store_file_sync(
file_key = self.generate_file_key(file_name=file_name, user=user)
extension = file_key.rsplit(".", 1)[-1]
client = self.get_client()
upload_metadata: t.Dict[str, t.Any] = self.upload_metadata or {}
upload_metadata: t.Dict[str, t.Any] = self.upload_metadata

if extension in CONTENT_TYPE:
upload_metadata["ContentType"] = CONTENT_TYPE[extension]
Expand Down Expand Up @@ -374,9 +374,7 @@ def __hash__(self):
return hash(
(
"s3",
self.connection_kwargs.get("endpoint_url")
if self.connection_kwargs
else None,
self.connection_kwargs.get("endpoint_url"),
self.bucket_name,
self.folder_name,
)
Expand Down
1 change: 0 additions & 1 deletion piccolo_api/shared/auth/junction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def __init__(self, backends: t.Sequence[AuthenticationBackend]):
async def authenticate(
self, conn: HTTPConnection
) -> t.Optional[t.Tuple[AuthCredentials, BaseUser]]:

for backend in self.backends:
try:
response = await backend.authenticate(conn=conn)
Expand Down
4 changes: 1 addition & 3 deletions piccolo_api/shared/middleware/junction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@ def __init__(self, *routers: Router) -> None:
async def __call__(self, scope: Scope, receive: Receive, send: Send):
for router in self.routers:
try:
asgi = await router(scope, receive=receive, send=send)
await router(scope, receive=receive, send=send)
except HTTPException as exception:
if exception.status_code != 404:
raise exception
else:
if getattr(asgi, "status_code", None) == 404:
continue
return

raise HTTPException(status_code=404)
1 change: 0 additions & 1 deletion piccolo_api/token_auth/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ async def get_token(self, username: str, password: str) -> t.Optional[str]:


class TokenAuthLoginEndpoint(HTTPEndpoint):

token_provider: TokenProvider = PiccoloTokenProvider()

async def post(self, request: Request) -> Response:
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.black]
line-length = 79
target-version = ['py37', 'py38', 'py39']
target-version = ['py38', 'py39', 'py310', 'py311']

[tool.isort]
profile = "black"
Expand All @@ -9,7 +9,6 @@ line_length = 79
[tool.mypy]
[[tool.mypy.overrides]]
module = [
"asyncpg.pgproto.pgproto",
"asyncpg.exceptions",
"jinja2",
"uvicorn",
Expand Down
8 changes: 4 additions & 4 deletions requirements/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
black>=21.7b0
isort==5.10.1
black==23.7.0
isort==5.12.0
twine==4.0.2
mypy==0.950
mypy==1.5.1
pip-upgrader==1.4.15
wheel==0.40.0
wheel==0.41.2
6 changes: 3 additions & 3 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Jinja2>=2.11.0
piccolo[postgres]>=0.104.0
pydantic[email]>=1.6,<2.0
piccolo[postgres]>=1.0a1
pydantic[email]>=2.0
python-multipart>=0.0.5
fastapi>=0.87.0,<0.100.0
fastapi>=0.100.0
PyJWT>=2.0.0
httpx>=0.20.0
Loading

0 comments on commit fde7a4f

Please sign in to comment.