Skip to content

Commit

Permalink
Merge pull request #49 from dubusster/upsert
Browse files Browse the repository at this point in the history
feat: ✨ add upsert method in FastCRUD class
  • Loading branch information
igorbenav authored May 14, 2024
2 parents 8b23064 + e0d265a commit adf9ab0
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 3 deletions.
50 changes: 49 additions & 1 deletion fastcrud/crud/fast_crud.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Generic, TypeVar, Union, Optional
from typing import Any, Dict, Generic, TypeVar, Union, Optional
from datetime import datetime, timezone

from pydantic import BaseModel, ValidationError
Expand Down Expand Up @@ -190,6 +190,7 @@ def __init__(
self.is_deleted_column = is_deleted_column
self.deleted_at_column = deleted_at_column
self.updated_at_column = updated_at_column
self._primary_keys = _get_primary_keys(self.model)

def _parse_filters(
self, model: Optional[Union[type[ModelType], AliasedClass]] = None, **kwargs
Expand Down Expand Up @@ -453,6 +454,7 @@ async def get(
Args:
db: The database session to use for the operation.
schema_to_select: Optional Pydantic schema for selecting specific columns.
return_as_model: If True, converts the fetched data to Pydantic models based on schema_to_select. Defaults to False.
one_or_none: Flag to get strictly one or no result. Multiple results are not allowed.
**kwargs: Filters to apply to the query, using field names for direct matches or appending comparison operators for advanced queries.
Expand Down Expand Up @@ -498,6 +500,52 @@ async def get(
)
return schema_to_select(**out)

def _get_pk_dict(self, instance):
return {pk.name: getattr(instance, pk.name) for pk in self._primary_keys}

async def upsert(
self,
db: AsyncSession,
instance: Union[UpdateSchemaType, CreateSchemaType],
schema_to_select: Optional[type[BaseModel]] = None,
return_as_model: bool = False,
) -> Union[BaseModel, Dict[str, Any], None]:
"""Update the instance or create it if it doesn't exists.
Note: This method will perform two transactions to the database (get and create or update).
Args:
db (AsyncSession): The database session to use for the operation.
instance (Union[UpdateSchemaType, type[BaseModel]]): A Pydantic schema representing the instance.
schema_to_select (Optional[type[BaseModel]], optional): Optional Pydantic schema for selecting specific columns. Defaults to None.
return_as_model (bool, optional): If True, converts the fetched data to Pydantic models based on schema_to_select. Defaults to False.
Returns:
BaseModel: the created or updated instance
"""
_pks = self._get_pk_dict(instance)
schema_to_select = schema_to_select or type(instance)
db_instance = await self.get(
db,
schema_to_select=schema_to_select,
return_as_model=return_as_model,
**_pks,
)
if db_instance is None:
db_instance = await self.create(db, instance) # type: ignore
db_instance = schema_to_select.model_validate(
db_instance, from_attributes=True
)
else:
await self.update(db, instance) # type: ignore
db_instance = await self.get(
db,
schema_to_select=schema_to_select,
return_as_model=return_as_model,
**_pks,
)

return db_instance

async def exists(self, db: AsyncSession, **kwargs: Any) -> bool:
"""
Checks if any records exist that match the given filter conditions.
Expand Down
2 changes: 1 addition & 1 deletion tests/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def client(


@pytest.fixture
def endpoint_creator() -> EndpointCreator:
def endpoint_creator(test_model) -> EndpointCreator:
"""Fixture to create an instance of EndpointCreator."""
return EndpointCreator(
session=get_session_local,
Expand Down
16 changes: 16 additions & 0 deletions tests/sqlalchemy/crud/test_upsert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from fastcrud.crud.fast_crud import FastCRUD


@pytest.mark.asyncio
async def test_upsert_successful(async_session, test_model, read_schema):
crud = FastCRUD(test_model)
new_data = read_schema(id=1, name="New Record", tier_id=1, category_id=1)
fetched_record = await crud.upsert(async_session, new_data, return_as_model=True)
assert read_schema.model_validate(fetched_record) == new_data

fetched_record.name == "New name"

updated_fetched_record = await crud.upsert(async_session, fetched_record)
assert read_schema.model_validate(updated_fetched_record) == fetched_record
2 changes: 1 addition & 1 deletion tests/sqlmodel/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def client(


@pytest.fixture
def endpoint_creator() -> EndpointCreator:
def endpoint_creator(test_model) -> EndpointCreator:
"""Fixture to create an instance of EndpointCreator."""
return EndpointCreator(
session=get_session_local,
Expand Down
16 changes: 16 additions & 0 deletions tests/sqlmodel/crud/test_upsert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from fastcrud.crud.fast_crud import FastCRUD


@pytest.mark.asyncio
async def test_upsert_successful(async_session, test_model, read_schema):
crud = FastCRUD(test_model)
new_data = read_schema(id=1, name="New Record", tier_id=1, category_id=1)
fetched_record = await crud.upsert(async_session, new_data, return_as_model=True)
assert read_schema.model_validate(fetched_record) == new_data

fetched_record.name == "New name"

updated_fetched_record = await crud.upsert(async_session, fetched_record)
assert read_schema.model_validate(updated_fetched_record) == fetched_record

0 comments on commit adf9ab0

Please sign in to comment.