diff --git a/docs/usage/crud.md b/docs/usage/crud.md index 0e9e401..db2fc79 100644 --- a/docs/usage/crud.md +++ b/docs/usage/crud.md @@ -234,7 +234,7 @@ items = await item_crud.get_multi(db, offset=0, limit=10, sort_columns=['name'], ```python get_joined( db: AsyncSession, - join_model: Optional[type[DeclarativeBase]] = None, + join_model: Optional[ModelType] = None, join_prefix: Optional[str] = None, join_on: Optional[Union[Join, BinaryExpression]] = None, schema_to_select: Optional[type[BaseModel]] = None, diff --git a/fastcrud/crud/fast_crud.py b/fastcrud/crud/fast_crud.py index 82c3e3d..6814006 100644 --- a/fastcrud/crud/fast_crud.py +++ b/fastcrud/crud/fast_crud.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Generic, TypeVar, Union, Optional, Callable +from typing import Any, Dict, Generic, Union, Optional, Callable from datetime import datetime, timezone from pydantic import BaseModel, ValidationError @@ -7,11 +7,18 @@ from sqlalchemy.sql import Join from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.engine.row import Row -from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm.util import AliasedClass from sqlalchemy.sql.elements import BinaryExpression, ColumnElement from sqlalchemy.sql.selectable import Select +from fastcrud.types import ( + CreateSchemaType, + DeleteSchemaType, + ModelType, + UpdateSchemaInternalType, + UpdateSchemaType, +) + from .helper import ( _extract_matching_columns_from_schema, _auto_detect_join_condition, @@ -23,12 +30,6 @@ from ..endpoint.helper import _get_primary_keys -ModelType = TypeVar("ModelType", bound=DeclarativeBase) -CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) -UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) -UpdateSchemaInternalType = TypeVar("UpdateSchemaInternalType", bound=BaseModel) -DeleteSchemaType = TypeVar("DeleteSchemaType", bound=BaseModel) - class FastCRUD( Generic[ @@ -830,7 +831,7 @@ async def get_joined( self, db: AsyncSession, schema_to_select: Optional[type[BaseModel]] = None, - join_model: Optional[type[DeclarativeBase]] = None, + join_model: Optional[ModelType] = None, join_on: Optional[Union[Join, BinaryExpression]] = None, join_prefix: Optional[str] = None, join_schema_to_select: Optional[type[BaseModel]] = None, diff --git a/fastcrud/crud/helper.py b/fastcrud/crud/helper.py index d45e68e..6f07623 100644 --- a/fastcrud/crud/helper.py +++ b/fastcrud/crud/helper.py @@ -1,12 +1,13 @@ -from typing import Any, Optional, Union, Sequence +from typing import Any, Optional, Union, Sequence, cast from sqlalchemy import inspect -from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm.util import AliasedClass from sqlalchemy.sql import ColumnElement from pydantic import BaseModel, ConfigDict from pydantic.functional_validators import field_validator +from fastcrud.types import ModelType + from ..endpoint.helper import _get_primary_key @@ -38,7 +39,7 @@ def check_valid_join_type(cls, value): def _extract_matching_columns_from_schema( - model: Union[type[DeclarativeBase], AliasedClass], + model: Union[ModelType, AliasedClass], schema: Optional[type[BaseModel]], prefix: Optional[str] = None, alias: Optional[AliasedClass] = None, @@ -63,6 +64,9 @@ def _extract_matching_columns_from_schema( in the schema or all columns from the model if no schema is specified. These columns are correctly referenced through the provided alias if one is given. """ + if not hasattr(model, "__table__"): + raise AttributeError(f"{model.__name__} does not have a '__table__' attribute.") + model_or_alias = alias if alias else model columns = [] temp_prefix = ( @@ -96,7 +100,8 @@ def _extract_matching_columns_from_schema( def _auto_detect_join_condition( - base_model: type[DeclarativeBase], join_model: type[DeclarativeBase] + base_model: ModelType, + join_model: ModelType, ) -> Optional[ColumnElement]: """ Automatically detects the join condition for SQLAlchemy models based on foreign key relationships. @@ -110,18 +115,27 @@ def _auto_detect_join_condition( Raises: ValueError: If the join condition cannot be automatically determined. - - Example: - # Assuming User has a foreign key reference to Tier: - join_condition = auto_detect_join_condition(User, Tier) + AttributeError: If either base_model or join_model does not have a '__table__' attribute. """ + if not hasattr(base_model, "__table__"): + raise AttributeError( + f"{base_model.__name__} does not have a '__table__' attribute." + ) + if not hasattr(join_model, "__table__"): + raise AttributeError( + f"{join_model.__name__} does not have a '__table__' attribute." + ) + inspector = inspect(base_model) if inspector is not None: fk_columns = [col for col in inspector.c if col.foreign_keys] join_on = next( ( - base_model.__table__.c[col.name] - == join_model.__table__.c[list(col.foreign_keys)[0].column.name] + cast( + ColumnElement, + base_model.__table__.c[col.name] + == join_model.__table__.c[list(col.foreign_keys)[0].column.name], + ) for col in fk_columns if list(col.foreign_keys)[0].column.table == join_model.__table__ ), diff --git a/fastcrud/endpoint/crud_router.py b/fastcrud/endpoint/crud_router.py index 7bc777f..8b087e2 100644 --- a/fastcrud/endpoint/crud_router.py +++ b/fastcrud/endpoint/crud_router.py @@ -1,22 +1,22 @@ -from typing import Type, TypeVar, Optional, Union, Sequence, Callable +from typing import Type, Optional, Union, Sequence, Callable from enum import Enum from fastapi import APIRouter -from sqlalchemy.orm import DeclarativeBase -from pydantic import BaseModel +from fastcrud.crud.fast_crud import FastCRUD +from fastcrud.types import ( + CreateSchemaType, + DeleteSchemaType, + ModelType, + UpdateSchemaType, +) from .endpoint_creator import EndpointCreator -from ..crud.fast_crud import FastCRUD from .helper import FilterConfig -CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) -UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) -DeleteSchemaType = TypeVar("DeleteSchemaType", bound=BaseModel) - def crud_router( session: Callable, - model: type[DeclarativeBase], + model: ModelType, create_schema: Type[CreateSchemaType], update_schema: Type[UpdateSchemaType], crud: Optional[FastCRUD] = None, @@ -287,6 +287,7 @@ async def add_routes_to_router(self, ...): # Example GET request: /mymodel/get_multi?id=1&name=example ``` """ + crud = crud or FastCRUD( model=model, is_deleted_column=is_deleted_column, diff --git a/fastcrud/endpoint/endpoint_creator.py b/fastcrud/endpoint/endpoint_creator.py index a4591ee..4f15e1c 100644 --- a/fastcrud/endpoint/endpoint_creator.py +++ b/fastcrud/endpoint/endpoint_creator.py @@ -1,13 +1,18 @@ -import warnings -from typing import Type, TypeVar, Optional, Callable, Sequence, Union +from typing import Type, Optional, Callable, Sequence, Union from enum import Enum +import warnings from fastapi import Depends, Body, Query, APIRouter -from pydantic import BaseModel, ValidationError +from pydantic import ValidationError from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import DeclarativeBase -from ..crud.fast_crud import FastCRUD +from fastcrud.crud.fast_crud import FastCRUD +from fastcrud.types import ( + CreateSchemaType, + DeleteSchemaType, + ModelType, + UpdateSchemaType, +) from ..exceptions.http_exceptions import DuplicateValueException, NotFoundException from ..paginated.helper import compute_offset from ..paginated.response import paginated_response @@ -23,11 +28,6 @@ _get_column_types, ) -CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) -UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) -UpdateSchemaInternalType = TypeVar("UpdateSchemaInternalType", bound=BaseModel) -DeleteSchemaType = TypeVar("DeleteSchemaType", bound=BaseModel) - class EndpointCreator: """ @@ -208,7 +208,7 @@ def get_current_user(token: str = Depends(oauth2_scheme)): def __init__( self, session: Callable, - model: type[DeclarativeBase], + model: ModelType, create_schema: Type[CreateSchemaType], update_schema: Type[UpdateSchemaType], crud: Optional[FastCRUD] = None, @@ -262,7 +262,7 @@ def __init__( "resulting in plain endpoint names. " "For details see:" " https://github.com/igorbenav/fastcrud/issues/67", - DeprecationWarning + DeprecationWarning, ) if filter_config: if isinstance(filter_config, dict): @@ -320,17 +320,14 @@ def _read_items(self): async def endpoint( db: AsyncSession = Depends(self.session), - page: Optional[int] = Query( - None, alias="page", description="Page number" - ), + page: Optional[int] = Query(None, alias="page", description="Page number"), items_per_page: Optional[int] = Query( None, alias="itemsPerPage", description="Number of items per page" ), filters: dict = Depends(dynamic_filters), ): if not (page and items_per_page): - return await self.crud.get_multi(db, offset=0, limit=100, - **filters) + return await self.crud.get_multi(db, offset=0, limit=100, **filters) offset = compute_offset(page=page, items_per_page=items_per_page) crud_data = await self.crud.get_multi( @@ -352,7 +349,7 @@ def _read_paginated(self): "Please use _read_items with optional page and items_per_page " "query params instead, to achieve pagination as before." "Simple _read_items behaviour persists with no breaking changes.", - DeprecationWarning + DeprecationWarning, ) async def endpoint( @@ -366,8 +363,7 @@ async def endpoint( filters: dict = Depends(dynamic_filters), ): if not (page and items_per_page): # pragma: no cover - return await self.crud.get_multi(db, offset=0, limit=100, - **filters) + return await self.crud.get_multi(db, offset=0, limit=100, **filters) offset = compute_offset(page=page, items_per_page=items_per_page) crud_data = await self.crud.get_multi( @@ -427,11 +423,11 @@ def _get_endpoint_path(self, operation: str): ) path = f"{self.path}/{endpoint_name}" if endpoint_name else self.path - if operation in {'read', 'update', 'delete', 'db_delete'}: + if operation in {"read", "update", "delete", "db_delete"}: _primary_keys_path_suffix = "/".join( f"{{{n}}}" for n in self.primary_key_names ) - path = f'{path}/{_primary_keys_path_suffix}' + path = f"{path}/{_primary_keys_path_suffix}" return path @@ -540,7 +536,7 @@ def get_current_user(...): if ("create" in included_methods) and ("create" not in deleted_methods): self.router.add_api_route( - self._get_endpoint_path(operation='create'), + self._get_endpoint_path(operation="create"), self._create_item(), methods=["POST"], include_in_schema=self.include_in_schema, @@ -551,7 +547,7 @@ def get_current_user(...): if ("read" in included_methods) and ("read" not in deleted_methods): self.router.add_api_route( - self._get_endpoint_path(operation='read'), + self._get_endpoint_path(operation="read"), self._read_item(), methods=["GET"], include_in_schema=self.include_in_schema, @@ -562,7 +558,7 @@ def get_current_user(...): if ("read_multi" in included_methods) and ("read_multi" not in deleted_methods): self.router.add_api_route( - self._get_endpoint_path(operation='read_multi'), + self._get_endpoint_path(operation="read_multi"), self._read_items(), methods=["GET"], include_in_schema=self.include_in_schema, @@ -575,7 +571,7 @@ def get_current_user(...): "read_paginated" not in deleted_methods ): self.router.add_api_route( - self._get_endpoint_path(operation='read_paginated'), + self._get_endpoint_path(operation="read_paginated"), self._read_paginated(), methods=["GET"], include_in_schema=self.include_in_schema, @@ -586,7 +582,7 @@ def get_current_user(...): if ("update" in included_methods) and ("update" not in deleted_methods): self.router.add_api_route( - self._get_endpoint_path(operation='update'), + self._get_endpoint_path(operation="update"), self._update_item(), methods=["PATCH"], include_in_schema=self.include_in_schema, @@ -596,7 +592,7 @@ def get_current_user(...): ) if ("delete" in included_methods) and ("delete" not in deleted_methods): - path = self._get_endpoint_path(operation='delete') + path = self._get_endpoint_path(operation="delete") self.router.add_api_route( path, self._delete_item(), @@ -613,7 +609,7 @@ def get_current_user(...): and self.delete_schema ): self.router.add_api_route( - self._get_endpoint_path(operation='db_delete'), + self._get_endpoint_path(operation="db_delete"), self._db_delete(), methods=["DELETE"], include_in_schema=self.include_in_schema, diff --git a/fastcrud/endpoint/helper.py b/fastcrud/endpoint/helper.py index a525c5b..b19c2c7 100644 --- a/fastcrud/endpoint/helper.py +++ b/fastcrud/endpoint/helper.py @@ -7,9 +7,10 @@ from fastapi import Depends, Query, params from sqlalchemy import Column, inspect as sa_inspect -from sqlalchemy.orm import DeclarativeBase from sqlalchemy.sql.elements import KeyedColumnElement +from fastcrud.types import ModelType + F = TypeVar("F", bound=Callable[..., Any]) @@ -71,16 +72,20 @@ def get_params(self) -> dict[str, Any]: def _get_primary_key( - model: type[DeclarativeBase], + model: ModelType, ) -> Union[str, None]: # pragma: no cover key: Optional[str] = _get_primary_keys(model)[0].name return key -def _get_primary_keys(model: type[DeclarativeBase]) -> Sequence[Column]: +def _get_primary_keys( + model: ModelType, +) -> Sequence[Column]: """Get the primary key of a SQLAlchemy model.""" - inspector = sa_inspect(model).mapper - primary_key_columns: Sequence[Column] = inspector.primary_key + inspector_result = sa_inspect(model) + if inspector_result is None: + raise ValueError("Model inspection failed, resulting in None.") + primary_key_columns: Sequence[Column] = inspector_result.mapper.primary_key return primary_key_columns @@ -99,19 +104,25 @@ def _get_python_type(column: Column) -> Optional[type]: ) -def _get_column_types(model: type[DeclarativeBase]) -> dict[str, Union[type, None]]: +def _get_column_types( + model: ModelType, +) -> dict[str, Union[type, None]]: """Get a dictionary of column names and their corresponding Python types from a SQLAlchemy model.""" - inspector = sa_inspect(model).mapper + inspector_result = sa_inspect(model) + if inspector_result is None or inspector_result.mapper is None: + raise ValueError("Model inspection failed, resulting in None.") column_types = {} - for column in inspector.columns: + for column in inspector_result.mapper.columns: column_types[column.name] = _get_python_type(column) return column_types def _extract_unique_columns( - model: type[DeclarativeBase], + model: ModelType, ) -> Sequence[KeyedColumnElement]: """Extracts columns from a SQLAlchemy model that are marked as unique.""" + if not hasattr(model, "__table__"): + raise AttributeError(f"{model.__name__} does not have a '__table__' attribute.") unique_columns = [column for column in model.__table__.columns if column.unique] return unique_columns diff --git a/fastcrud/types.py b/fastcrud/types.py new file mode 100644 index 0000000..2987e2e --- /dev/null +++ b/fastcrud/types.py @@ -0,0 +1,10 @@ +from typing import TypeVar, Any + +from pydantic import BaseModel + +ModelType = TypeVar("ModelType", bound=Any) + +CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) +UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) +UpdateSchemaInternalType = TypeVar("UpdateSchemaInternalType", bound=BaseModel) +DeleteSchemaType = TypeVar("DeleteSchemaType", bound=BaseModel) diff --git a/tests/sqlalchemy/endpoint/test_endpoint_custom_names.py b/tests/sqlalchemy/endpoint/test_endpoint_custom_names.py index 6fc5c6e..2db7bc1 100644 --- a/tests/sqlalchemy/endpoint/test_endpoint_custom_names.py +++ b/tests/sqlalchemy/endpoint/test_endpoint_custom_names.py @@ -4,17 +4,23 @@ from fastcrud import crud_router -@pytest.mark.parametrize('custom_endpoint_names, endpoint_paths', [ - ( +@pytest.mark.parametrize( + "custom_endpoint_names, endpoint_paths", + [ + ( {"create": "", "read": "", "read_multi": ""}, - ["/test_custom_names", "/test_custom_names", - "/test_custom_names"] - ), - ( + ["/test_custom_names", "/test_custom_names", "/test_custom_names"], + ), + ( {"create": "add", "read": "fetch", "read_multi": "fetch_multi"}, - ["/test_custom_names/add", "/test_custom_names/fetch", - "/test_custom_names/fetch_multi"]), -]) + [ + "/test_custom_names/add", + "/test_custom_names/fetch", + "/test_custom_names/fetch_multi", + ], + ), + ], +) @pytest.mark.asyncio async def test_endpoint_custom_names( client: TestClient, @@ -54,19 +60,19 @@ async def test_endpoint_custom_names( item_id = create_response.json()["id"] - fetch_response = client.get(f'{read_path}/{item_id}') + fetch_response = client.get(f"{read_path}/{item_id}") assert ( fetch_response.status_code == 200 ), "Failed to fetch item with custom endpoint name" - assert ( - fetch_response.json()["id"] == item_id - ), (f"Fetched item ID does not match created item ID:" - f" {fetch_response.json()['id']} != {item_id}") + assert fetch_response.json()["id"] == item_id, ( + f"Fetched item ID does not match created item ID:" + f" {fetch_response.json()['id']} != {item_id}" + ) fetch_multi_response = client.get(read_multi_path) assert ( fetch_multi_response.status_code == 200 ), "Failed to fetch multi items with custom endpoint name" assert ( - len(fetch_multi_response.json()['data']) == 12 + len(fetch_multi_response.json()["data"]) == 12 ), "Fetched item list has incorrect length" diff --git a/tests/sqlalchemy/endpoint/test_get_paginated.py b/tests/sqlalchemy/endpoint/test_get_paginated.py index b3d9704..27785d1 100644 --- a/tests/sqlalchemy/endpoint/test_get_paginated.py +++ b/tests/sqlalchemy/endpoint/test_get_paginated.py @@ -91,7 +91,9 @@ async def test_read_paginated_with_filters( # ------ the following tests will completely replace the current ones in the next version of fastcrud ------ @pytest.mark.asyncio -async def test_read_items_with_pagination(client: TestClient, async_session, test_model, test_data): +async def test_read_items_with_pagination( + client: TestClient, async_session, test_model, test_data +): for data in test_data: new_item = test_model(**data) async_session.add(new_item) @@ -100,9 +102,7 @@ async def test_read_items_with_pagination(client: TestClient, async_session, tes page = 1 items_per_page = 5 - response = client.get( - f"/test/get_multi?page={page}&itemsPerPage={items_per_page}" - ) + response = client.get(f"/test/get_multi?page={page}&itemsPerPage={items_per_page}") assert response.status_code == 200 @@ -125,7 +125,9 @@ async def test_read_items_with_pagination(client: TestClient, async_session, tes @pytest.mark.asyncio -async def test_read_items_with_pagination_and_filters(filtered_client: TestClient, async_session, test_model, test_data): +async def test_read_items_with_pagination_and_filters( + filtered_client: TestClient, async_session, test_model, test_data +): for data in test_data: new_item = test_model(**data) async_session.add(new_item) diff --git a/tests/sqlmodel/endpoint/test_get_paginated.py b/tests/sqlmodel/endpoint/test_get_paginated.py index b3d9704..27785d1 100644 --- a/tests/sqlmodel/endpoint/test_get_paginated.py +++ b/tests/sqlmodel/endpoint/test_get_paginated.py @@ -91,7 +91,9 @@ async def test_read_paginated_with_filters( # ------ the following tests will completely replace the current ones in the next version of fastcrud ------ @pytest.mark.asyncio -async def test_read_items_with_pagination(client: TestClient, async_session, test_model, test_data): +async def test_read_items_with_pagination( + client: TestClient, async_session, test_model, test_data +): for data in test_data: new_item = test_model(**data) async_session.add(new_item) @@ -100,9 +102,7 @@ async def test_read_items_with_pagination(client: TestClient, async_session, tes page = 1 items_per_page = 5 - response = client.get( - f"/test/get_multi?page={page}&itemsPerPage={items_per_page}" - ) + response = client.get(f"/test/get_multi?page={page}&itemsPerPage={items_per_page}") assert response.status_code == 200 @@ -125,7 +125,9 @@ async def test_read_items_with_pagination(client: TestClient, async_session, tes @pytest.mark.asyncio -async def test_read_items_with_pagination_and_filters(filtered_client: TestClient, async_session, test_model, test_data): +async def test_read_items_with_pagination_and_filters( + filtered_client: TestClient, async_session, test_model, test_data +): for data in test_data: new_item = test_model(**data) async_session.add(new_item)