From c58b4f1e694a4934a05f29e5fe2e57c26e7d5811 Mon Sep 17 00:00:00 2001 From: Louis Cochen Date: Wed, 9 Oct 2024 23:22:36 +0100 Subject: [PATCH 1/2] Implement `select_schema` on `EndpointCreator` and `crud_router` --- fastcrud/crud/fast_crud.py | 53 +++++++++++------------ fastcrud/crud/helper.py | 19 ++++---- fastcrud/endpoint/crud_router.py | 4 ++ fastcrud/endpoint/endpoint_creator.py | 50 ++++++++++++++------- fastcrud/types.py | 1 + tests/sqlmodel/conftest.py | 33 ++++++++++++++ tests/sqlmodel/endpoint/test_get_item.py | 23 ++++++++++ tests/sqlmodel/endpoint/test_get_items.py | 24 ++++++++++ 8 files changed, 156 insertions(+), 51 deletions(-) diff --git a/fastcrud/crud/fast_crud.py b/fastcrud/crud/fast_crud.py index d3c876b..9be0a3a 100644 --- a/fastcrud/crud/fast_crud.py +++ b/fastcrud/crud/fast_crud.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Generic, Union, Optional, Callable from datetime import datetime, timezone -from pydantic import BaseModel, ValidationError +from pydantic import ValidationError from sqlalchemy import ( Insert, Result, @@ -29,6 +29,7 @@ CreateSchemaType, DeleteSchemaType, ModelType, + SelectSchemaType, UpdateSchemaInternalType, UpdateSchemaType, ) @@ -52,6 +53,7 @@ class FastCRUD( UpdateSchemaType, UpdateSchemaInternalType, DeleteSchemaType, + SelectSchemaType, ] ): """ @@ -649,7 +651,7 @@ async def create( async def select( self, - schema_to_select: Optional[type[BaseModel]] = None, + schema_to_select: Optional[type[SelectSchemaType]] = None, sort_columns: Optional[Union[str, list[str]]] = None, sort_orders: Optional[Union[str, list[str]]] = None, **kwargs: Any, @@ -715,11 +717,11 @@ async def select( async def get( self, db: AsyncSession, - schema_to_select: Optional[type[BaseModel]] = None, + schema_to_select: Optional[type[SelectSchemaType]] = None, return_as_model: bool = False, one_or_none: bool = False, **kwargs: Any, - ) -> Optional[Union[dict, BaseModel]]: + ) -> Optional[Union[dict, SelectSchemaType]]: """ Fetches a single record based on specified filters. @@ -787,9 +789,9 @@ async def upsert( self, db: AsyncSession, instance: Union[UpdateSchemaType, CreateSchemaType], - schema_to_select: Optional[type[BaseModel]] = None, + schema_to_select: Optional[type[SelectSchemaType]] = None, return_as_model: bool = False, - ) -> Union[BaseModel, Dict[str, Any], None]: + ) -> Union[SelectSchemaType, Dict[str, Any], None]: """Update the instance or create it if it doesn't exists. Note: This method will perform two transactions to the database (get and create or update). @@ -804,23 +806,23 @@ async def upsert( The created or updated instance """ _pks = self._get_pk_dict(instance) - schema_to_select = schema_to_select or type(instance) + schema_to_select = schema_to_select or type(instance) # type: ignore db_instance = await self.get( db, - schema_to_select=schema_to_select, + schema_to_select=schema_to_select, # type: ignore return_as_model=return_as_model, **_pks, ) if db_instance is None: db_instance = await self.create(db, instance) # type: ignore - db_instance = schema_to_select.model_validate( + db_instance = schema_to_select.model_validate( # type: ignore db_instance, from_attributes=True ) else: await self.update(db, instance) # type: ignore db_instance = await self.get( db, - schema_to_select=schema_to_select, + schema_to_select=schema_to_select, # type: ignore return_as_model=return_as_model, **_pks, ) @@ -832,7 +834,7 @@ async def upsert_multi( db: AsyncSession, instances: list[Union[UpdateSchemaType, CreateSchemaType]], return_columns: Optional[list[str]] = None, - schema_to_select: Optional[type[BaseModel]] = None, + schema_to_select: Optional[type[SelectSchemaType]] = None, return_as_model: bool = False, update_override: Optional[dict[str, Any]] = None, **kwargs: Any, @@ -1134,7 +1136,7 @@ async def get_multi( db: AsyncSession, offset: int = 0, limit: Optional[int] = 100, - schema_to_select: Optional[type[BaseModel]] = None, + schema_to_select: Optional[type[SelectSchemaType]] = None, sort_columns: Optional[Union[str, list[str]]] = None, sort_orders: Optional[Union[str, list[str]]] = None, return_as_model: bool = False, @@ -1277,11 +1279,11 @@ async def get_multi( async def get_joined( self, db: AsyncSession, - schema_to_select: Optional[type[BaseModel]] = None, + schema_to_select: Optional[type[SelectSchemaType]] = None, join_model: Optional[ModelType] = None, join_on: Optional[Union[Join, BinaryExpression]] = None, join_prefix: Optional[str] = None, - join_schema_to_select: Optional[type[BaseModel]] = None, + join_schema_to_select: Optional[type[SelectSchemaType]] = None, join_type: str = "left", alias: Optional[AliasedClass] = None, join_filters: Optional[dict] = None, @@ -1593,11 +1595,11 @@ async def get_joined( async def get_multi_joined( self, db: AsyncSession, - schema_to_select: Optional[type[BaseModel]] = None, + schema_to_select: Optional[type[SelectSchemaType]] = None, join_model: Optional[type[ModelType]] = None, join_on: Optional[Any] = None, join_prefix: Optional[str] = None, - join_schema_to_select: Optional[type[BaseModel]] = None, + join_schema_to_select: Optional[type[SelectSchemaType]] = None, join_type: str = "left", alias: Optional[AliasedClass[Any]] = None, join_filters: Optional[dict] = None, @@ -1937,7 +1939,7 @@ async def get_multi_joined( stmt = stmt.limit(limit) result = await db.execute(stmt) - data: list[Union[dict, BaseModel]] = [] + data: list[Union[dict, SelectSchemaType]] = [] for row in result.mappings().all(): row_dict = dict(row) @@ -2000,7 +2002,7 @@ async def get_multi_by_cursor( db: AsyncSession, cursor: Any = None, limit: int = 100, - schema_to_select: Optional[type[BaseModel]] = None, + schema_to_select: Optional[type[SelectSchemaType]] = None, sort_column: str = "id", sort_order: str = "asc", **kwargs: Any, @@ -2076,10 +2078,7 @@ async def get_multi_by_cursor( if limit == 0: return {"data": [], "next_cursor": None} - stmt = await self.select( - schema_to_select=schema_to_select, - **kwargs, - ) + stmt = await self.select(schema_to_select=schema_to_select, **kwargs) if cursor: if sort_order == "asc": @@ -2113,11 +2112,11 @@ async def update( allow_multiple: bool = False, commit: bool = True, return_columns: Optional[list[str]] = None, - schema_to_select: Optional[type[BaseModel]] = None, + schema_to_select: Optional[type[SelectSchemaType]] = None, return_as_model: bool = False, one_or_none: bool = False, **kwargs: Any, - ) -> Optional[Union[dict, BaseModel]]: + ) -> Optional[Union[dict, SelectSchemaType]]: """ Updates an existing record or multiple records in the database based on specified filters. This method allows for precise targeting of records to update. @@ -2243,10 +2242,10 @@ async def update( def _as_single_response( self, db_row: Result, - schema_to_select: Optional[type[BaseModel]] = None, + schema_to_select: Optional[type[SelectSchemaType]] = None, return_as_model: bool = False, one_or_none: bool = False, - ) -> Optional[Union[dict, BaseModel]]: + ) -> Optional[Union[dict, SelectSchemaType]]: result: Optional[Row] = db_row.one_or_none() if one_or_none else db_row.first() if result is None: # pragma: no cover return None @@ -2262,7 +2261,7 @@ def _as_single_response( def _as_multi_response( self, db_row: Result, - schema_to_select: Optional[type[BaseModel]] = None, + schema_to_select: Optional[type[SelectSchemaType]] = None, return_as_model: bool = False, ) -> dict: data = [dict(row) for row in db_row.mappings()] diff --git a/fastcrud/crud/helper.py b/fastcrud/crud/helper.py index 402ee3e..0524f5c 100644 --- a/fastcrud/crud/helper.py +++ b/fastcrud/crud/helper.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict from pydantic.functional_validators import field_validator -from fastcrud.types import ModelType +from fastcrud.types import ModelType, SelectSchemaType from ..endpoint.helper import _get_primary_key @@ -40,7 +40,7 @@ def check_valid_join_type(cls, value): def _extract_matching_columns_from_schema( model: Union[ModelType, AliasedClass], - schema: Optional[type[BaseModel]], + schema: Optional[type[SelectSchemaType]], prefix: Optional[str] = None, alias: Optional[AliasedClass] = None, use_temporary_prefix: Optional[bool] = False, @@ -443,12 +443,12 @@ def _nest_join_data( def _nest_multi_join_data( base_primary_key: str, - data: list[Union[dict, BaseModel]], + data: Sequence[Union[dict, BaseModel]], joins_config: Sequence[JoinConfig], return_as_model: bool = False, - schema_to_select: Optional[type[BaseModel]] = None, - nested_schema_to_select: Optional[dict[str, type[BaseModel]]] = None, -) -> Sequence[Union[dict, BaseModel]]: + schema_to_select: Optional[type[SelectSchemaType]] = None, + nested_schema_to_select: Optional[dict[str, type[SelectSchemaType]]] = None, +) -> Sequence[Union[dict, SelectSchemaType]]: """ Nests joined data based on join definitions provided for multiple records. This function processes the input list of dictionaries, identifying keys that correspond to joined tables using the provided `joins_config`, and nests them @@ -464,7 +464,7 @@ def _nest_multi_join_data( nested_schema_to_select: A dictionary mapping join prefixes to their corresponding Pydantic schemas. Returns: - Sequence[Union[dict, BaseModel]]: A list of dictionaries with nested structures for joined table data or Pydantic models. + Sequence[Union[dict, SelectSchemaType]]: A list of dictionaries with nested structures for joined table data or Pydantic models. Example: @@ -584,8 +584,9 @@ def _nest_multi_join_data( def _handle_null_primary_key_multi_join( - data: list[Union[dict[str, Any], BaseModel]], join_definitions: list[JoinConfig] -) -> list[Union[dict[str, Any], BaseModel]]: + data: list[Union[dict[str, Any], SelectSchemaType]], + join_definitions: list[JoinConfig], +) -> list[Union[dict[str, Any], SelectSchemaType]]: for item in data: item_dict = item if isinstance(item, dict) else item.model_dump() diff --git a/fastcrud/endpoint/crud_router.py b/fastcrud/endpoint/crud_router.py index c8cc0cd..7d87e21 100644 --- a/fastcrud/endpoint/crud_router.py +++ b/fastcrud/endpoint/crud_router.py @@ -9,6 +9,7 @@ DeleteSchemaType, ModelType, UpdateSchemaType, + SelectSchemaType, ) from .endpoint_creator import EndpointCreator from .helper import FilterConfig @@ -38,6 +39,7 @@ def crud_router( updated_at_column: str = "updated_at", endpoint_names: Optional[dict[str, str]] = None, filter_config: Optional[Union[FilterConfig, dict]] = None, + select_schema: Optional[Type[SelectSchemaType]] = None, ) -> APIRouter: """ Creates and configures a FastAPI router with CRUD endpoints for a given model. @@ -72,6 +74,7 @@ def crud_router( (`"create"`, `"read"`, `"update"`, `"delete"`, `"db_delete"`, `"read_multi"`), and values are the custom names to use. Unspecified operations will use default names. filter_config: Optional `FilterConfig` instance or dictionary to configure filters for the `read_multi` endpoint. + select_schema: Optional Pydantic schema for selecting an item. Returns: Configured `APIRouter` instance with the CRUD endpoints. @@ -541,6 +544,7 @@ async def add_routes_to_router(self, ...): updated_at_column=updated_at_column, endpoint_names=endpoint_names, filter_config=filter_config, + select_schema=select_schema, # type: ignore ) endpoint_creator_instance.add_routes_to_router( diff --git a/fastcrud/endpoint/endpoint_creator.py b/fastcrud/endpoint/endpoint_creator.py index f4f9694..b0d0bb8 100644 --- a/fastcrud/endpoint/endpoint_creator.py +++ b/fastcrud/endpoint/endpoint_creator.py @@ -10,6 +10,7 @@ CreateSchemaType, DeleteSchemaType, ModelType, + SelectSchemaType, UpdateSchemaType, ) from ..exceptions.http_exceptions import ( @@ -58,6 +59,7 @@ class EndpointCreator: (`"create"`, `"read"`, `"update"`, `"delete"`, `"db_delete"`, `"read_multi"`), and values are the custom names to use. Unspecified operations will use default names. filter_config: Optional `FilterConfig` instance or dictionary to configure filters for the `read_multi` endpoint. + select_schema: Optional Pydantic schema for selecting an item. Raises: ValueError: If both `included_methods` and `deleted_methods` are provided. @@ -251,6 +253,7 @@ def __init__( updated_at_column: str = "updated_at", endpoint_names: Optional[dict[str, str]] = None, filter_config: Optional[Union[FilterConfig, dict]] = None, + select_schema: Optional[Type[SelectSchemaType]] = None, ) -> None: self._primary_keys = _get_primary_keys(model) self._primary_keys_types = { @@ -268,6 +271,7 @@ def __init__( self.create_schema = create_schema self.update_schema = update_schema self.delete_schema = delete_schema + self.select_schema = select_schema self.include_in_schema = include_in_schema self.path = path self.tags = tags or [] @@ -327,7 +331,15 @@ def _read_item(self): @_apply_model_pk(**self._primary_keys_types) async def endpoint(db: AsyncSession = Depends(self.session), **pkeys): - item = await self.crud.get(db, **pkeys) + if self.select_schema is not None: + item = await self.crud.get( + db, + schema_to_select=self.select_schema, + return_as_model=True, + **pkeys, + ) + else: + item = await self.crud.get(db, **pkeys) if not item: # pragma: no cover raise NotFoundException(detail="Item not found") return item # pragma: no cover @@ -359,29 +371,37 @@ async def endpoint( raise BadRequestException( detail="Conflicting parameters: Use either 'page' and 'itemsPerPage' for paginated results or 'offset' and 'limit' for specific range queries." ) - - if is_paginated: + elif is_paginated: offset = compute_offset(page=page, items_per_page=items_per_page) # type: ignore limit = items_per_page + elif not has_offset_limit: + offset = 0 + limit = 100 + + if self.select_schema is not None: + crud_data = await self.crud.get_multi( + db, + offset=offset, # type: ignore + limit=limit, # type: ignore + schema_to_select=self.select_schema, + return_as_model=True, + **filters, + ) + else: crud_data = await self.crud.get_multi( - db, offset=offset, limit=limit, **filters + db, + offset=offset, # type: ignore + limit=limit, # type: ignore + **filters, ) + + if is_paginated: return paginated_response( crud_data=crud_data, - page=page, # type: ignore + page=page, # type: ignore items_per_page=items_per_page, # type: ignore ) - if not has_offset_limit: - offset = 0 - limit = 100 - - crud_data = await self.crud.get_multi( - db, - offset=offset, # type: ignore - limit=limit, # type: ignore - **filters, - ) return crud_data # pragma: no cover return endpoint diff --git a/fastcrud/types.py b/fastcrud/types.py index 2987e2e..548af9c 100644 --- a/fastcrud/types.py +++ b/fastcrud/types.py @@ -4,6 +4,7 @@ ModelType = TypeVar("ModelType", bound=Any) +SelectSchemaType = TypeVar("SelectSchemaType", bound=BaseModel) CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) UpdateSchemaInternalType = TypeVar("UpdateSchemaInternalType", bound=BaseModel) diff --git a/tests/sqlmodel/conftest.py b/tests/sqlmodel/conftest.py index 95232a2..e53fabc 100644 --- a/tests/sqlmodel/conftest.py +++ b/tests/sqlmodel/conftest.py @@ -650,3 +650,36 @@ def endpoint_creator(test_model, async_session) -> EndpointCreator: "read_multi": "get_multi", }, ) + + +@pytest.fixture +def client_with_select_schema( + test_model, + create_schema, + update_schema, + read_schema, + async_session, +): + app = FastAPI() + + app.include_router( + crud_router( + session=lambda: async_session, + model=test_model, + select_schema=read_schema, + create_schema=create_schema, + update_schema=update_schema, + path="/test", + tags=["test"], + endpoint_names={ + "create": "create", + "read": "get", + "update": "update", + "delete": "delete", + "db_delete": "db_delete", + "read_multi": "get_multi", + }, + ) + ) + + return TestClient(app) diff --git a/tests/sqlmodel/endpoint/test_get_item.py b/tests/sqlmodel/endpoint/test_get_item.py index b822c31..6d8b2e7 100644 --- a/tests/sqlmodel/endpoint/test_get_item.py +++ b/tests/sqlmodel/endpoint/test_get_item.py @@ -51,3 +51,26 @@ async def test_read_multi_primary_key_item_success( assert data["name"] == tester_data["name"] assert data["id"] == tester_data["id"] assert data["uuid"] == tester_data["uuid"] + + +@pytest.mark.asyncio +async def test_read_item_with_schema( + client_with_select_schema: TestClient, + async_session, + test_model, + test_data, + read_schema, +): + tester_data = {"name": test_data[0]["name"], "tier_id": test_data[0]["tier_id"]} + new_item = test_model(**tester_data) + async_session.add(new_item) + await async_session.commit() + await async_session.refresh(new_item) + + response = client_with_select_schema.get(f"/test/get/{new_item.id}") + + assert response.status_code == 200 + data = response.json() + assert read_schema.model_validate(data) + assert data["name"] == tester_data["name"] + assert data["tier_id"] == tester_data["tier_id"] diff --git a/tests/sqlmodel/endpoint/test_get_items.py b/tests/sqlmodel/endpoint/test_get_items.py index 8184a9a..eb42b52 100644 --- a/tests/sqlmodel/endpoint/test_get_items.py +++ b/tests/sqlmodel/endpoint/test_get_items.py @@ -92,3 +92,27 @@ async def test_read_items_with_dict_filter_config( @pytest.mark.asyncio async def test_invalid_filter_column(invalid_filtered_client): pass + + +@pytest.mark.asyncio +async def test_read_items_with_schema( + client_with_select_schema: TestClient, + async_session, + test_model, + test_data, + read_schema, +): + for data in test_data: + new_item = test_model(**data) + async_session.add(new_item) + await async_session.commit() + + response = client_with_select_schema.get("/test/get_multi") + + assert response.status_code == 200 + data = response.json() + + assert "data" in data + assert len(data["data"]) > 0 + + assert all(read_schema.model_validate(item) for item in data["data"]) From 6c3bee3ccfcaf38e960a2a8308fc9a84b1d5a413 Mon Sep 17 00:00:00 2001 From: Louis Cochen Date: Wed, 9 Oct 2024 23:27:54 +0100 Subject: [PATCH 2/2] Add type ignore for type that cannot be inferred --- fastcrud/crud/fast_crud.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastcrud/crud/fast_crud.py b/fastcrud/crud/fast_crud.py index 9be0a3a..f41aa61 100644 --- a/fastcrud/crud/fast_crud.py +++ b/fastcrud/crud/fast_crud.py @@ -1969,7 +1969,7 @@ async def get_multi_joined( join.relationship_type == "one-to-many" for join in join_definitions ): nested_data = _nest_multi_join_data( - base_primary_key=self._primary_keys[0].name, + base_primary_key=self._primary_keys[0].name, # type: ignore[misc] data=data, joins_config=join_definitions, return_as_model=return_as_model,