Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Support simple type field matching query #61

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion fastapi_crudrouter/core/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi.types import DecoratedCallable

from ._types import T, DEPENDENCIES
from ._utils import pagination_factory, schema_factory
from ._utils import pagination_factory, schema_factory, query_factory, sort_factory

NOT_FOUND = HTTPException(404, "Item not found")

Expand Down Expand Up @@ -34,6 +34,9 @@ def __init__(

self.schema = schema
self.pagination = pagination_factory(max_limit=paginate)
self.filter = query_factory(self.schema)
self.sort = sort_factory(self.schema)

self._pk: str = self._pk if hasattr(self, "_pk") else "id"
self.create_schema = (
create_schema
Expand Down
4 changes: 3 additions & 1 deletion fastapi_crudrouter/core/_types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Dict, TypeVar, Optional, Sequence
from typing import Dict, TypeVar, Optional, Sequence, Union

from fastapi.params import Depends
from pydantic import BaseModel

PAGINATION = Dict[str, Optional[int]]
FILTER = Dict[str, Optional[Union[int, float, str, bool]]]
SORT = Dict[str, str]
PYDANTIC_SCHEMA = BaseModel

T = TypeVar("T", bound=BaseModel)
Expand Down
59 changes: 56 additions & 3 deletions fastapi_crudrouter/core/_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from typing import Optional, Type, TypeVar, Any
from typing import Optional, Type, TypeVar, Any, List # noqa: F401

from fastapi import Depends, HTTPException
from fastapi import Depends, HTTPException, Query # noqa: F401
from pydantic import create_model

from ._types import PAGINATION, PYDANTIC_SCHEMA
from ._types import PAGINATION, PYDANTIC_SCHEMA, SORT, FILTER # noqa: F401

T = TypeVar("T", bound=PYDANTIC_SCHEMA)
FILTER_MAPPING = {
"int": int,
"float": float,
"bool": bool,
"str": str,
"ConstrainedStrValue": str,
}


def get_pk_type(schema: Type[PYDANTIC_SCHEMA], pk_field: str) -> Any:
Expand Down Expand Up @@ -71,3 +78,49 @@ def pagination(skip: int = 0, limit: Optional[int] = max_limit) -> PAGINATION:
return {"skip": skip, "limit": limit}

return Depends(pagination)


def query_factory(schema: Type[T]) -> Any:
"""
Dynamically builds a Fastapi query dependency based on all available field in the
"""

_str = "{}: Optional[{}] = Query(None)"
args_str = ", ".join(
[
_str.format(name, FILTER_MAPPING[field.type_.__name__].__name__)
for name, field in schema.__fields__.items()
if field.type_.__name__ in FILTER_MAPPING
]
)

_str = "{}={}"
return_str = ", ".join(
[
_str.format(name, field.name)
for name, field in schema.__fields__.items()
if field.type_.__name__ in FILTER_MAPPING
]
)

filter_func_src = f"""
def filter_func({args_str}) -> FILTER:
ret = dict({return_str})
return {{k:v for k, v in ret.items() if v is not None}}
"""

exec(filter_func_src, globals(), locals())
return Depends(locals().get("filter_func"))


def sort_factory(schema: Type[T]) -> Any:
fields = [field.name for field in schema.__fields__.values()]

def sort_func(
sort_: str = Query(None, alias="sort", enum=fields),
direction: str = Query(None, enum=["asc", "desc"]),
) -> SORT:
ret = {"sort": sort_, "reverse": direction == "desc"}
return {k: v for k, v in ret.items() if v} # type: ignore

return Depends(sort_func)
19 changes: 16 additions & 3 deletions fastapi_crudrouter/core/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
from fastapi import HTTPException

from . import CRUDGenerator, NOT_FOUND, _utils
from ._types import PAGINATION, PYDANTIC_SCHEMA, DEPENDENCIES
from ._types import PAGINATION, PYDANTIC_SCHEMA, DEPENDENCIES, FILTER, SORT

try:
from sqlalchemy.sql.schema import Table
from sqlalchemy import desc
from databases.core import Database
except ImportError:
databases_installed = False
Expand Down Expand Up @@ -75,10 +76,20 @@ def __init__(
def _get_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST:
async def route(
pagination: PAGINATION = self.pagination,
filter_: FILTER = self.filter,
sort_: SORT = self.sort,
) -> List[Model]:
skip, limit = pagination.get("skip"), pagination.get("limit")

query = self.table.select().limit(limit).offset(skip)

if sort_:
field = getattr(self.table.c, sort_.get("sort", self._pk))
order = desc(field) if sort_.get("reverse", False) else field
query = query.order_by(order)

for col, val in filter_.items():
query = query.where(self.table.c[col] == val)

return await self.db.fetch_all(query)

return route
Expand Down Expand Up @@ -129,7 +140,9 @@ async def route() -> List[Model]:
query = self.table.delete()
await self.db.execute(query=query)

return await self._get_all()(pagination={"skip": 0, "limit": None})
return await self._get_all()(
pagination={"skip": 0, "limit": None}, filter_={}, sort_={}
)

return route

Expand Down
48 changes: 40 additions & 8 deletions fastapi_crudrouter/core/mem.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Callable, List, Type, cast, Optional, Union

from . import CRUDGenerator, NOT_FOUND
from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA
from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA, FILTER, SORT

CALLABLE = Callable[..., SCHEMA]
CALLABLE_LIST = Callable[..., List[SCHEMA]]
Expand Down Expand Up @@ -44,15 +44,18 @@ def __init__(
self._id = 1

def _get_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST:
def route(pagination: PAGINATION = self.pagination) -> List[SCHEMA]:
def route(
pagination: PAGINATION = self.pagination,
filter_: FILTER = self.filter,
sort_: SORT = self.sort,
) -> List[SCHEMA]:
skip, limit = pagination.get("skip"), pagination.get("limit")
skip = cast(int, skip)

return (
self.models[skip:]
if limit is None
else self.models[skip : skip + limit]
)
models = self._get_filtered_list(self.models, filter_)
models = self._get_sorted_list(models, sort_)

return models[skip:] if limit is None else models[skip : skip + limit]

return route

Expand All @@ -69,7 +72,7 @@ def route(item_id: int) -> SCHEMA:
def _create(self, *args: Any, **kwargs: Any) -> CALLABLE:
def route(model: self.create_schema) -> SCHEMA: # type: ignore
model_dict = model.dict()
model_dict["id"] = self._get_next_id()
model_dict[self._pk] = self._get_next_id()
ready_model = self.schema(**model_dict)
self.models.append(ready_model)
return ready_model
Expand Down Expand Up @@ -112,3 +115,32 @@ def _get_next_id(self) -> int:
self._id += 1

return id_

def _get_sorted_list(self, models: List[SCHEMA], sort_: SORT) -> List[SCHEMA]:
if not sort_:
return models

field = sort_.get("sort", self._pk)
models.sort(
reverse=bool(sort_.get("reverse", False)), key=lambda x: getattr(x, field) # type: ignore
)

return models

@staticmethod
def _get_filtered_list(models: List[SCHEMA], filters_: FILTER) -> List[SCHEMA]:
if not filters_:
return models

return [
model
for model in models
if MemoryCRUDRouter._check_filters(model, filters_)
]

@staticmethod
def _check_filters(model: SCHEMA, filters_: FILTER) -> bool:
for k, v in filters_.items():
if getattr(model, k) != v:
return False
return True
24 changes: 17 additions & 7 deletions fastapi_crudrouter/core/ormar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from fastapi import HTTPException

from . import CRUDGenerator, NOT_FOUND, _utils
from ._types import DEPENDENCIES, PAGINATION
from ._types import DEPENDENCIES, PAGINATION, FILTER, SORT

try:
from ormar import Model, NoMatch
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(
update_route: Union[bool, DEPENDENCIES] = True,
delete_one_route: Union[bool, DEPENDENCIES] = True,
delete_all_route: Union[bool, DEPENDENCIES] = True,
**kwargs: Any
**kwargs: Any,
) -> None:
assert ormar_installed, "Ormar must be installed to use the OrmarCRUDRouter."

Expand All @@ -61,19 +61,27 @@ def __init__(
update_route=update_route,
delete_one_route=delete_one_route,
delete_all_route=delete_all_route,
**kwargs
**kwargs,
)

self._INTEGRITY_ERROR = self._get_integrity_error_type()

def _get_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST:
async def route(
pagination: PAGINATION = self.pagination,
filter_: FILTER = self.filter,
sort_: SORT = self.sort,
) -> List[Optional[Model]]:
skip, limit = pagination.get("skip"), pagination.get("limit")
query = self.schema.objects.offset(cast(int, skip))
if limit:
query = query.limit(limit)
query = self.schema.objects.filter(**filter_) # type: ignore

if sort_:
field = sort_.get("sort", self._pk)
order = f"-{field}" if sort_.get("reverse", False) else field
query = query.order_by(order)

query = query.limit(limit).offset(cast(int, skip)) # type: ignore

return await query.all()

return route
Expand Down Expand Up @@ -122,7 +130,9 @@ async def route(
def _delete_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST:
async def route() -> List[Optional[Model]]:
await self.schema.objects.delete(each=True)
return await self._get_all()(pagination={"skip": 0, "limit": None})
return await self._get_all()(
pagination={"skip": 0, "limit": None}, filter_={}, sort_={}
)

return route

Expand Down
24 changes: 17 additions & 7 deletions fastapi_crudrouter/core/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any, Callable, List, Type, Generator, Optional, Union
from typing import Any, Callable, Generator, List
from typing import Optional, Type, Union

from fastapi import Depends, HTTPException

from . import CRUDGenerator, NOT_FOUND, _utils
from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA
from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA, FILTER, SORT

try:
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -65,19 +66,26 @@ def _get_all(self, *args: Any, **kwargs: Any) -> Callable[..., List[Model]]:
def route(
db: Session = Depends(self.db_func),
pagination: PAGINATION = self.pagination,
filter_: FILTER = self.filter,
sort_: SORT = self.sort,
) -> List[Model]:
skip, limit = pagination.get("skip"), pagination.get("limit")
query = db.query(self.db_model).filter_by(**filter_)

db_models: List[Model] = (
db.query(self.db_model).limit(limit).offset(skip).all()
)
if sort_:
field = getattr(self.db_model, sort_.get("sort", self._pk))
order = field.desc() if sort_.get("reverse", False) else field
query = query.order_by(order)

db_models: List[Model] = query.limit(limit).offset(skip).all()
return db_models

return route

def _get_one(self, *args: Any, **kwargs: Any) -> Callable[..., Model]:
def route(
item_id: self._pk_type, db: Session = Depends(self.db_func) # type: ignore
item_id: Optional[self._pk_type] = None, # type: ignore
db: Session = Depends(self.db_func),
) -> Model:
model: Model = db.query(self.db_model).get(item_id)

Expand Down Expand Up @@ -133,7 +141,9 @@ def route(db: Session = Depends(self.db_func)) -> List[Model]:
db.query(self.db_model).delete()
db.commit()

return self._get_all()(db=db, pagination={"skip": 0, "limit": None})
return self._get_all()(
db=db, pagination={"skip": 0, "limit": None}, filter_={}, sort_={}
)

return route

Expand Down
21 changes: 16 additions & 5 deletions fastapi_crudrouter/core/tortoise.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Callable, List, Type, cast, Coroutine, Optional, Union

from . import CRUDGenerator, NOT_FOUND
from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA
from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA, FILTER, SORT

try:
from tortoise.models import Model
Expand All @@ -11,7 +11,6 @@
else:
tortoise_installed = True


CALLABLE = Callable[..., Coroutine[Any, Any, Model]]
CALLABLE_LIST = Callable[..., Coroutine[Any, Any, List[Model]]]

Expand Down Expand Up @@ -58,11 +57,21 @@ def __init__(
)

def _get_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST:
async def route(pagination: PAGINATION = self.pagination) -> List[Model]:
async def route(
pagination: PAGINATION = self.pagination,
filter_: FILTER = self.filter,
sort_: SORT = self.sort,
) -> List[Model]:
skip, limit = pagination.get("skip"), pagination.get("limit")
query = self.db_model.all().offset(cast(int, skip))
query = self.db_model.filter(**filter_).offset(cast(int, skip))
if limit:
query = query.limit(limit)

if sort_:
field = sort_.get("sort", self._pk)
order = "-" + field if sort_.get("reverse", False) else field
query = query.order_by(order)

return await query

return route
Expand Down Expand Up @@ -101,7 +110,9 @@ async def route(
def _delete_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST:
async def route() -> List[Model]:
await self.db_model.all().delete()
return await self._get_all()(pagination={"skip": 0, "limit": None})
return await self._get_all()(
pagination={"skip": 0, "limit": None}, filter_={}, sort_={}
)

return route

Expand Down
2 changes: 0 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,3 @@ ignore_missing_imports = True

[mypy-uvicorn.*]
ignore_missing_imports = True


Loading