Skip to content

Commit

Permalink
Add update override arg to upsert_multi
Browse files Browse the repository at this point in the history
  • Loading branch information
feluelle committed Aug 20, 2024
1 parent 66f5a3a commit 7e5fd46
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 5 deletions.
26 changes: 21 additions & 5 deletions fastcrud/crud/fast_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,7 @@ async def upsert_multi(
return_columns: Optional[list[str]] = None,
schema_to_select: Optional[type[BaseModel]] = None,
return_as_model: bool = False,
update_override: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Optional[Dict[str, Any]]:
"""
Expand All @@ -789,6 +790,7 @@ async def upsert_multi(
return_columns: Optional list of column names to return after the upsert operation.
schema_to_select: Optional Pydantic schema for selecting specific columns. Required if return_as_model is True.
return_as_model: If True, returns data as instances of the specified Pydantic model.
update_override: Optional dictionary to override the update values for the upsert operation.
**kwargs: Filters to identify the record(s) to update on conflict, supporting advanced comparison operators for refined querying.
Returns:
Expand All @@ -798,12 +800,18 @@ async def upsert_multi(
ValueError: If the MySQL dialect is used with filters, return_columns, schema_to_select, or return_as_model.
NotImplementedError: If the database dialect is not supported for upsert multi.
"""
if update_override is None:
update_override = {}
filters = self._parse_filters(**kwargs)

if db.bind.dialect.name == "postgresql":
statement, params = await self._upsert_multi_postgresql(instances, filters)
statement, params = await self._upsert_multi_postgresql(
instances, filters, update_override
)
elif db.bind.dialect.name == "sqlite":
statement, params = await self._upsert_multi_sqlite(instances, filters)
statement, params = await self._upsert_multi_sqlite(
instances, filters, update_override
)
elif db.bind.dialect.name in ["mysql", "mariadb"]:
if filters:
raise ValueError(
Expand All @@ -813,7 +821,9 @@ async def upsert_multi(
raise ValueError(
"MySQL does not support the returning clause for insert operations."
)
statement, params = await self._upsert_multi_mysql(instances)
statement, params = await self._upsert_multi_mysql(
instances, update_override
)
else: # pragma: no cover
raise NotImplementedError(
f"Upsert multi is not implemented for {db.bind.dialect.name}"
Expand All @@ -838,6 +848,7 @@ async def _upsert_multi_postgresql(
self,
instances: list[Union[UpdateSchemaType, CreateSchemaType]],
filters: list[ColumnElement],
update_set_override: dict[str, Any],
) -> tuple[Insert, list[dict]]:
statement = postgresql.insert(self.model)
statement = statement.on_conflict_do_update(
Expand All @@ -846,7 +857,8 @@ async def _upsert_multi_postgresql(
column.name: getattr(statement.excluded, column.name)
for column in self.model.__table__.columns
if not column.primary_key and not column.unique
},
}
| update_set_override,
where=and_(*filters) if filters else None,
)
params = [
Expand All @@ -858,6 +870,7 @@ async def _upsert_multi_sqlite(
self,
instances: list[Union[UpdateSchemaType, CreateSchemaType]],
filters: list[ColumnElement],
update_set_override: dict[str, Any],
) -> tuple[Insert, list[dict]]:
statement = sqlite.insert(self.model)
statement = statement.on_conflict_do_update(
Expand All @@ -866,7 +879,8 @@ async def _upsert_multi_sqlite(
column.name: getattr(statement.excluded, column.name)
for column in self.model.__table__.columns
if not column.primary_key and not column.unique
},
}
| update_set_override,
where=and_(*filters) if filters else None,
)
params = [
Expand All @@ -877,6 +891,7 @@ async def _upsert_multi_sqlite(
async def _upsert_multi_mysql(
self,
instances: list[Union[UpdateSchemaType, CreateSchemaType]],
update_set_override: dict[str, Any],
) -> tuple[Insert, list[dict]]:
statement = mysql.insert(self.model)
statement = statement.on_duplicate_key_update(
Expand All @@ -887,6 +902,7 @@ async def _upsert_multi_mysql(
and not column.unique
and column.name != self.deleted_at_column
}
| update_set_override,
)
params = [
self.model(**instance.model_dump()).__dict__ for instance in instances
Expand Down
70 changes: 70 additions & 0 deletions tests/sqlalchemy/crud/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,35 @@ async def test_upsert_successful(async_session, test_model, read_schema):
marks=pytest.mark.dialect("postgresql"),
id="postgresql-dict",
),
pytest.param(
{
"kwargs": {"return_columns": ["id", "name"]},
"expected_result": {
"data": [
{
"id": 1,
"name": "New Record",
}
]
},
},
{
"kwargs": {
"return_columns": ["id", "name"],
"update_override": {"name": "New"},
},
"expected_result": {
"data": [
{
"id": 1,
"name": "New",
}
]
},
},
marks=pytest.mark.dialect("postgresql"),
id="postgresql-dict-update-override",
),
pytest.param(
{
"kwargs": {"return_columns": ["id", "name"]},
Expand Down Expand Up @@ -146,6 +175,35 @@ async def test_upsert_successful(async_session, test_model, read_schema):
marks=pytest.mark.dialect("sqlite"),
id="sqlite-dict",
),
pytest.param(
{
"kwargs": {"return_columns": ["id", "name"]},
"expected_result": {
"data": [
{
"id": 1,
"name": "New Record",
}
]
},
},
{
"kwargs": {
"return_columns": ["id", "name"],
"update_override": {"name": "New"},
},
"expected_result": {
"data": [
{
"id": 1,
"name": "New",
}
]
},
},
marks=pytest.mark.dialect("sqlite"),
id="sqlite-dict-update-override",
),
pytest.param(
{
"kwargs": {"return_columns": ["id", "name"]},
Expand Down Expand Up @@ -208,6 +266,18 @@ async def test_upsert_successful(async_session, test_model, read_schema):
marks=pytest.mark.dialect("mysql"),
id="mysql-none",
),
pytest.param(
{
"kwargs": {},
"expected_result": None,
},
{
"kwargs": {"update_override": {"name": "New"}},
"expected_result": None,
},
marks=pytest.mark.dialect("mysql"),
id="mysql-dict-update-override",
),
],
)
@pytest.mark.asyncio
Expand Down

0 comments on commit 7e5fd46

Please sign in to comment.