-
-
Notifications
You must be signed in to change notification settings - Fork 65
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #111 from feluelle/feature/update-returning
Add returning clause to update
- Loading branch information
Showing
3 changed files
with
188 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
from datetime import datetime, timezone | ||
|
||
from pydantic import BaseModel, ValidationError | ||
from sqlalchemy import select, update, delete, func, inspect, asc, desc, or_ | ||
from sqlalchemy import Result, select, update, delete, func, inspect, asc, desc, or_, column | ||
from sqlalchemy.exc import ArgumentError, MultipleResultsFound, NoResultFound | ||
from sqlalchemy.sql import Join | ||
from sqlalchemy.ext.asyncio import AsyncSession | ||
|
@@ -1596,8 +1596,12 @@ async def update( | |
object: Union[UpdateSchemaType, dict[str, Any]], | ||
allow_multiple: bool = False, | ||
commit: bool = True, | ||
return_columns: Optional[list[str]] = None, | ||
schema_to_select: Optional[type[BaseModel]] = None, | ||
return_as_model: bool = False, | ||
one_or_none: bool = False, | ||
**kwargs: Any, | ||
) -> None: | ||
) -> Optional[Union[dict, BaseModel]]: | ||
""" | ||
Updates an existing record or multiple records in the database based on specified filters. This method allows for precise targeting of records to update. | ||
For filtering details see: | ||
|
@@ -1608,14 +1612,19 @@ async def update( | |
object: A Pydantic schema or dictionary containing the update data. | ||
allow_multiple: If True, allows updating multiple records that match the filters. If False, raises an error if more than one record matches the filters. | ||
commit: If True, commits the transaction immediately. Default is True. | ||
return_columns: A list of column names to return after the update. If return_as_model is True, all columns are returned. | ||
schema_to_select: Pydantic schema for selecting specific columns from the updated record(s). Required if `return_as_model` is True. | ||
return_as_model: If True, returns the updated record(s) as Pydantic model instances based on `schema_to_select`. Default is False. | ||
one_or_none: If True, returns a single record if only one record matches the filters. Default is False. | ||
**kwargs: Filters to identify the record(s) to update, supporting advanced comparison operators for refined querying. | ||
Returns: | ||
None | ||
The updated record(s) as a dictionary or Pydantic model instance or None, depending on the value of `return_as_model` and `return_columns`. | ||
Raises: | ||
MultipleResultsFound: If `allow_multiple` is False and more than one record matches the filters. | ||
ValueError: If extra fields not present in the model are provided in the update data. | ||
ValueError: If `return_as_model` is True but `schema_to_select` is not provided. | ||
Examples: | ||
Update a user's email based on their ID: | ||
|
@@ -1632,6 +1641,16 @@ async def update( | |
```python | ||
update(db, {'username': 'new_username'}, id__ne=1, allow_multiple=False) | ||
``` | ||
Update a user's email and return the updated record as a Pydantic model instance: | ||
```python | ||
update(db, {'email': '[email protected]'}, id=1, schema_to_select=UserSchema, return_as_model=True) | ||
``` | ||
Update a user's email and return the updated record as a dictionary: | ||
```python | ||
update(db, {'email': '[email protected]'}, id=1, return_columns=['id', 'email']) | ||
``` | ||
""" | ||
if not allow_multiple and (total_count := await self.count(db, **kwargs)) > 1: | ||
raise MultipleResultsFound( | ||
|
@@ -1656,9 +1675,74 @@ async def update( | |
filters = self._parse_filters(**kwargs) | ||
stmt = update(self.model).filter(*filters).values(update_data) | ||
|
||
if return_as_model: | ||
# All columns are returned to ensure the model can be constructed | ||
return_columns = self.model_col_names | ||
|
||
if return_columns: | ||
stmt = stmt.returning(*[column(name) for name in return_columns]) | ||
db_row = await db.execute(stmt) | ||
if allow_multiple: | ||
return self._as_multi_response( | ||
db_row, | ||
schema_to_select=schema_to_select, | ||
return_as_model=return_as_model, | ||
) | ||
return self._as_single_response( | ||
db_row, | ||
schema_to_select=schema_to_select, | ||
return_as_model=return_as_model, | ||
one_or_none=one_or_none, | ||
) | ||
|
||
await db.execute(stmt) | ||
if commit: | ||
await db.commit() | ||
return None | ||
|
||
def _as_single_response( | ||
self, | ||
db_row: Result, | ||
schema_to_select: Optional[type[BaseModel]] = None, | ||
return_as_model: bool = False, | ||
one_or_none: bool = False, | ||
) -> Optional[Union[dict, BaseModel]]: | ||
result: Optional[Row] = db_row.one_or_none() if one_or_none else db_row.first() | ||
if result is None: | ||
return None | ||
out: dict = dict(result._mapping) | ||
if not return_as_model: | ||
return out | ||
if not schema_to_select: | ||
raise ValueError( | ||
"schema_to_select must be provided when return_as_model is True." | ||
) | ||
return schema_to_select(**out) | ||
|
||
def _as_multi_response( | ||
self, | ||
db_row: Result, | ||
schema_to_select: Optional[type[BaseModel]] = None, | ||
return_as_model: bool = False, | ||
) -> dict: | ||
data = [dict(row) for row in db_row.mappings()] | ||
|
||
response: dict[str, Any] = {"data": data} | ||
|
||
if return_as_model: | ||
if not schema_to_select: | ||
raise ValueError( | ||
"schema_to_select must be provided when return_as_model is True." | ||
) | ||
try: | ||
model_data = [schema_to_select(**row) for row in data] | ||
response["data"] = model_data | ||
except ValidationError as e: | ||
raise ValueError( | ||
f"Data validation error for schema {schema_to_select.__name__}: {e}" | ||
) | ||
|
||
return response | ||
|
||
async def db_delete( | ||
self, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters