diff --git a/fastcrud/crud/fast_crud.py b/fastcrud/crud/fast_crud.py index 0f4f8a7..d4060c5 100644 --- a/fastcrud/crud/fast_crud.py +++ b/fastcrud/crud/fast_crud.py @@ -1,4 +1,4 @@ -from typing import Any, Generic, TypeVar, Union, Optional +from typing import Any, Dict, Generic, TypeVar, Union, Optional from datetime import datetime, timezone from pydantic import BaseModel, ValidationError @@ -190,6 +190,7 @@ def __init__( self.is_deleted_column = is_deleted_column self.deleted_at_column = deleted_at_column self.updated_at_column = updated_at_column + self._primary_keys = _get_primary_keys(self.model) def _parse_filters( self, model: Optional[Union[type[ModelType], AliasedClass]] = None, **kwargs @@ -453,6 +454,7 @@ async def get( Args: db: The database session to use for the operation. schema_to_select: Optional Pydantic schema for selecting specific columns. + return_as_model: If True, converts the fetched data to Pydantic models based on schema_to_select. Defaults to False. one_or_none: Flag to get strictly one or no result. Multiple results are not allowed. **kwargs: Filters to apply to the query, using field names for direct matches or appending comparison operators for advanced queries. @@ -498,6 +500,52 @@ async def get( ) return schema_to_select(**out) + def _get_pk_dict(self, instance): + return {pk.name: getattr(instance, pk.name) for pk in self._primary_keys} + + async def upsert( + self, + db: AsyncSession, + instance: Union[UpdateSchemaType, CreateSchemaType], + schema_to_select: Optional[type[BaseModel]] = None, + return_as_model: bool = False, + ) -> Union[BaseModel, Dict[str, Any], None]: + """Update the instance or create it if it doesn't exists. + Note: This method will perform two transactions to the database (get and create or update). + + Args: + db (AsyncSession): The database session to use for the operation. + instance (Union[UpdateSchemaType, type[BaseModel]]): A Pydantic schema representing the instance. + schema_to_select (Optional[type[BaseModel]], optional): Optional Pydantic schema for selecting specific columns. Defaults to None. + return_as_model (bool, optional): If True, converts the fetched data to Pydantic models based on schema_to_select. Defaults to False. + + Returns: + BaseModel: the created or updated instance + """ + _pks = self._get_pk_dict(instance) + schema_to_select = schema_to_select or type(instance) + db_instance = await self.get( + db, + schema_to_select=schema_to_select, + return_as_model=return_as_model, + **_pks, + ) + if db_instance is None: + db_instance = await self.create(db, instance) # type: ignore + db_instance = schema_to_select.model_validate( + db_instance, from_attributes=True + ) + else: + await self.update(db, instance) # type: ignore + db_instance = await self.get( + db, + schema_to_select=schema_to_select, + return_as_model=return_as_model, + **_pks, + ) + + return db_instance + async def exists(self, db: AsyncSession, **kwargs: Any) -> bool: """ Checks if any records exist that match the given filter conditions. diff --git a/tests/sqlalchemy/conftest.py b/tests/sqlalchemy/conftest.py index e52fc45..18b325b 100644 --- a/tests/sqlalchemy/conftest.py +++ b/tests/sqlalchemy/conftest.py @@ -370,7 +370,7 @@ def client( @pytest.fixture -def endpoint_creator() -> EndpointCreator: +def endpoint_creator(test_model) -> EndpointCreator: """Fixture to create an instance of EndpointCreator.""" return EndpointCreator( session=get_session_local, diff --git a/tests/sqlalchemy/crud/test_upsert.py b/tests/sqlalchemy/crud/test_upsert.py new file mode 100644 index 0000000..7fdaba2 --- /dev/null +++ b/tests/sqlalchemy/crud/test_upsert.py @@ -0,0 +1,16 @@ +import pytest + +from fastcrud.crud.fast_crud import FastCRUD + + +@pytest.mark.asyncio +async def test_upsert_successful(async_session, test_model, read_schema): + crud = FastCRUD(test_model) + new_data = read_schema(id=1, name="New Record", tier_id=1, category_id=1) + fetched_record = await crud.upsert(async_session, new_data, return_as_model=True) + assert read_schema.model_validate(fetched_record) == new_data + + fetched_record.name == "New name" + + updated_fetched_record = await crud.upsert(async_session, fetched_record) + assert read_schema.model_validate(updated_fetched_record) == fetched_record diff --git a/tests/sqlmodel/conftest.py b/tests/sqlmodel/conftest.py index 0054b2d..85ad173 100644 --- a/tests/sqlmodel/conftest.py +++ b/tests/sqlmodel/conftest.py @@ -360,7 +360,7 @@ def client( @pytest.fixture -def endpoint_creator() -> EndpointCreator: +def endpoint_creator(test_model) -> EndpointCreator: """Fixture to create an instance of EndpointCreator.""" return EndpointCreator( session=get_session_local, diff --git a/tests/sqlmodel/crud/test_upsert.py b/tests/sqlmodel/crud/test_upsert.py new file mode 100644 index 0000000..7fdaba2 --- /dev/null +++ b/tests/sqlmodel/crud/test_upsert.py @@ -0,0 +1,16 @@ +import pytest + +from fastcrud.crud.fast_crud import FastCRUD + + +@pytest.mark.asyncio +async def test_upsert_successful(async_session, test_model, read_schema): + crud = FastCRUD(test_model) + new_data = read_schema(id=1, name="New Record", tier_id=1, category_id=1) + fetched_record = await crud.upsert(async_session, new_data, return_as_model=True) + assert read_schema.model_validate(fetched_record) == new_data + + fetched_record.name == "New name" + + updated_fetched_record = await crud.upsert(async_session, fetched_record) + assert read_schema.model_validate(updated_fetched_record) == fetched_record