diff --git a/fastcrud/crud/fast_crud.py b/fastcrud/crud/fast_crud.py index a200ea7..498c5ab 100644 --- a/fastcrud/crud/fast_crud.py +++ b/fastcrud/crud/fast_crud.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Generic, Union, Optional, Callable +from typing import Any, Generic, Union, Optional, Callable from datetime import datetime, timezone from pydantic import ValidationError @@ -793,7 +793,7 @@ async def upsert( instance: Union[UpdateSchemaType, CreateSchemaType], schema_to_select: Optional[type[SelectSchemaType]] = None, return_as_model: bool = False, - ) -> Union[SelectSchemaType, Dict[str, Any], None]: + ) -> Union[SelectSchemaType, dict[str, Any], None]: """Update the instance or create it if it doesn't exists. Note: This method will perform two transactions to the database (get and create or update). @@ -841,7 +841,7 @@ async def upsert_multi( return_as_model: bool = False, update_override: Optional[dict[str, Any]] = None, **kwargs: Any, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[dict[str, Any]]: """ Upsert multiple records in the database. The underlying implementation varies based on the database dialect. diff --git a/fastcrud/endpoint/endpoint_creator.py b/fastcrud/endpoint/endpoint_creator.py index c9ebbfc..212f786 100644 --- a/fastcrud/endpoint/endpoint_creator.py +++ b/fastcrud/endpoint/endpoint_creator.py @@ -1,11 +1,12 @@ -from typing import Type, Optional, Callable, Sequence, Union +from typing import Type, Optional, Callable, Sequence, Union, Any, cast from enum import Enum from fastapi import Depends, Body, Query, APIRouter -from pydantic import ValidationError +from pydantic import ValidationError, BaseModel from sqlalchemy.ext.asyncio import AsyncSession from fastcrud.crud.fast_crud import FastCRUD +from fastcrud.paginated import ListResponse, PaginatedListResponse from fastcrud.types import ( CreateSchemaType, DeleteSchemaType, @@ -295,6 +296,31 @@ def __init__( self.filter_config = filter_config self.column_types = _get_column_types(model) + if select_schema is not None: + self.list_response_model: Optional[Type[ListResponse[Any]]] = type( + "DynamicListResponse", + (ListResponse[BaseModel],), + {"__annotations__": {"data": list[select_schema]}}, # type: ignore + ) + self.paginated_response_model: Optional[ + Type[PaginatedListResponse[Any]] + ] = type( + "DynamicPaginatedResponse", + (PaginatedListResponse[BaseModel],), + { + "__annotations__": { + "data": list[select_schema], # type: ignore + "total_count": int, + "has_more": bool, + "page": Optional[int], + "items_per_page": Optional[int], + } + }, + ) + else: + self.list_response_model = None + self.paginated_response_model = None + def _validate_filter_config(self, filter_config: FilterConfig) -> None: model_columns = self.crud.model_col_names for key in filter_config.filters.keys(): @@ -334,7 +360,7 @@ async def endpoint(db: AsyncSession = Depends(self.session), **pkeys): if self.select_schema is not None: item = await self.crud.get( db, - schema_to_select=self.select_schema, + schema_to_select=cast(Type[BaseModel], self.select_schema), return_as_model=True, **pkeys, ) @@ -363,7 +389,7 @@ async def endpoint( None, alias="itemsPerPage", description="Number of items per page" ), filters: dict = Depends(dynamic_filters), - ): + ) -> Union[dict[str, Any], PaginatedListResponse, ListResponse]: is_paginated = (page is not None) or (items_per_page is not None) has_offset_limit = (offset is not None) and (limit is not None) @@ -599,10 +625,21 @@ def get_current_user(...): include_in_schema=self.include_in_schema, tags=self.tags, dependencies=_inject_dependencies(read_deps), + response_model=self.select_schema if self.select_schema else None, description=f"Read a single {self.model.__name__} row from the database by its primary keys: {self.primary_key_names}.", ) if ("read_multi" in included_methods) and ("read_multi" not in deleted_methods): + if self.select_schema is not None: + response_model: Optional[ + Type[Union[PaginatedListResponse[Any], ListResponse[Any]]] + ] = Union[ + self.paginated_response_model, # type: ignore + self.list_response_model, # type: ignore + ] + else: + response_model = None + self.router.add_api_route( self._get_endpoint_path(operation="read_multi"), self._read_items(), @@ -610,7 +647,14 @@ def get_current_user(...): include_in_schema=self.include_in_schema, tags=self.tags, dependencies=_inject_dependencies(read_multi_deps), - description=f"Read multiple {self.model.__name__} rows from the database with a limit and an offset.", + response_model=response_model, + description=( + f"Read multiple {self.model.__name__} rows from the database.\n\n" + f"- Use page & itemsPerPage for paginated results\n" + f"- Use offset & limit for specific ranges\n" + f"- Returns paginated response when using page/itemsPerPage\n" + f"- Returns simple list response when using offset/limit" + ), ) if ("update" in included_methods) and ("update" not in deleted_methods): diff --git a/fastcrud/types.py b/fastcrud/types.py index 24af6ac..ccce731 100644 --- a/fastcrud/types.py +++ b/fastcrud/types.py @@ -1,4 +1,4 @@ -from typing import TypeVar, Any, Dict, Union, List +from typing import TypeVar, Any, Union from pydantic import BaseModel @@ -10,5 +10,5 @@ UpdateSchemaInternalType = TypeVar("UpdateSchemaInternalType", bound=BaseModel) DeleteSchemaType = TypeVar("DeleteSchemaType", bound=BaseModel) -GetMultiResponseDict = Dict[str, Union[List[Dict[str, Any]], int]] -GetMultiResponseModel = Dict[str, Union[List[SelectSchemaType], int]] +GetMultiResponseDict = dict[str, Union[list[dict[str, Any]], int]] +GetMultiResponseModel = dict[str, Union[list[SelectSchemaType], int]]