Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pydantic v2 support #245

Merged
merged 10 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 28 additions & 22 deletions piccolo_api/crud/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from piccolo.query.methods.select import Select
from piccolo.table import Table
from piccolo.utils.encoding import dump_json
from pydantic.error_wrappers import ValidationError
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Route, Router
Expand All @@ -38,7 +37,7 @@
)

from .exceptions import MalformedQuery, db_exception_handler
from .serializers import Config, create_pydantic_model
from .serializers import create_pydantic_model
from .validators import Validators, apply_validators

if t.TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -307,7 +306,7 @@ def _pydantic_model_output(
self,
include_readable: bool = False,
include_columns: t.Tuple[Column, ...] = (),
nested: t.Union[bool, t.Tuple[Column, ...]] = False,
nested: t.Union[bool, t.Tuple[ForeignKey, ...]] = False,
) -> t.Type[pydantic.BaseModel]:
return create_pydantic_model(
self.table,
Expand Down Expand Up @@ -343,8 +342,8 @@ def pydantic_model_plural(
self,
include_readable=False,
include_columns: t.Tuple[Column, ...] = (),
nested: t.Union[bool, t.Tuple[Column, ...]] = False,
):
nested: t.Union[bool, t.Tuple[ForeignKey, ...]] = False,
) -> t.Type[pydantic.BaseModel]:
"""
This is for when we want to serialise many copies of the model.
"""
Expand All @@ -358,7 +357,9 @@ def pydantic_model_plural(
)
return pydantic.create_model(
str(self.table.__name__) + "Plural",
__config__=Config,
__config__=pydantic.config.ConfigDict(
arbitrary_types_allowed=True
),
rows=(t.List[base_model], None),
)

Expand All @@ -367,7 +368,7 @@ async def get_schema(self, request: Request) -> JSONResponse:
"""
Return a representation of the model, so a UI can generate a form.
"""
return JSONResponse(self.pydantic_model.schema())
return JSONResponse(self.pydantic_model.model_json_schema())

###########################################################################

