Skip to content

Commit

Permalink
Merge pull request #111 from feluelle/feature/update-returning
Browse files Browse the repository at this point in the history
Add returning clause to update
  • Loading branch information
igorbenav authored Jul 5, 2024
2 parents 698c27a + 6fbb474 commit e324631
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 3 deletions.
40 changes: 40 additions & 0 deletions docs/advanced/crud.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,46 @@ await crud_items.delete(
# this will not actually delete until you run a db.commit()
```

## Returning clause in `update`

In `update` method, you can pass `return_columns` parameter containing a list of columns you want to return after the update.

```python
from fastcrud import FastCRUD

from .models.item import Item
from .database import session as db

crud_items = FastCRUD(Item)
item = await crud_items.update(
db=db,
object={"price": 9.99},
price__lt=10
return_columns=["price"]
)
# this will return the updated price
```

You can also pass `schema_to_select` parameter and `return_as_model` to return the updated data in the form of a Pydantic schema.

```python
from fastcrud import FastCRUD

from .models.item import Item
from .schemas.item import ItemSchema
from .database import session as db

crud_items = FastCRUD(Item)
item = await crud_items.update(
db=db,
object={"price": 9.99},
price__lt=10
schema_to_select=ItemSchema,
return_as_model=True
)
# this will return the updated data in the form of ItemSchema
```

## Unpaginated `get_multi` and `get_multi_joined`

If you pass `None` to `limit` in `get_multi` and `get_multi_joined`, you get the whole unpaginated set of data that matches the filters. Use this with caution.
Expand Down
90 changes: 87 additions & 3 deletions fastcrud/crud/fast_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime, timezone

from pydantic import BaseModel, ValidationError
from sqlalchemy import select, update, delete, func, inspect, asc, desc, or_
from sqlalchemy import Result, select, update, delete, func, inspect, asc, desc, or_, column
from sqlalchemy.exc import ArgumentError, MultipleResultsFound, NoResultFound
from sqlalchemy.sql import Join
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -1596,8 +1596,12 @@ async def update(
object: Union[UpdateSchemaType, dict[str, Any]],
allow_multiple: bool = False,
commit: bool = True,
return_columns: Optional[list[str]] = None,
schema_to_select: Optional[type[BaseModel]] = None,
return_as_model: bool = False,
one_or_none: bool = False,
**kwargs: Any,
) -> None:
) -> Optional[Union[dict, BaseModel]]:
"""
Updates an existing record or multiple records in the database based on specified filters. This method allows for precise targeting of records to update.
For filtering details see:
Expand All @@ -1608,14 +1612,19 @@ async def update(
object: A Pydantic schema or dictionary containing the update data.
allow_multiple: If True, allows updating multiple records that match the filters. If False, raises an error if more than one record matches the filters.
commit: If True, commits the transaction immediately. Default is True.
return_columns: A list of column names to return after the update. If return_as_model is True, all columns are returned.
schema_to_select: Pydantic schema for selecting specific columns from the updated record(s). Required if `return_as_model` is True.
return_as_model: If True, returns the updated record(s) as Pydantic model instances based on `schema_to_select`. Default is False.
one_or_none: If True, returns a single record if only one record matches the filters. Default is False.
**kwargs: Filters to identify the record(s) to update, supporting advanced comparison operators for refined querying.
Returns:
None
The updated record(s) as a dictionary or Pydantic model instance or None, depending on the value of `return_as_model` and `return_columns`.
Raises:
MultipleResultsFound: If `allow_multiple` is False and more than one record matches the filters.
ValueError: If extra fields not present in the model are provided in the update data.
ValueError: If `return_as_model` is True but `schema_to_select` is not provided.
Examples:
Update a user's email based on their ID:
Expand All @@ -1632,6 +1641,16 @@ async def update(
```python
update(db, {'username': 'new_username'}, id__ne=1, allow_multiple=False)
```
Update a user's email and return the updated record as a Pydantic model instance:
```python
update(db, {'email': '[email protected]'}, id=1, schema_to_select=UserSchema, return_as_model=True)
```
Update a user's email and return the updated record as a dictionary:
```python
update(db, {'email': '[email protected]'}, id=1, return_columns=['id', 'email'])
```
"""
if not allow_multiple and (total_count := await self.count(db, **kwargs)) > 1:
raise MultipleResultsFound(
Expand All @@ -1656,9 +1675,74 @@ async def update(
filters = self._parse_filters(**kwargs)
stmt = update(self.model).filter(*filters).values(update_data)

if return_as_model:
# All columns are returned to ensure the model can be constructed
return_columns = self.model_col_names

if return_columns:
stmt = stmt.returning(*[column(name) for name in return_columns])
db_row = await db.execute(stmt)
if allow_multiple:
return self._as_multi_response(
db_row,
schema_to_select=schema_to_select,
return_as_model=return_as_model,
)
return self._as_single_response(
db_row,
schema_to_select=schema_to_select,
return_as_model=return_as_model,
one_or_none=one_or_none,
)

await db.execute(stmt)
if commit:
await db.commit()
return None

def _as_single_response(
self,
db_row: Result,
schema_to_select: Optional[type[BaseModel]] = None,
return_as_model: bool = False,
one_or_none: bool = False,
) -> Optional[Union[dict, BaseModel]]:
result: Optional[Row] = db_row.one_or_none() if one_or_none else db_row.first()
if result is None:
return None
out: dict = dict(result._mapping)
if not return_as_model:
return out
if not schema_to_select:
raise ValueError(
"schema_to_select must be provided when return_as_model is True."
)
return schema_to_select(**out)

def _as_multi_response(
self,
db_row: Result,
schema_to_select: Optional[type[BaseModel]] = None,
return_as_model: bool = False,
) -> dict:
data = [dict(row) for row in db_row.mappings()]

response: dict[str, Any] = {"data": data}

if return_as_model:
if not schema_to_select:
raise ValueError(
"schema_to_select must be provided when return_as_model is True."
)
try:
model_data = [schema_to_select(**row) for row in data]
response["data"] = model_data
except ValidationError as e:
raise ValueError(
f"Data validation error for schema {schema_to_select.__name__}: {e}"
)

return response

async def db_delete(
self,
Expand Down
61 changes: 61 additions & 0 deletions tests/sqlalchemy/crud/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,64 @@ async def test_update_auto_updates_updated_at(async_session, test_data):
assert (
updated.updated_at > initial_time
), "updated_at should be later than the initial timestamp."


@pytest.mark.parametrize(
["update_kwargs", "expected_result"],
[
pytest.param(
{"return_columns": ["id", "name"]},
{
"id": 1,
"name": "Updated Name",
},
id="dict",
),
pytest.param(
{"schema_to_select": UpdateSchemaTest, "return_as_model": True},
UpdateSchemaTest(id=1, name="Updated Name"),
id="model",
),
pytest.param(
{"allow_multiple": True, "return_columns": ["id", "name"]},
{
"data": [
{
"id": 1,
"name": "Updated Name",
}
]
},
id="multiple_dict",
),
pytest.param(
{
"allow_multiple": True,
"schema_to_select": UpdateSchemaTest,
"return_as_model": True,
},
{"data": [UpdateSchemaTest(id=1, name="Updated Name")]},
id="multiple_model",
),
],
)
@pytest.mark.asyncio
async def test_update_with_returning(
async_session, test_data, update_kwargs, expected_result
):
for item in test_data:
async_session.add(ModelTest(**item))
await async_session.commit()

crud = FastCRUD(ModelTest)
target_id = test_data[0]["id"]
updated_data = {"name": "Updated Name"}

updated_record = await crud.update(
db=async_session,
object=updated_data,
id=target_id,
**update_kwargs,
)

assert updated_record == expected_result

0 comments on commit e324631

Please sign in to comment.