Skip to content

Commit

Permalink
Merge pull request #190 from igorbenav/some-fixes
Browse files Browse the repository at this point in the history
Some fixes
  • Loading branch information
igorbenav authored Dec 26, 2024
2 parents 7b9a4d1 + 37e92bf commit 8e58b8c
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 75 deletions.
17 changes: 8 additions & 9 deletions fastcrud/crud/fast_crud.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Dict, Generic, Union, Optional, Callable
from datetime import datetime, timezone
import warnings

from pydantic import ValidationError
from sqlalchemy import (
Expand Down Expand Up @@ -33,6 +32,8 @@
SelectSchemaType,
UpdateSchemaInternalType,
UpdateSchemaType,
GetMultiResponseModel,
GetMultiResponseDict,
)

from .helper import (
Expand Down Expand Up @@ -1149,7 +1150,7 @@ async def get_multi(
return_as_model: bool = False,
return_total_count: bool = True,
**kwargs: Any,
) -> dict[str, Any]:
) -> Union[GetMultiResponseModel[SelectSchemaType], GetMultiResponseDict]:
"""
Fetches multiple records based on filters, supporting sorting, pagination.
Expand All @@ -1167,7 +1168,10 @@ async def get_multi(
**kwargs: Filters to apply to the query, including advanced comparison operators for more detailed querying.
Returns:
A dictionary containing `"data"` with fetched records and `"total_count"` indicating the total number of records matching the filters.
A dictionary containing the data list and optionally the total count:
- With return_as_model=True: Dict with "data": List[SelectSchemaType]
- With return_as_model=False: Dict with "data": List[Dict[str, Any]]
- If return_total_count=True, includes "total_count": int
Raises:
ValueError: If `limit` or `offset` is negative, or if `schema_to_select` is required but not provided or invalid.
Expand Down Expand Up @@ -2206,12 +2210,7 @@ async def update(
"""
total_count = await self.count(db, **kwargs)
if total_count == 0:
warnings.warn(
"Passing non-existing records to `update` will raise NoResultFound on version 0.15.3.",
DeprecationWarning,
stacklevel=2,
)
# raise NoResultFound("No record found to update.")
raise NoResultFound("No record found to update.")
if not allow_multiple and total_count > 1:
raise MultipleResultsFound(
f"Expected exactly one record to update, found {total_count}."
Expand Down
20 changes: 14 additions & 6 deletions fastcrud/endpoint/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,10 @@ def _get_column_types(
column_types = {}
for column in inspector_result.mapper.columns:
column_type = _get_python_type(column)
if hasattr(column.type, "__visit_name__") and column.type.__visit_name__ == "uuid":
if (
hasattr(column.type, "__visit_name__")
and column.type.__visit_name__ == "uuid"
):
column_type = UUID
column_types[column.name] = column_type
return column_types
Expand Down Expand Up @@ -191,15 +194,13 @@ def wrapper(endpoint):
inspect.Parameter(
name=k,
annotation=Annotated[UUID, Path(...)],
kind=inspect.Parameter.POSITIONAL_ONLY
kind=inspect.Parameter.POSITIONAL_ONLY,
)
)
else:
extra_positional_params.append(
inspect.Parameter(
name=k,
annotation=v,
kind=inspect.Parameter.POSITIONAL_ONLY
name=k, annotation=v, kind=inspect.Parameter.POSITIONAL_ONLY
)
)

Expand All @@ -223,7 +224,14 @@ def filters(
filtered_params = {}
for key, value in kwargs.items():
if value is not None:
filtered_params[key] = value
parse_func = column_types.get(key)
if parse_func:
try:
filtered_params[key] = parse_func(value)
except (ValueError, TypeError):
filtered_params[key] = value
else:
filtered_params[key] = value
return filtered_params

params = []
Expand Down
5 changes: 4 additions & 1 deletion fastcrud/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar, Any
from typing import TypeVar, Any, Dict, Union, List

from pydantic import BaseModel

Expand All @@ -9,3 +9,6 @@
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
UpdateSchemaInternalType = TypeVar("UpdateSchemaInternalType", bound=BaseModel)
DeleteSchemaType = TypeVar("DeleteSchemaType", bound=BaseModel)

GetMultiResponseDict = Dict[str, Union[List[Dict[str, Any]], int]]
GetMultiResponseModel = Dict[str, Union[List[SelectSchemaType], int]]
42 changes: 42 additions & 0 deletions tests/sqlalchemy/core/test_uuid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from fastapi.testclient import TestClient

from fastcrud import crud_router, FastCRUD
from fastcrud import FilterConfig
from fastcrud.endpoint.helper import _create_dynamic_filters
from pydantic import BaseModel

from ..conftest import Base
Expand Down Expand Up @@ -202,3 +204,43 @@ async def test_uuid_list_endpoint(uuid_client):
UUID(item["id"])
except ValueError: # pragma: no cover
pytest.fail("Invalid UUID format in list response")


def test_create_dynamic_filters_type_conversion():
filter_config = FilterConfig(uuid_field=None, int_field=None, str_field=None)
column_types = {
"uuid_field": UUID,
"int_field": int,
"str_field": str,
}

filters_func = _create_dynamic_filters(filter_config, column_types)

test_uuid = "123e4567-e89b-12d3-a456-426614174000"
result = filters_func(uuid_field=test_uuid, int_field="123", str_field=456)

assert isinstance(result["uuid_field"], UUID)
assert result["uuid_field"] == UUID(test_uuid)
assert isinstance(result["int_field"], int)
assert result["int_field"] == 123
assert isinstance(result["str_field"], str)
assert result["str_field"] == "456"

result = filters_func(
uuid_field="not-a-uuid", int_field="not-an-int", str_field=456
)

assert result["uuid_field"] == "not-a-uuid"
assert result["int_field"] == "not-an-int"
assert isinstance(result["str_field"], str)

result = filters_func(uuid_field=None, int_field="123", str_field=None)
assert "uuid_field" not in result
assert result["int_field"] == 123
assert "str_field" not in result

result = filters_func(unknown_field="test")
assert result["unknown_field"] == "test"

empty_filters_func = _create_dynamic_filters(None, {})
assert empty_filters_func() == {}
31 changes: 1 addition & 30 deletions tests/sqlalchemy/crud/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from sqlalchemy import select
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.exc import MultipleResultsFound, NoResultFound

from fastcrud.crud.fast_crud import FastCRUD
from ...sqlalchemy.conftest import ModelTest, UpdateSchemaTest, ModelTestWithTimestamp
Expand Down Expand Up @@ -51,24 +51,11 @@ async def test_update_non_existent_record(async_session, test_data):
crud = FastCRUD(ModelTest)
non_existent_id = 99999
updated_data = {"name": "New Name"}
"""
In version 0.15.3, the `update` method will raise a `NoResultFound` exception:

```
with pytest.raises(NoResultFound) as exc_info:
await crud.update(db=async_session, object=updated_data, id=non_existent_id)

assert "No record found to update" in str(exc_info.value)
```
For 0.15.2, the test will check if the record is not updated.
"""
await crud.update(db=async_session, object=updated_data, id=non_existent_id)

record = await async_session.execute(
select(ModelTest).where(ModelTest.id == non_existent_id)
)
assert record.scalar_one_or_none() is None


@pytest.mark.asyncio
Expand All @@ -81,26 +68,10 @@ async def test_update_invalid_filters(async_session, test_data):
updated_data = {"name": "New Name"}

non_matching_filter = {"name": "NonExistingName"}
"""
In version 0.15.3, the `update` method will raise a `NoResultFound` exception:
```
with pytest.raises(NoResultFound) as exc_info:
await crud.update(db=async_session, object=updated_data, **non_matching_filter)

assert "No record found to update" in str(exc_info.value)
```
For 0.15.2, the test will check if the record is not updated.
"""
await crud.update(db=async_session, object=updated_data, **non_matching_filter)

for item in test_data:
record = await async_session.execute(
select(ModelTest).where(ModelTest.id == item["id"])
)
fetched_record = record.scalar_one()
assert fetched_record.name != "New Name"


@pytest.mark.asyncio
Expand Down
42 changes: 42 additions & 0 deletions tests/sqlmodel/core/test_uuid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from sqlmodel import Field, SQLModel

from fastcrud import crud_router, FastCRUD
from fastcrud import FilterConfig
from fastcrud.endpoint.helper import _create_dynamic_filters
from pydantic import ConfigDict


Expand Down Expand Up @@ -204,3 +206,43 @@ async def test_uuid_list_endpoint(uuid_client):
UUID(item["id"])
except ValueError: # pragma: no cover
pytest.fail("Invalid UUID format in list response")


def test_create_dynamic_filters_type_conversion():
filter_config = FilterConfig(uuid_field=None, int_field=None, str_field=None)
column_types = {
"uuid_field": UUID,
"int_field": int,
"str_field": str,
}

filters_func = _create_dynamic_filters(filter_config, column_types)

test_uuid = "123e4567-e89b-12d3-a456-426614174000"
result = filters_func(uuid_field=test_uuid, int_field="123", str_field=456)

assert isinstance(result["uuid_field"], UUID)
assert result["uuid_field"] == UUID(test_uuid)
assert isinstance(result["int_field"], int)
assert result["int_field"] == 123
assert isinstance(result["str_field"], str)
assert result["str_field"] == "456"

result = filters_func(
uuid_field="not-a-uuid", int_field="not-an-int", str_field=456
)

assert result["uuid_field"] == "not-a-uuid"
assert result["int_field"] == "not-an-int"
assert isinstance(result["str_field"], str)

result = filters_func(uuid_field=None, int_field="123", str_field=None)
assert "uuid_field" not in result
assert result["int_field"] == 123
assert "str_field" not in result

result = filters_func(unknown_field="test")
assert result["unknown_field"] == "test"

empty_filters_func = _create_dynamic_filters(None, {})
assert empty_filters_func() == {}
30 changes: 1 addition & 29 deletions tests/sqlmodel/crud/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from sqlalchemy import select
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.exc import MultipleResultsFound, NoResultFound

from fastcrud.crud.fast_crud import FastCRUD
from ...sqlmodel.conftest import ModelTest, UpdateSchemaTest, ModelTestWithTimestamp
Expand Down Expand Up @@ -51,24 +51,11 @@ async def test_update_non_existent_record(async_session, test_data):
crud = FastCRUD(ModelTest)
non_existent_id = 99999
updated_data = {"name": "New Name"}
"""
In version 0.15.3, the `update` method will raise a `NoResultFound` exception:

```
with pytest.raises(NoResultFound) as exc_info:
await crud.update(db=async_session, object=updated_data, id=non_existent_id)

assert "No record found to update" in str(exc_info.value)
```
For 0.15.2, the test will check if the record is not updated.
"""
await crud.update(db=async_session, object=updated_data, id=non_existent_id)

record = await async_session.execute(
select(ModelTest).where(ModelTest.id == non_existent_id)
)
assert record.scalar_one_or_none() is None


@pytest.mark.asyncio
Expand All @@ -81,26 +68,11 @@ async def test_update_invalid_filters(async_session, test_data):
updated_data = {"name": "New Name"}

non_matching_filter = {"name": "NonExistingName"}
"""
In version 0.15.3, the `update` method will raise a `NoResultFound` exception:

```
with pytest.raises(NoResultFound) as exc_info:
await crud.update(db=async_session, object=updated_data, **non_matching_filter)

assert "No record found to update" in str(exc_info.value)
```
For 0.15.2, the test will check if the record is not updated.
"""
await crud.update(db=async_session, object=updated_data, **non_matching_filter)

for item in test_data:
record = await async_session.execute(
select(ModelTest).where(ModelTest.id == item["id"])
)
fetched_record = record.scalar_one()
assert fetched_record.name != "New Name"


@pytest.mark.asyncio
Expand Down

0 comments on commit 8e58b8c

Please sign in to comment.