Expand Down Expand Up @@ -713,7 +714,7 @@ def _apply_filters(
"""
fields = params.fields
if fields:
model_dict = self.pydantic_model_optional(**fields).dict()
model_dict = self.pydantic_model_optional(**fields).model_dump()
for field_name in fields.keys():
value = model_dict.get(field_name, ...)
if value is ...:
Expand Down Expand Up @@ -778,7 +779,9 @@ async def get_all(
nested: t.Union[bool, t.Tuple[Column, ...]]
if visible_fields:
nested = tuple(
i for i in visible_fields if len(i._meta.call_chain) > 0
i._meta.call_chain[-1]
for i in visible_fields
if len(i._meta.call_chain) > 0
)
else:
visible_fields = self.table._meta.columns
Expand Down Expand Up @@ -865,7 +868,7 @@ async def get_all(
include_readable=include_readable,
include_columns=tuple(visible_fields),
nested=nested,
)(rows=rows).json()
)(rows=rows).model_dump_json()
return CustomJSONResponse(json, headers=headers)

###########################################################################
Expand Down Expand Up @@ -894,19 +897,19 @@ async def post_single(
cleaned_data = self._clean_data(data)
try:
model = self.pydantic_model(**cleaned_data)
except ValidationError as exception:
except pydantic.ValidationError as exception:
return Response(str(exception), status_code=400)

if issubclass(self.table, BaseUser):
try:
user = await self.table.create_user(**model.dict())
user = await self.table.create_user(**model.model_dump())
json = dump_json({"id": user.id})
return CustomJSONResponse(json, status_code=201)
except Exception as e:
return Response(f"Error: {e}", status_code=400)
else:
try:
row = self.table(**model.dict())
row = self.table(**model.model_dump())
if self._hook_map:
row = await execute_post_hooks(
hooks=self._hook_map,
Expand Down Expand Up @@ -969,7 +972,7 @@ async def get_new(self, request: Request) -> CustomJSONResponse:
row_dict.pop(column_name)

return CustomJSONResponse(
self.pydantic_model_optional(**row_dict).json()
self.pydantic_model_optional(**row_dict).model_dump_json()
)

###########################################################################
Expand Down Expand Up @@ -1053,11 +1056,13 @@ async def get_single(self, request: Request, row_id: PK_TYPES) -> Response:
return Response(str(exception), status_code=400)

# Visible fields
nested: t.Union[bool, t.Tuple[Column, ...]]
nested: t.Union[bool, t.Tuple[ForeignKey, ...]]
visible_fields = split_params.visible_fields
if visible_fields:
nested = tuple(
i for i in visible_fields if len(i._meta.call_chain) > 0
i._meta.call_chain[-1]
for i in visible_fields
if len(i._meta.call_chain) > 0
)
else:
visible_fields = self.table._meta.columns
Expand Down Expand Up @@ -1098,7 +1103,7 @@ async def get_single(self, request: Request, row_id: PK_TYPES) -> Response:
include_readable=split_params.include_readable,
include_columns=tuple(visible_fields),
nested=nested,
)(**row).json()
)(**row).model_dump_json()
)

@apply_validators
Expand All @@ -1113,7 +1118,7 @@ async def put_single(

try:
model = self.pydantic_model(**cleaned_data)
except ValidationError as exception:
except pydantic.ValidationError as exception:
return Response(str(exception), status_code=400)

cls = self.table
Expand All @@ -1123,7 +1128,6 @@ async def put_single(
}

try:

await cls.update(values).where(
cls._meta.primary_key == row_id
).run()
Expand All @@ -1143,7 +1147,7 @@ async def patch_single(

try:
model = self.pydantic_model_optional(**cleaned_data)
except ValidationError as exception:
except pydantic.ValidationError as exception:
return Response(str(exception), status_code=400)

cls = self.table
Expand All @@ -1168,7 +1172,9 @@ async def patch_single(
for key in data.keys()
}
except AttributeError:
unrecognised_keys = set(data.keys()) - set(model.dict().keys())
unrecognised_keys = set(data.keys()) - set(
model.model_dump().keys()
)
return Response(
f"Unrecognised keys - {unrecognised_keys}.",
status_code=400,
Expand All @@ -1195,7 +1201,7 @@ async def patch_single(
)
assert new_row
return CustomJSONResponse(
self.pydantic_model(**new_row).json()
self.pydantic_model(**new_row).model_dump_json()
)
except ValueError:
return Response(
Expand Down
4 changes: 2 additions & 2 deletions piccolo_api/crud/serializers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from piccolo.utils.pydantic import Config, create_pydantic_model # noqa
from piccolo.utils.pydantic import create_pydantic_model # noqa

__all__ = ["Config", "create_pydantic_model"]
__all__ = ["create_pydantic_model"]
2 changes: 1 addition & 1 deletion piccolo_api/crud/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def apply_validators(function):
:class:`PiccoloCRUD`.
"""

async def run_validators(*args, **kwargs):
async def run_validators(*args, **kwargs) -> None:
piccolo_crud: PiccoloCRUD = args[0]
validators = piccolo_crud.validators

Expand Down
37 changes: 35 additions & 2 deletions piccolo_api/fastapi/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,36 @@ class ReferencesModel(BaseModel):
references: t.List[ReferenceModel]


def _get_type(type_: t.Type) -> t.Type:
"""
Extract the inner type from an optional if necessary, otherwise return
the type as is.
For example::
>>> get_type(Optional[int])
int
>>> get_type(int)
int
>>> get_type(list[str])
list[str]
"""
origin = t.get_origin(type_)

# Note: even if `t.Optional` is passed in, the origin is still a
# `t.Union`.
if origin is t.Union:
args = t.get_args(type_)

if len(args) == 2 and None in args:
return [i for i in args if i is not None][0]

return type_


class FastAPIWrapper:
"""
Wraps ``PiccoloCRUD`` so it can easily be integrated into FastAPI.
Expand Down Expand Up @@ -413,8 +443,11 @@ def modify_signature(
),
]

for field_name, _field in model.__fields__.items():
type_ = _field.outer_type_
for field_name, _field in model.model_fields.items():
annotation = _field.annotation
assert annotation is not None
type_ = _get_type(annotation)

parameters.append(
Parameter(
name=field_name,
Expand Down
2 changes: 1 addition & 1 deletion piccolo_api/media/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ async def get_file_keys(self) -> t.List[str]:
Returns the file key for each file we have stored.
"""
file_keys = []
for (_, _, filenames) in os.walk(self.media_path):
for _, _, filenames in os.walk(self.media_path):
file_keys.extend(filenames)
break

Expand Down
14 changes: 6 additions & 8 deletions piccolo_api/media/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def __init__(
column: t.Union[Text, Varchar, Array],
bucket_name: str,
folder_name: t.Optional[str] = None,
connection_kwargs: t.Dict[str, t.Any] = None,
connection_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
sign_urls: bool = True,
signed_url_expiry: int = 3600,
upload_metadata: t.Dict[str, t.Any] = None,
upload_metadata: t.Optional[t.Dict[str, t.Any]] = None,
executor: t.Optional[Executor] = None,
allowed_extensions: t.Optional[t.Sequence[str]] = ALLOWED_EXTENSIONS,
allowed_characters: t.Optional[t.Sequence[str]] = ALLOWED_CHARACTERS,
Expand Down Expand Up @@ -130,9 +130,9 @@ def __init__(
self.boto3 = boto3

self.bucket_name = bucket_name
self.upload_metadata = upload_metadata
self.upload_metadata = upload_metadata or {}
self.folder_name = folder_name
self.connection_kwargs = connection_kwargs
self.connection_kwargs = connection_kwargs or {}
self.sign_urls = sign_urls
self.signed_url_expiry = signed_url_expiry
self.executor = executor or ThreadPoolExecutor(max_workers=10)
Expand Down Expand Up @@ -181,7 +181,7 @@ def store_file_sync(
file_key = self.generate_file_key(file_name=file_name, user=user)
extension = file_key.rsplit(".", 1)[-1]
client = self.get_client()
upload_metadata: t.Dict[str, t.Any] = self.upload_metadata or {}
upload_metadata: t.Dict[str, t.Any] = self.upload_metadata

if extension in CONTENT_TYPE:
upload_metadata["ContentType"] = CONTENT_TYPE[extension]
Expand Down Expand Up @@ -374,9 +374,7 @@ def __hash__(self):
return hash(
(
"s3",
self.connection_kwargs.get("endpoint_url")
if self.connection_kwargs
else None,
self.connection_kwargs.get("endpoint_url"),
self.bucket_name,
self.folder_name,
)
Expand Down
1 change: 0 additions & 1 deletion piccolo_api/shared/auth/junction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def __init__(self, backends: t.Sequence[AuthenticationBackend]):
async def authenticate(
self, conn: HTTPConnection
) -> t.Optional[t.Tuple[AuthCredentials, BaseUser]]:

for backend in self.backends:
try:
response = await backend.authenticate(conn=conn)
Expand Down
4 changes: 1 addition & 3 deletions piccolo_api/shared/middleware/junction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@ def __init__(self, *routers: Router) -> None:
async def __call__(self, scope: Scope, receive: Receive, send: Send):
for router in self.routers:
try:
asgi = await router(scope, receive=receive, send=send)
await router(scope, receive=receive, send=send)
except HTTPException as exception:
if exception.status_code != 404:
raise exception
else:
if getattr(asgi, "status_code", None) == 404:
continue
return

raise HTTPException(status_code=404)
1 change: 0 additions & 1 deletion piccolo_api/token_auth/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ async def get_token(self, username: str, password: str) -> t.Optional[str]:


class TokenAuthLoginEndpoint(HTTPEndpoint):

token_provider: TokenProvider = PiccoloTokenProvider()

async def post(self, request: Request) -> Response:
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.black]
line-length = 79
target-version = ['py37', 'py38', 'py39']
target-version = ['py38', 'py39', 'py310', 'py311']

[tool.isort]
profile = "black"
Expand All @@ -9,7 +9,6 @@ line_length = 79
[tool.mypy]
[[tool.mypy.overrides]]
module = [
"asyncpg.pgproto.pgproto",
"asyncpg.exceptions",
"jinja2",
"uvicorn",
Expand Down
8 changes: 4 additions & 4 deletions requirements/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
black>=21.7b0
isort==5.10.1
black==23.7.0
isort==5.12.0
twine==4.0.2
mypy==0.950
mypy==1.5.1
pip-upgrader==1.4.15
wheel==0.40.0
wheel==0.41.2
6 changes: 3 additions & 3 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Jinja2>=2.11.0
piccolo[postgres]>=0.104.0
pydantic[email]>=1.6,<2.0
piccolo[postgres]>=1.0a1
pydantic[email]>=2.0
python-multipart>=0.0.5
fastapi>=0.87.0,<0.100.0
fastapi>=0.100.0
PyJWT>=2.0.0
httpx>=0.20.0
Loading
Loading