Skip to content

Commit

Permalink
Merge pull request #112 from kdcokenny/main
Browse files Browse the repository at this point in the history
Add type-checking support for SQLModel types.
  • Loading branch information
igorbenav authored Jul 3, 2024
2 parents dbc8d45 + 9bbe6cc commit 698c27a
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 92 deletions.
2 changes: 1 addition & 1 deletion docs/usage/crud.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ items = await item_crud.get_multi(db, offset=0, limit=10, sort_columns=['name'],
```python
get_joined(
db: AsyncSession,
join_model: Optional[type[DeclarativeBase]] = None,
join_model: Optional[ModelType] = None,
join_prefix: Optional[str] = None,
join_on: Optional[Union[Join, BinaryExpression]] = None,
schema_to_select: Optional[type[BaseModel]] = None,
Expand Down
19 changes: 10 additions & 9 deletions fastcrud/crud/fast_crud.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Generic, TypeVar, Union, Optional, Callable
from typing import Any, Dict, Generic, Union, Optional, Callable
from datetime import datetime, timezone

from pydantic import BaseModel, ValidationError
Expand All @@ -7,11 +7,18 @@
from sqlalchemy.sql import Join
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.engine.row import Row
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql.elements import BinaryExpression, ColumnElement
from sqlalchemy.sql.selectable import Select

from fastcrud.types import (
CreateSchemaType,
DeleteSchemaType,
ModelType,
UpdateSchemaInternalType,
UpdateSchemaType,
)

from .helper import (
_extract_matching_columns_from_schema,
_auto_detect_join_condition,
Expand All @@ -23,12 +30,6 @@

from ..endpoint.helper import _get_primary_keys

ModelType = TypeVar("ModelType", bound=DeclarativeBase)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
UpdateSchemaInternalType = TypeVar("UpdateSchemaInternalType", bound=BaseModel)
DeleteSchemaType = TypeVar("DeleteSchemaType", bound=BaseModel)


class FastCRUD(
Generic[
Expand Down Expand Up @@ -830,7 +831,7 @@ async def get_joined(
self,
db: AsyncSession,
schema_to_select: Optional[type[BaseModel]] = None,
join_model: Optional[type[DeclarativeBase]] = None,
join_model: Optional[ModelType] = None,
join_on: Optional[Union[Join, BinaryExpression]] = None,
join_prefix: Optional[str] = None,
join_schema_to_select: Optional[type[BaseModel]] = None,
Expand Down
34 changes: 24 additions & 10 deletions fastcrud/crud/helper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Any, Optional, Union, Sequence
from typing import Any, Optional, Union, Sequence, cast

from sqlalchemy import inspect
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql import ColumnElement
from pydantic import BaseModel, ConfigDict
from pydantic.functional_validators import field_validator

from fastcrud.types import ModelType

from ..endpoint.helper import _get_primary_key


Expand Down Expand Up @@ -38,7 +39,7 @@ def check_valid_join_type(cls, value):


def _extract_matching_columns_from_schema(
model: Union[type[DeclarativeBase], AliasedClass],
model: Union[ModelType, AliasedClass],
schema: Optional[type[BaseModel]],
prefix: Optional[str] = None,
alias: Optional[AliasedClass] = None,
Expand All @@ -63,6 +64,9 @@ def _extract_matching_columns_from_schema(
in the schema or all columns from the model if no schema is specified. These columns are correctly referenced
through the provided alias if one is given.
"""
if not hasattr(model, "__table__"):
raise AttributeError(f"{model.__name__} does not have a '__table__' attribute.")

model_or_alias = alias if alias else model
columns = []
temp_prefix = (
Expand Down Expand Up @@ -96,7 +100,8 @@ def _extract_matching_columns_from_schema(


def _auto_detect_join_condition(
base_model: type[DeclarativeBase], join_model: type[DeclarativeBase]
base_model: ModelType,
join_model: ModelType,
) -> Optional[ColumnElement]:
"""
Automatically detects the join condition for SQLAlchemy models based on foreign key relationships.
Expand All @@ -110,18 +115,27 @@ def _auto_detect_join_condition(
Raises:
ValueError: If the join condition cannot be automatically determined.
Example:
# Assuming User has a foreign key reference to Tier:
join_condition = auto_detect_join_condition(User, Tier)
AttributeError: If either base_model or join_model does not have a '__table__' attribute.
"""
if not hasattr(base_model, "__table__"):
raise AttributeError(
f"{base_model.__name__} does not have a '__table__' attribute."
)
if not hasattr(join_model, "__table__"):
raise AttributeError(
f"{join_model.__name__} does not have a '__table__' attribute."
)

inspector = inspect(base_model)
if inspector is not None:
fk_columns = [col for col in inspector.c if col.foreign_keys]
join_on = next(
(
base_model.__table__.c[col.name]
== join_model.__table__.c[list(col.foreign_keys)[0].column.name]
cast(
ColumnElement,
base_model.__table__.c[col.name]
== join_model.__table__.c[list(col.foreign_keys)[0].column.name],
)
for col in fk_columns
if list(col.foreign_keys)[0].column.table == join_model.__table__
),
Expand Down
19 changes: 10 additions & 9 deletions fastcrud/endpoint/crud_router.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
from typing import Type, TypeVar, Optional, Union, Sequence, Callable
from typing import Type, Optional, Union, Sequence, Callable
from enum import Enum

from fastapi import APIRouter
from sqlalchemy.orm import DeclarativeBase
from pydantic import BaseModel

from fastcrud.crud.fast_crud import FastCRUD
from fastcrud.types import (
CreateSchemaType,
DeleteSchemaType,
ModelType,
UpdateSchemaType,
)
from .endpoint_creator import EndpointCreator
from ..crud.fast_crud import FastCRUD
from .helper import FilterConfig

CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
DeleteSchemaType = TypeVar("DeleteSchemaType", bound=BaseModel)


def crud_router(
session: Callable,
model: type[DeclarativeBase],
model: ModelType,
create_schema: Type[CreateSchemaType],
update_schema: Type[UpdateSchemaType],
crud: Optional[FastCRUD] = None,
Expand Down Expand Up @@ -287,6 +287,7 @@ async def add_routes_to_router(self, ...):
# Example GET request: /mymodel/get_multi?id=1&name=example
```
"""

crud = crud or FastCRUD(
model=model,
is_deleted_column=is_deleted_column,
Expand Down
54 changes: 25 additions & 29 deletions fastcrud/endpoint/endpoint_creator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import warnings
from typing import Type, TypeVar, Optional, Callable, Sequence, Union
from typing import Type, Optional, Callable, Sequence, Union
from enum import Enum
import warnings

from fastapi import Depends, Body, Query, APIRouter
from pydantic import BaseModel, ValidationError
from pydantic import ValidationError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase

from ..crud.fast_crud import FastCRUD
from fastcrud.crud.fast_crud import FastCRUD
from fastcrud.types import (
CreateSchemaType,
DeleteSchemaType,
ModelType,
UpdateSchemaType,
)
from ..exceptions.http_exceptions import DuplicateValueException, NotFoundException
from ..paginated.helper import compute_offset
from ..paginated.response import paginated_response
Expand All @@ -23,11 +28,6 @@
_get_column_types,
)

CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
UpdateSchemaInternalType = TypeVar("UpdateSchemaInternalType", bound=BaseModel)
DeleteSchemaType = TypeVar("DeleteSchemaType", bound=BaseModel)


class EndpointCreator:
"""
Expand Down Expand Up @@ -208,7 +208,7 @@ def get_current_user(token: str = Depends(oauth2_scheme)):
def __init__(
self,
session: Callable,
model: type[DeclarativeBase],
model: ModelType,
create_schema: Type[CreateSchemaType],
update_schema: Type[UpdateSchemaType],
crud: Optional[FastCRUD] = None,
Expand Down Expand Up @@ -262,7 +262,7 @@ def __init__(
"resulting in plain endpoint names. "
"For details see:"
" https://github.com/igorbenav/fastcrud/issues/67",
DeprecationWarning
DeprecationWarning,
)
if filter_config:
if isinstance(filter_config, dict):
Expand Down Expand Up @@ -320,17 +320,14 @@ def _read_items(self):

async def endpoint(
db: AsyncSession = Depends(self.session),
page: Optional[int] = Query(
None, alias="page", description="Page number"
),
page: Optional[int] = Query(None, alias="page", description="Page number"),
items_per_page: Optional[int] = Query(
None, alias="itemsPerPage", description="Number of items per page"
),
filters: dict = Depends(dynamic_filters),
):
if not (page and items_per_page):
return await self.crud.get_multi(db, offset=0, limit=100,
**filters)
return await self.crud.get_multi(db, offset=0, limit=100, **filters)

offset = compute_offset(page=page, items_per_page=items_per_page)
crud_data = await self.crud.get_multi(
Expand All @@ -352,7 +349,7 @@ def _read_paginated(self):
"Please use _read_items with optional page and items_per_page "
"query params instead, to achieve pagination as before."
"Simple _read_items behaviour persists with no breaking changes.",
DeprecationWarning
DeprecationWarning,
)

async def endpoint(
Expand All @@ -366,8 +363,7 @@ async def endpoint(
filters: dict = Depends(dynamic_filters),
):
if not (page and items_per_page): # pragma: no cover
return await self.crud.get_multi(db, offset=0, limit=100,
**filters)
return await self.crud.get_multi(db, offset=0, limit=100, **filters)

offset = compute_offset(page=page, items_per_page=items_per_page)
crud_data = await self.crud.get_multi(
Expand Down Expand Up @@ -427,11 +423,11 @@ def _get_endpoint_path(self, operation: str):
)
path = f"{self.path}/{endpoint_name}" if endpoint_name else self.path

if operation in {'read', 'update', 'delete', 'db_delete'}:
if operation in {"read", "update", "delete", "db_delete"}:
_primary_keys_path_suffix = "/".join(
f"{{{n}}}" for n in self.primary_key_names
)
path = f'{path}/{_primary_keys_path_suffix}'
path = f"{path}/{_primary_keys_path_suffix}"

return path

Expand Down Expand Up @@ -540,7 +536,7 @@ def get_current_user(...):

if ("create" in included_methods) and ("create" not in deleted_methods):
self.router.add_api_route(
self._get_endpoint_path(operation='create'),
self._get_endpoint_path(operation="create"),
self._create_item(),
methods=["POST"],
include_in_schema=self.include_in_schema,
Expand All @@ -551,7 +547,7 @@ def get_current_user(...):

if ("read" in included_methods) and ("read" not in deleted_methods):
self.router.add_api_route(
self._get_endpoint_path(operation='read'),
self._get_endpoint_path(operation="read"),
self._read_item(),
methods=["GET"],
include_in_schema=self.include_in_schema,
Expand All @@ -562,7 +558,7 @@ def get_current_user(...):

if ("read_multi" in included_methods) and ("read_multi" not in deleted_methods):
self.router.add_api_route(
self._get_endpoint_path(operation='read_multi'),
self._get_endpoint_path(operation="read_multi"),
self._read_items(),
methods=["GET"],
include_in_schema=self.include_in_schema,
Expand All @@ -575,7 +571,7 @@ def get_current_user(...):
"read_paginated" not in deleted_methods
):
self.router.add_api_route(
self._get_endpoint_path(operation='read_paginated'),
self._get_endpoint_path(operation="read_paginated"),
self._read_paginated(),
methods=["GET"],
include_in_schema=self.include_in_schema,
Expand All @@ -586,7 +582,7 @@ def get_current_user(...):

if ("update" in included_methods) and ("update" not in deleted_methods):
self.router.add_api_route(
self._get_endpoint_path(operation='update'),
self._get_endpoint_path(operation="update"),
self._update_item(),
methods=["PATCH"],
include_in_schema=self.include_in_schema,
Expand All @@ -596,7 +592,7 @@ def get_current_user(...):
)

if ("delete" in included_methods) and ("delete" not in deleted_methods):
path = self._get_endpoint_path(operation='delete')
path = self._get_endpoint_path(operation="delete")
self.router.add_api_route(
path,
self._delete_item(),
Expand All @@ -613,7 +609,7 @@ def get_current_user(...):
and self.delete_schema
):
self.router.add_api_route(
self._get_endpoint_path(operation='db_delete'),
self._get_endpoint_path(operation="db_delete"),
self._db_delete(),
methods=["DELETE"],
include_in_schema=self.include_in_schema,
Expand Down
Loading

0 comments on commit 698c27a

Please sign in to comment.