diff --git a/fastapi_crudrouter/core/_base.py b/fastapi_crudrouter/core/_base.py index 9b3067c..a218cc3 100644 --- a/fastapi_crudrouter/core/_base.py +++ b/fastapi_crudrouter/core/_base.py @@ -4,7 +4,7 @@ from fastapi.types import DecoratedCallable from ._types import T, DEPENDENCIES -from ._utils import pagination_factory, schema_factory +from ._utils import pagination_factory, schema_factory, query_factory, sort_factory NOT_FOUND = HTTPException(404, "Item not found") @@ -34,6 +34,9 @@ def __init__( self.schema = schema self.pagination = pagination_factory(max_limit=paginate) + self.filter = query_factory(self.schema) + self.sort = sort_factory(self.schema) + self._pk: str = self._pk if hasattr(self, "_pk") else "id" self.create_schema = ( create_schema diff --git a/fastapi_crudrouter/core/_types.py b/fastapi_crudrouter/core/_types.py index 959d5fe..a88ecee 100644 --- a/fastapi_crudrouter/core/_types.py +++ b/fastapi_crudrouter/core/_types.py @@ -1,9 +1,11 @@ -from typing import Dict, TypeVar, Optional, Sequence +from typing import Dict, TypeVar, Optional, Sequence, Union from fastapi.params import Depends from pydantic import BaseModel PAGINATION = Dict[str, Optional[int]] +FILTER = Dict[str, Optional[Union[int, float, str, bool]]] +SORT = Dict[str, str] PYDANTIC_SCHEMA = BaseModel T = TypeVar("T", bound=BaseModel) diff --git a/fastapi_crudrouter/core/_utils.py b/fastapi_crudrouter/core/_utils.py index c42e197..bffe81b 100644 --- a/fastapi_crudrouter/core/_utils.py +++ b/fastapi_crudrouter/core/_utils.py @@ -1,11 +1,18 @@ -from typing import Optional, Type, TypeVar, Any +from typing import Optional, Type, TypeVar, Any, List # noqa: F401 -from fastapi import Depends, HTTPException +from fastapi import Depends, HTTPException, Query # noqa: F401 from pydantic import create_model -from ._types import PAGINATION, PYDANTIC_SCHEMA +from ._types import PAGINATION, PYDANTIC_SCHEMA, SORT, FILTER # noqa: F401 T = TypeVar("T", bound=PYDANTIC_SCHEMA) +FILTER_MAPPING = { + "int": int, + "float": float, + "bool": bool, + "str": str, + "ConstrainedStrValue": str, +} def get_pk_type(schema: Type[PYDANTIC_SCHEMA], pk_field: str) -> Any: @@ -71,3 +78,49 @@ def pagination(skip: int = 0, limit: Optional[int] = max_limit) -> PAGINATION: return {"skip": skip, "limit": limit} return Depends(pagination) + + +def query_factory(schema: Type[T]) -> Any: + """ + Dynamically builds a Fastapi query dependency based on all available field in the + """ + + _str = "{}: Optional[{}] = Query(None)" + args_str = ", ".join( + [ + _str.format(name, FILTER_MAPPING[field.type_.__name__].__name__) + for name, field in schema.__fields__.items() + if field.type_.__name__ in FILTER_MAPPING + ] + ) + + _str = "{}={}" + return_str = ", ".join( + [ + _str.format(name, field.name) + for name, field in schema.__fields__.items() + if field.type_.__name__ in FILTER_MAPPING + ] + ) + + filter_func_src = f""" +def filter_func({args_str}) -> FILTER: + ret = dict({return_str}) + return {{k:v for k, v in ret.items() if v is not None}} +""" + + exec(filter_func_src, globals(), locals()) + return Depends(locals().get("filter_func")) + + +def sort_factory(schema: Type[T]) -> Any: + fields = [field.name for field in schema.__fields__.values()] + + def sort_func( + sort_: str = Query(None, alias="sort", enum=fields), + direction: str = Query(None, enum=["asc", "desc"]), + ) -> SORT: + ret = {"sort": sort_, "reverse": direction == "desc"} + return {k: v for k, v in ret.items() if v} # type: ignore + + return Depends(sort_func) diff --git a/fastapi_crudrouter/core/databases.py b/fastapi_crudrouter/core/databases.py index a761278..a8fab6a 100644 --- a/fastapi_crudrouter/core/databases.py +++ b/fastapi_crudrouter/core/databases.py @@ -12,10 +12,11 @@ from fastapi import HTTPException from . import CRUDGenerator, NOT_FOUND, _utils -from ._types import PAGINATION, PYDANTIC_SCHEMA, DEPENDENCIES +from ._types import PAGINATION, PYDANTIC_SCHEMA, DEPENDENCIES, FILTER, SORT try: from sqlalchemy.sql.schema import Table + from sqlalchemy import desc from databases.core import Database except ImportError: databases_installed = False @@ -75,10 +76,20 @@ def __init__( def _get_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST: async def route( pagination: PAGINATION = self.pagination, + filter_: FILTER = self.filter, + sort_: SORT = self.sort, ) -> List[Model]: skip, limit = pagination.get("skip"), pagination.get("limit") - query = self.table.select().limit(limit).offset(skip) + + if sort_: + field = getattr(self.table.c, sort_.get("sort", self._pk)) + order = desc(field) if sort_.get("reverse", False) else field + query = query.order_by(order) + + for col, val in filter_.items(): + query = query.where(self.table.c[col] == val) + return await self.db.fetch_all(query) return route @@ -129,7 +140,9 @@ async def route() -> List[Model]: query = self.table.delete() await self.db.execute(query=query) - return await self._get_all()(pagination={"skip": 0, "limit": None}) + return await self._get_all()( + pagination={"skip": 0, "limit": None}, filter_={}, sort_={} + ) return route diff --git a/fastapi_crudrouter/core/mem.py b/fastapi_crudrouter/core/mem.py index d4e13c1..0e41326 100644 --- a/fastapi_crudrouter/core/mem.py +++ b/fastapi_crudrouter/core/mem.py @@ -1,7 +1,7 @@ from typing import Any, Callable, List, Type, cast, Optional, Union from . import CRUDGenerator, NOT_FOUND -from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA +from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA, FILTER, SORT CALLABLE = Callable[..., SCHEMA] CALLABLE_LIST = Callable[..., List[SCHEMA]] @@ -44,15 +44,18 @@ def __init__( self._id = 1 def _get_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST: - def route(pagination: PAGINATION = self.pagination) -> List[SCHEMA]: + def route( + pagination: PAGINATION = self.pagination, + filter_: FILTER = self.filter, + sort_: SORT = self.sort, + ) -> List[SCHEMA]: skip, limit = pagination.get("skip"), pagination.get("limit") skip = cast(int, skip) - return ( - self.models[skip:] - if limit is None - else self.models[skip : skip + limit] - ) + models = self._get_filtered_list(self.models, filter_) + models = self._get_sorted_list(models, sort_) + + return models[skip:] if limit is None else models[skip : skip + limit] return route @@ -69,7 +72,7 @@ def route(item_id: int) -> SCHEMA: def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: def route(model: self.create_schema) -> SCHEMA: # type: ignore model_dict = model.dict() - model_dict["id"] = self._get_next_id() + model_dict[self._pk] = self._get_next_id() ready_model = self.schema(**model_dict) self.models.append(ready_model) return ready_model @@ -112,3 +115,32 @@ def _get_next_id(self) -> int: self._id += 1 return id_ + + def _get_sorted_list(self, models: List[SCHEMA], sort_: SORT) -> List[SCHEMA]: + if not sort_: + return models + + field = sort_.get("sort", self._pk) + models.sort( + reverse=bool(sort_.get("reverse", False)), key=lambda x: getattr(x, field) # type: ignore + ) + + return models + + @staticmethod + def _get_filtered_list(models: List[SCHEMA], filters_: FILTER) -> List[SCHEMA]: + if not filters_: + return models + + return [ + model + for model in models + if MemoryCRUDRouter._check_filters(model, filters_) + ] + + @staticmethod + def _check_filters(model: SCHEMA, filters_: FILTER) -> bool: + for k, v in filters_.items(): + if getattr(model, k) != v: + return False + return True diff --git a/fastapi_crudrouter/core/ormar.py b/fastapi_crudrouter/core/ormar.py index 4586b5b..1304276 100644 --- a/fastapi_crudrouter/core/ormar.py +++ b/fastapi_crudrouter/core/ormar.py @@ -12,7 +12,7 @@ from fastapi import HTTPException from . import CRUDGenerator, NOT_FOUND, _utils -from ._types import DEPENDENCIES, PAGINATION +from ._types import DEPENDENCIES, PAGINATION, FILTER, SORT try: from ormar import Model, NoMatch @@ -41,7 +41,7 @@ def __init__( update_route: Union[bool, DEPENDENCIES] = True, delete_one_route: Union[bool, DEPENDENCIES] = True, delete_all_route: Union[bool, DEPENDENCIES] = True, - **kwargs: Any + **kwargs: Any, ) -> None: assert ormar_installed, "Ormar must be installed to use the OrmarCRUDRouter." @@ -61,7 +61,7 @@ def __init__( update_route=update_route, delete_one_route=delete_one_route, delete_all_route=delete_all_route, - **kwargs + **kwargs, ) self._INTEGRITY_ERROR = self._get_integrity_error_type() @@ -69,11 +69,19 @@ def __init__( def _get_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST: async def route( pagination: PAGINATION = self.pagination, + filter_: FILTER = self.filter, + sort_: SORT = self.sort, ) -> List[Optional[Model]]: skip, limit = pagination.get("skip"), pagination.get("limit") - query = self.schema.objects.offset(cast(int, skip)) - if limit: - query = query.limit(limit) + query = self.schema.objects.filter(**filter_) # type: ignore + + if sort_: + field = sort_.get("sort", self._pk) + order = f"-{field}" if sort_.get("reverse", False) else field + query = query.order_by(order) + + query = query.limit(limit).offset(cast(int, skip)) # type: ignore + return await query.all() return route @@ -122,7 +130,9 @@ async def route( def _delete_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST: async def route() -> List[Optional[Model]]: await self.schema.objects.delete(each=True) - return await self._get_all()(pagination={"skip": 0, "limit": None}) + return await self._get_all()( + pagination={"skip": 0, "limit": None}, filter_={}, sort_={} + ) return route diff --git a/fastapi_crudrouter/core/sqlalchemy.py b/fastapi_crudrouter/core/sqlalchemy.py index fa26f4a..b11fe06 100644 --- a/fastapi_crudrouter/core/sqlalchemy.py +++ b/fastapi_crudrouter/core/sqlalchemy.py @@ -1,9 +1,10 @@ -from typing import Any, Callable, List, Type, Generator, Optional, Union +from typing import Any, Callable, Generator, List +from typing import Optional, Type, Union from fastapi import Depends, HTTPException from . import CRUDGenerator, NOT_FOUND, _utils -from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA +from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA, FILTER, SORT try: from sqlalchemy.orm import Session @@ -65,19 +66,26 @@ def _get_all(self, *args: Any, **kwargs: Any) -> Callable[..., List[Model]]: def route( db: Session = Depends(self.db_func), pagination: PAGINATION = self.pagination, + filter_: FILTER = self.filter, + sort_: SORT = self.sort, ) -> List[Model]: skip, limit = pagination.get("skip"), pagination.get("limit") + query = db.query(self.db_model).filter_by(**filter_) - db_models: List[Model] = ( - db.query(self.db_model).limit(limit).offset(skip).all() - ) + if sort_: + field = getattr(self.db_model, sort_.get("sort", self._pk)) + order = field.desc() if sort_.get("reverse", False) else field + query = query.order_by(order) + + db_models: List[Model] = query.limit(limit).offset(skip).all() return db_models return route def _get_one(self, *args: Any, **kwargs: Any) -> Callable[..., Model]: def route( - item_id: self._pk_type, db: Session = Depends(self.db_func) # type: ignore + item_id: Optional[self._pk_type] = None, # type: ignore + db: Session = Depends(self.db_func), ) -> Model: model: Model = db.query(self.db_model).get(item_id) @@ -133,7 +141,9 @@ def route(db: Session = Depends(self.db_func)) -> List[Model]: db.query(self.db_model).delete() db.commit() - return self._get_all()(db=db, pagination={"skip": 0, "limit": None}) + return self._get_all()( + db=db, pagination={"skip": 0, "limit": None}, filter_={}, sort_={} + ) return route diff --git a/fastapi_crudrouter/core/tortoise.py b/fastapi_crudrouter/core/tortoise.py index 23c3926..675131c 100644 --- a/fastapi_crudrouter/core/tortoise.py +++ b/fastapi_crudrouter/core/tortoise.py @@ -1,7 +1,7 @@ from typing import Any, Callable, List, Type, cast, Coroutine, Optional, Union from . import CRUDGenerator, NOT_FOUND -from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA +from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA, FILTER, SORT try: from tortoise.models import Model @@ -11,7 +11,6 @@ else: tortoise_installed = True - CALLABLE = Callable[..., Coroutine[Any, Any, Model]] CALLABLE_LIST = Callable[..., Coroutine[Any, Any, List[Model]]] @@ -58,11 +57,21 @@ def __init__( ) def _get_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST: - async def route(pagination: PAGINATION = self.pagination) -> List[Model]: + async def route( + pagination: PAGINATION = self.pagination, + filter_: FILTER = self.filter, + sort_: SORT = self.sort, + ) -> List[Model]: skip, limit = pagination.get("skip"), pagination.get("limit") - query = self.db_model.all().offset(cast(int, skip)) + query = self.db_model.filter(**filter_).offset(cast(int, skip)) if limit: query = query.limit(limit) + + if sort_: + field = sort_.get("sort", self._pk) + order = "-" + field if sort_.get("reverse", False) else field + query = query.order_by(order) + return await query return route @@ -101,7 +110,9 @@ async def route( def _delete_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST: async def route() -> List[Model]: await self.db_model.all().delete() - return await self._get_all()(pagination={"skip": 0, "limit": None}) + return await self._get_all()( + pagination={"skip": 0, "limit": None}, filter_={}, sort_={} + ) return route diff --git a/setup.cfg b/setup.cfg index a2081ed..02371af 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,5 +33,3 @@ ignore_missing_imports = True [mypy-uvicorn.*] ignore_missing_imports = True - - diff --git a/tests/test_query_params.py b/tests/test_query_params.py new file mode 100644 index 0000000..a047021 --- /dev/null +++ b/tests/test_query_params.py @@ -0,0 +1,59 @@ +from operator import itemgetter + +from tests.test_router import test_post, test_get + + +def insert(_client): + test_post(_client) + test_post(_client, expected_length=2) + test_post(_client, expected_length=3) + + test_post( + _client, + model=dict(thickness=0.24, mass=1.1, color="red", type="Large"), + expected_length=4, + ) + test_post( + _client, + model=dict(thickness=0.10, mass=1.9, color="red", type="Small"), + expected_length=5, + ) + test_post( + _client, + model=dict(thickness=0.25, mass=1.9, color="red", type="Medium"), + expected_length=6, + ) + + +def test_simple(client): + insert(client) + + test_get(client, params={"color": "red"}, expected_length=3) + test_get(client, params={"color": "blue"}, expected_length=0) + test_get(client, params={"type": "Large"}, expected_length=1) + test_get(client, params={"thickness": 0.24}, expected_length=4) + + +def test_two_params(client): + insert(client) + + test_get(client, params={"color": "red", "type": "Large"}, expected_length=1) + test_get(client, params={"color": "red", "type": "Small"}, expected_length=1) + test_get(client, params={"color": "blue", "type": "Small"}, expected_length=0) + test_get(client, params={"thickness": 0.24, "mass": 1.2}, expected_length=3) + + +def test_sort_asc(client): + insert(client) + + data1 = test_get(client, params={"color": "red", "sort": "thickness"}, expected_length=3) + assert data1 == sorted(data1, key=itemgetter("thickness")) + + +def test_sort_desc(client): + insert(client) + + data = test_get(client, params={"color": "red", "sort": "thickness", "direction": "desc"}, expected_length=3) + assert data == sorted(data, key=itemgetter("thickness"), reverse=True) + +