Skip to content

Commit

Permalink
Merge pull request #169 from ljmc-github/feature-add-response-model
Browse files Browse the repository at this point in the history
Implement `select_schema` on `EndpointCreator` and `crud_router`
  • Loading branch information
igorbenav authored Dec 23, 2024
2 parents 3984fc4 + 0a8fce8 commit ef49d95
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 39 deletions.
55 changes: 27 additions & 28 deletions fastcrud/crud/fast_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime, timezone
import warnings

from pydantic import BaseModel, ValidationError
from pydantic import ValidationError
from sqlalchemy import (
Insert,
Result,
Expand Down Expand Up @@ -30,6 +30,7 @@
CreateSchemaType,
DeleteSchemaType,
ModelType,
SelectSchemaType,
UpdateSchemaInternalType,
UpdateSchemaType,
)
Expand All @@ -53,6 +54,7 @@ class FastCRUD(
UpdateSchemaType,
UpdateSchemaInternalType,
DeleteSchemaType,
SelectSchemaType,
]
):
"""
Expand Down Expand Up @@ -650,7 +652,7 @@ async def create(

async def select(
self,
schema_to_select: Optional[type[BaseModel]] = None,
schema_to_select: Optional[type[SelectSchemaType]] = None,
sort_columns: Optional[Union[str, list[str]]] = None,
sort_orders: Optional[Union[str, list[str]]] = None,
**kwargs: Any,
Expand Down Expand Up @@ -716,11 +718,11 @@ async def select(
async def get(
self,
db: AsyncSession,
schema_to_select: Optional[type[BaseModel]] = None,
schema_to_select: Optional[type[SelectSchemaType]] = None,
return_as_model: bool = False,
one_or_none: bool = False,
**kwargs: Any,
) -> Optional[Union[dict, BaseModel]]:
) -> Optional[Union[dict, SelectSchemaType]]:
"""
Fetches a single record based on specified filters.
Expand Down Expand Up @@ -788,9 +790,9 @@ async def upsert(
self,
db: AsyncSession,
instance: Union[UpdateSchemaType, CreateSchemaType],
schema_to_select: Optional[type[BaseModel]] = None,
schema_to_select: Optional[type[SelectSchemaType]] = None,
return_as_model: bool = False,
) -> Union[BaseModel, Dict[str, Any], None]:
) -> Union[SelectSchemaType, Dict[str, Any], None]:
"""Update the instance or create it if it doesn't exists.
Note: This method will perform two transactions to the database (get and create or update).
Expand All @@ -805,23 +807,23 @@ async def upsert(
The created or updated instance
"""
_pks = self._get_pk_dict(instance)
schema_to_select = schema_to_select or type(instance)
schema_to_select = schema_to_select or type(instance) # type: ignore
db_instance = await self.get(
db,
schema_to_select=schema_to_select,
schema_to_select=schema_to_select, # type: ignore
return_as_model=return_as_model,
**_pks,
)
if db_instance is None:
db_instance = await self.create(db, instance) # type: ignore
db_instance = schema_to_select.model_validate(
db_instance = schema_to_select.model_validate( # type: ignore
db_instance, from_attributes=True
)
else:
await self.update(db, instance) # type: ignore
db_instance = await self.get(
db,
schema_to_select=schema_to_select,
schema_to_select=schema_to_select, # type: ignore
return_as_model=return_as_model,
**_pks,
)
Expand All @@ -834,7 +836,7 @@ async def upsert_multi(
instances: list[Union[UpdateSchemaType, CreateSchemaType]],
commit: bool = False,
return_columns: Optional[list[str]] = None,
schema_to_select: Optional[type[BaseModel]] = None,
schema_to_select: Optional[type[SelectSchemaType]] = None,
return_as_model: bool = False,
update_override: Optional[dict[str, Any]] = None,
**kwargs: Any,
Expand Down Expand Up @@ -1141,7 +1143,7 @@ async def get_multi(
db: AsyncSession,
offset: int = 0,
limit: Optional[int] = 100,
schema_to_select: Optional[type[BaseModel]] = None,
schema_to_select: Optional[type[SelectSchemaType]] = None,
sort_columns: Optional[Union[str, list[str]]] = None,
sort_orders: Optional[Union[str, list[str]]] = None,
return_as_model: bool = False,
Expand Down Expand Up @@ -1284,11 +1286,11 @@ async def get_multi(
async def get_joined(
self,
db: AsyncSession,
schema_to_select: Optional[type[BaseModel]] = None,
schema_to_select: Optional[type[SelectSchemaType]] = None,
join_model: Optional[ModelType] = None,
join_on: Optional[Union[Join, BinaryExpression]] = None,
join_prefix: Optional[str] = None,
join_schema_to_select: Optional[type[BaseModel]] = None,
join_schema_to_select: Optional[type[SelectSchemaType]] = None,
join_type: str = "left",
alias: Optional[AliasedClass] = None,
join_filters: Optional[dict] = None,
Expand Down Expand Up @@ -1600,11 +1602,11 @@ async def get_joined(
async def get_multi_joined(
self,
db: AsyncSession,
schema_to_select: Optional[type[BaseModel]] = None,
schema_to_select: Optional[type[SelectSchemaType]] = None,
join_model: Optional[type[ModelType]] = None,
join_on: Optional[Any] = None,
join_prefix: Optional[str] = None,
join_schema_to_select: Optional[type[BaseModel]] = None,
join_schema_to_select: Optional[type[SelectSchemaType]] = None,
join_type: str = "left",
alias: Optional[AliasedClass[Any]] = None,
join_filters: Optional[dict] = None,
Expand Down Expand Up @@ -1948,7 +1950,7 @@ async def get_multi_joined(
stmt = stmt.limit(limit)

result = await db.execute(stmt)
data: list[Union[dict, BaseModel]] = []
data: list[Union[dict, SelectSchemaType]] = []

for row in result.mappings().all():
row_dict = dict(row)
Expand Down Expand Up @@ -1978,7 +1980,7 @@ async def get_multi_joined(
join.relationship_type == "one-to-many" for join in join_definitions
):
nested_data = _nest_multi_join_data(
base_primary_key=self._primary_keys[0].name,
base_primary_key=self._primary_keys[0].name, # type: ignore[misc]
data=data,
joins_config=join_definitions,
return_as_model=return_as_model,
Expand Down Expand Up @@ -2011,7 +2013,7 @@ async def get_multi_by_cursor(
db: AsyncSession,
cursor: Any = None,
limit: int = 100,
schema_to_select: Optional[type[BaseModel]] = None,
schema_to_select: Optional[type[SelectSchemaType]] = None,
sort_column: str = "id",
sort_order: str = "asc",
**kwargs: Any,
Expand Down Expand Up @@ -2087,10 +2089,7 @@ async def get_multi_by_cursor(
if limit == 0:
return {"data": [], "next_cursor": None}

stmt = await self.select(
schema_to_select=schema_to_select,
**kwargs,
)
stmt = await self.select(schema_to_select=schema_to_select, **kwargs)

if cursor:
if sort_order == "asc":
Expand Down Expand Up @@ -2124,11 +2123,11 @@ async def update(
allow_multiple: bool = False,
commit: bool = True,
return_columns: Optional[list[str]] = None,
schema_to_select: Optional[type[BaseModel]] = None,
schema_to_select: Optional[type[SelectSchemaType]] = None,
return_as_model: bool = False,
one_or_none: bool = False,
**kwargs: Any,
) -> Optional[Union[dict, BaseModel]]:
) -> Optional[Union[dict, SelectSchemaType]]:
"""
Updates an existing record or multiple records in the database based on specified filters. This method allows for precise targeting of records to update.
Expand Down Expand Up @@ -2265,10 +2264,10 @@ async def update(
def _as_single_response(
self,
db_row: Result,
schema_to_select: Optional[type[BaseModel]] = None,
schema_to_select: Optional[type[SelectSchemaType]] = None,
return_as_model: bool = False,
one_or_none: bool = False,
) -> Optional[Union[dict, BaseModel]]:
) -> Optional[Union[dict, SelectSchemaType]]:
result: Optional[Row] = db_row.one_or_none() if one_or_none else db_row.first()
if result is None: # pragma: no cover
return None
Expand All @@ -2284,7 +2283,7 @@ def _as_single_response(
def _as_multi_response(
self,
db_row: Result,
schema_to_select: Optional[type[BaseModel]] = None,
schema_to_select: Optional[type[SelectSchemaType]] = None,
return_as_model: bool = False,
) -> dict:
data = [dict(row) for row in db_row.mappings()]
Expand Down
19 changes: 10 additions & 9 deletions fastcrud/crud/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import BaseModel, ConfigDict
from pydantic.functional_validators import field_validator

from fastcrud.types import ModelType
from fastcrud.types import ModelType, SelectSchemaType

from ..endpoint.helper import _get_primary_key

Expand Down Expand Up @@ -40,7 +40,7 @@ def check_valid_join_type(cls, value):

def _extract_matching_columns_from_schema(
model: Union[ModelType, AliasedClass],
schema: Optional[type[BaseModel]],
schema: Optional[type[SelectSchemaType]],
prefix: Optional[str] = None,
alias: Optional[AliasedClass] = None,
use_temporary_prefix: Optional[bool] = False,
Expand Down Expand Up @@ -443,12 +443,12 @@ def _nest_join_data(

def _nest_multi_join_data(
base_primary_key: str,
data: list[Union[dict, BaseModel]],
data: Sequence[Union[dict, BaseModel]],
joins_config: Sequence[JoinConfig],
return_as_model: bool = False,
schema_to_select: Optional[type[BaseModel]] = None,
nested_schema_to_select: Optional[dict[str, type[BaseModel]]] = None,
) -> Sequence[Union[dict, BaseModel]]:
schema_to_select: Optional[type[SelectSchemaType]] = None,
nested_schema_to_select: Optional[dict[str, type[SelectSchemaType]]] = None,
) -> Sequence[Union[dict, SelectSchemaType]]:
"""
Nests joined data based on join definitions provided for multiple records. This function processes the input list of
dictionaries, identifying keys that correspond to joined tables using the provided `joins_config`, and nests them
Expand All @@ -464,7 +464,7 @@ def _nest_multi_join_data(
nested_schema_to_select: A dictionary mapping join prefixes to their corresponding Pydantic schemas.
Returns:
Sequence[Union[dict, BaseModel]]: A list of dictionaries with nested structures for joined table data or Pydantic models.
Sequence[Union[dict, SelectSchemaType]]: A list of dictionaries with nested structures for joined table data or Pydantic models.
Example:
Expand Down Expand Up @@ -616,8 +616,9 @@ def _nest_multi_join_data(


def _handle_null_primary_key_multi_join(
data: list[Union[dict[str, Any], BaseModel]], join_definitions: list[JoinConfig]
) -> list[Union[dict[str, Any], BaseModel]]:
data: list[Union[dict[str, Any], SelectSchemaType]],
join_definitions: list[JoinConfig],
) -> list[Union[dict[str, Any], SelectSchemaType]]:
for item in data:
item_dict = item if isinstance(item, dict) else item.model_dump()

Expand Down
4 changes: 4 additions & 0 deletions fastcrud/endpoint/crud_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DeleteSchemaType,
ModelType,
UpdateSchemaType,
SelectSchemaType,
)
from .endpoint_creator import EndpointCreator
from .helper import FilterConfig
Expand Down Expand Up @@ -38,6 +39,7 @@ def crud_router(
updated_at_column: str = "updated_at",
endpoint_names: Optional[dict[str, str]] = None,
filter_config: Optional[Union[FilterConfig, dict]] = None,
select_schema: Optional[Type[SelectSchemaType]] = None,
) -> APIRouter:
"""
Creates and configures a FastAPI router with CRUD endpoints for a given model.
Expand Down Expand Up @@ -72,6 +74,7 @@ def crud_router(
(`"create"`, `"read"`, `"update"`, `"delete"`, `"db_delete"`, `"read_multi"`), and
values are the custom names to use. Unspecified operations will use default names.
filter_config: Optional `FilterConfig` instance or dictionary to configure filters for the `read_multi` endpoint.
select_schema: Optional Pydantic schema for selecting an item.
Returns:
Configured `APIRouter` instance with the CRUD endpoints.
Expand Down Expand Up @@ -541,6 +544,7 @@ async def add_routes_to_router(self, ...):
updated_at_column=updated_at_column,
endpoint_names=endpoint_names,
filter_config=filter_config,
select_schema=select_schema, # type: ignore
)

endpoint_creator_instance.add_routes_to_router(
Expand Down
35 changes: 33 additions & 2 deletions fastcrud/endpoint/endpoint_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
CreateSchemaType,
DeleteSchemaType,
ModelType,
SelectSchemaType,
UpdateSchemaType,
)
from ..exceptions.http_exceptions import (
Expand Down Expand Up @@ -58,6 +59,7 @@ class EndpointCreator:
(`"create"`, `"read"`, `"update"`, `"delete"`, `"db_delete"`, `"read_multi"`), and
values are the custom names to use. Unspecified operations will use default names.
filter_config: Optional `FilterConfig` instance or dictionary to configure filters for the `read_multi` endpoint.
select_schema: Optional Pydantic schema for selecting an item.
Raises:
ValueError: If both `included_methods` and `deleted_methods` are provided.
Expand Down Expand Up @@ -251,6 +253,7 @@ def __init__(
updated_at_column: str = "updated_at",
endpoint_names: Optional[dict[str, str]] = None,
filter_config: Optional[Union[FilterConfig, dict]] = None,
select_schema: Optional[Type[SelectSchemaType]] = None,
) -> None:
self._primary_keys = _get_primary_keys(model)
self._primary_keys_types = {
Expand All @@ -268,6 +271,7 @@ def __init__(
self.create_schema = create_schema
self.update_schema = update_schema
self.delete_schema = delete_schema
self.select_schema = select_schema
self.include_in_schema = include_in_schema
self.path = path
self.tags = tags or []
Expand Down Expand Up @@ -327,7 +331,15 @@ def _read_item(self):

@_apply_model_pk(**self._primary_keys_types)
async def endpoint(db: AsyncSession = Depends(self.session), **pkeys):
item = await self.crud.get(db, **pkeys)
if self.select_schema is not None:
item = await self.crud.get(
db,
schema_to_select=self.select_schema,
return_as_model=True,
**pkeys,
)
else:
item = await self.crud.get(db, **pkeys)
if not item: # pragma: no cover
raise NotFoundException(detail="Item not found")
return item # pragma: no cover
Expand Down Expand Up @@ -367,9 +379,28 @@ async def endpoint(
items_per_page = 10
offset = compute_offset(page=page, items_per_page=items_per_page) # type: ignore
limit = items_per_page
elif not has_offset_limit:
offset = 0
limit = 100

if self.select_schema is not None:
crud_data = await self.crud.get_multi(
db, offset=offset, limit=limit, **filters
db,
offset=offset, # type: ignore
limit=limit, # type: ignore
schema_to_select=self.select_schema,
return_as_model=True,
**filters,
)
else:
crud_data = await self.crud.get_multi(
db,
offset=offset, # type: ignore
limit=limit, # type: ignore
**filters,
)

if is_paginated:
return paginated_response(
crud_data=crud_data,
page=page, # type: ignore
Expand Down
1 change: 1 addition & 0 deletions fastcrud/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

ModelType = TypeVar("ModelType", bound=Any)

SelectSchemaType = TypeVar("SelectSchemaType", bound=BaseModel)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
UpdateSchemaInternalType = TypeVar("UpdateSchemaInternalType", bound=BaseModel)
Expand Down
Loading

0 comments on commit ef49d95

Please sign in to comment.