From 7e5fd465ea936557e34095196745c7b14ab25c9f Mon Sep 17 00:00:00 2001 From: feluelle Date: Tue, 20 Aug 2024 14:42:03 +0200 Subject: [PATCH] Add update override arg to upsert_multi --- fastcrud/crud/fast_crud.py | 26 +++++++++-- tests/sqlalchemy/crud/test_upsert.py | 70 ++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 5 deletions(-) diff --git a/fastcrud/crud/fast_crud.py b/fastcrud/crud/fast_crud.py index a9333a6..099045c 100644 --- a/fastcrud/crud/fast_crud.py +++ b/fastcrud/crud/fast_crud.py @@ -778,6 +778,7 @@ async def upsert_multi( return_columns: Optional[list[str]] = None, schema_to_select: Optional[type[BaseModel]] = None, return_as_model: bool = False, + update_override: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Optional[Dict[str, Any]]: """ @@ -789,6 +790,7 @@ async def upsert_multi( return_columns: Optional list of column names to return after the upsert operation. schema_to_select: Optional Pydantic schema for selecting specific columns. Required if return_as_model is True. return_as_model: If True, returns data as instances of the specified Pydantic model. + update_override: Optional dictionary to override the update values for the upsert operation. **kwargs: Filters to identify the record(s) to update on conflict, supporting advanced comparison operators for refined querying. Returns: @@ -798,12 +800,18 @@ async def upsert_multi( ValueError: If the MySQL dialect is used with filters, return_columns, schema_to_select, or return_as_model. NotImplementedError: If the database dialect is not supported for upsert multi. """ + if update_override is None: + update_override = {} filters = self._parse_filters(**kwargs) if db.bind.dialect.name == "postgresql": - statement, params = await self._upsert_multi_postgresql(instances, filters) + statement, params = await self._upsert_multi_postgresql( + instances, filters, update_override + ) elif db.bind.dialect.name == "sqlite": - statement, params = await self._upsert_multi_sqlite(instances, filters) + statement, params = await self._upsert_multi_sqlite( + instances, filters, update_override + ) elif db.bind.dialect.name in ["mysql", "mariadb"]: if filters: raise ValueError( @@ -813,7 +821,9 @@ async def upsert_multi( raise ValueError( "MySQL does not support the returning clause for insert operations." ) - statement, params = await self._upsert_multi_mysql(instances) + statement, params = await self._upsert_multi_mysql( + instances, update_override + ) else: # pragma: no cover raise NotImplementedError( f"Upsert multi is not implemented for {db.bind.dialect.name}" @@ -838,6 +848,7 @@ async def _upsert_multi_postgresql( self, instances: list[Union[UpdateSchemaType, CreateSchemaType]], filters: list[ColumnElement], + update_set_override: dict[str, Any], ) -> tuple[Insert, list[dict]]: statement = postgresql.insert(self.model) statement = statement.on_conflict_do_update( @@ -846,7 +857,8 @@ async def _upsert_multi_postgresql( column.name: getattr(statement.excluded, column.name) for column in self.model.__table__.columns if not column.primary_key and not column.unique - }, + } + | update_set_override, where=and_(*filters) if filters else None, ) params = [ @@ -858,6 +870,7 @@ async def _upsert_multi_sqlite( self, instances: list[Union[UpdateSchemaType, CreateSchemaType]], filters: list[ColumnElement], + update_set_override: dict[str, Any], ) -> tuple[Insert, list[dict]]: statement = sqlite.insert(self.model) statement = statement.on_conflict_do_update( @@ -866,7 +879,8 @@ async def _upsert_multi_sqlite( column.name: getattr(statement.excluded, column.name) for column in self.model.__table__.columns if not column.primary_key and not column.unique - }, + } + | update_set_override, where=and_(*filters) if filters else None, ) params = [ @@ -877,6 +891,7 @@ async def _upsert_multi_sqlite( async def _upsert_multi_mysql( self, instances: list[Union[UpdateSchemaType, CreateSchemaType]], + update_set_override: dict[str, Any], ) -> tuple[Insert, list[dict]]: statement = mysql.insert(self.model) statement = statement.on_duplicate_key_update( @@ -887,6 +902,7 @@ async def _upsert_multi_mysql( and not column.unique and column.name != self.deleted_at_column } + | update_set_override, ) params = [ self.model(**instance.model_dump()).__dict__ for instance in instances diff --git a/tests/sqlalchemy/crud/test_upsert.py b/tests/sqlalchemy/crud/test_upsert.py index 5cfedb3..74883da 100644 --- a/tests/sqlalchemy/crud/test_upsert.py +++ b/tests/sqlalchemy/crud/test_upsert.py @@ -58,6 +58,35 @@ async def test_upsert_successful(async_session, test_model, read_schema): marks=pytest.mark.dialect("postgresql"), id="postgresql-dict", ), + pytest.param( + { + "kwargs": {"return_columns": ["id", "name"]}, + "expected_result": { + "data": [ + { + "id": 1, + "name": "New Record", + } + ] + }, + }, + { + "kwargs": { + "return_columns": ["id", "name"], + "update_override": {"name": "New"}, + }, + "expected_result": { + "data": [ + { + "id": 1, + "name": "New", + } + ] + }, + }, + marks=pytest.mark.dialect("postgresql"), + id="postgresql-dict-update-override", + ), pytest.param( { "kwargs": {"return_columns": ["id", "name"]}, @@ -146,6 +175,35 @@ async def test_upsert_successful(async_session, test_model, read_schema): marks=pytest.mark.dialect("sqlite"), id="sqlite-dict", ), + pytest.param( + { + "kwargs": {"return_columns": ["id", "name"]}, + "expected_result": { + "data": [ + { + "id": 1, + "name": "New Record", + } + ] + }, + }, + { + "kwargs": { + "return_columns": ["id", "name"], + "update_override": {"name": "New"}, + }, + "expected_result": { + "data": [ + { + "id": 1, + "name": "New", + } + ] + }, + }, + marks=pytest.mark.dialect("sqlite"), + id="sqlite-dict-update-override", + ), pytest.param( { "kwargs": {"return_columns": ["id", "name"]}, @@ -208,6 +266,18 @@ async def test_upsert_successful(async_session, test_model, read_schema): marks=pytest.mark.dialect("mysql"), id="mysql-none", ), + pytest.param( + { + "kwargs": {}, + "expected_result": None, + }, + { + "kwargs": {"update_override": {"name": "New"}}, + "expected_result": None, + }, + marks=pytest.mark.dialect("mysql"), + id="mysql-dict-update-override", + ), ], ) @pytest.mark.asyncio