Skip to content

Commit

Permalink
Add upsert_multi
Browse files Browse the repository at this point in the history
  • Loading branch information
feluelle committed Jul 5, 2024
1 parent e324631 commit bdaa922
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 7 deletions.
78 changes: 77 additions & 1 deletion fastcrud/crud/fast_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,27 @@
from datetime import datetime, timezone

from pydantic import BaseModel, ValidationError
from sqlalchemy import Result, select, update, delete, func, inspect, asc, desc, or_, column
from sqlalchemy import (
Insert,
Result,
select,
update,
delete,
func,
inspect,
asc,
desc,
or_,
column,
)
from sqlalchemy.exc import ArgumentError, MultipleResultsFound, NoResultFound
from sqlalchemy.sql import Join
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.engine.row import Row
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql.elements import BinaryExpression, ColumnElement
from sqlalchemy.sql.selectable import Select
from sqlalchemy.dialects import postgresql

from fastcrud.types import (
CreateSchemaType,
Expand Down Expand Up @@ -567,6 +580,69 @@ async def upsert(

return db_instance

async def upsert_multi(
self,
db: AsyncSession,
instances: list[Union[UpdateSchemaType, CreateSchemaType]],
return_columns: Optional[list[str]] = None,
schema_to_select: Optional[type[BaseModel]] = None,
return_as_model: bool = False,
) -> Optional[Dict[str, Any]]:
"""
Upsert multiple records in the database. This method is currently only supported for PostgreSQL databases.
Args:
db: The database session to use for the operation.
instances: A list of Pydantic schemas representing the instances to upsert.
return_columns: Optional list of column names to return after the upsert operation.
schema_to_select: Optional Pydantic schema for selecting specific columns. Required if return_as_model is True.
return_as_model: If True, returns data as instances of the specified Pydantic model.
Returns:
The updated record(s) as a dictionary or Pydantic model instance or None, depending on the value of `return_as_model` and `return_columns`.
Raises:
NotImplementedError: If the database dialect is not PostgreSQL.
"""
if db.bind.dialect.name == "postgresql":
statement = await self._upsert_multi_postgresql(instances)
else:
raise NotImplementedError(
f"Upsert multi is not implemented for {db.bind.dialect.name}"
)

if return_as_model:
# All columns are returned to ensure the model can be constructed
return_columns = self.model_col_names

if return_columns:
statement = statement.returning(*[column(name) for name in return_columns])
db_row = await db.execute(statement)
return self._as_multi_response(
db_row,
schema_to_select=schema_to_select,
return_as_model=return_as_model,
)

await db.execute(statement)
return None

async def _upsert_multi_postgresql(
self,
instances: list[Union[UpdateSchemaType, CreateSchemaType]],
) -> Insert:
statement = postgresql.insert(self.model)
statement = statement.values([instance.model_dump() for instance in instances])
statement = statement.on_conflict_do_update(
index_elements=self._primary_keys,
set_={
column.name: column
for column in statement.excluded
if not column.primary_key and not column.unique
},
)
return statement

async def exists(self, db: AsyncSession, **kwargs: Any) -> bool:
"""
Checks if any records exist that match the given filter conditions.
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ SQLAlchemy = "^2.0.0"
pydantic = "^2.0.0"
SQLAlchemy-Utils = "^0.41.1"
fastapi = ">=0.100.0,<0.112.0"
psycopg = "^3.2.1"

[tool.poetry.dev-dependencies]
pytest = "^7.4.4"
Expand All @@ -44,7 +45,13 @@ sqlmodel = "^0.0.14"
mypy = "^1.9.0"
ruff = "^0.3.4"
coverage = "^7.4.4"
testcontainers = "^4.7.1"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.pytest.ini_options]
markers = [
"dialect(name): mark test to run only on specific SQL dialect",
]
28 changes: 22 additions & 6 deletions tests/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import asynccontextmanager
from typing import Optional
from datetime import datetime

Expand All @@ -10,6 +11,7 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy.sql import func
from testcontainers.postgres import PostgresContainer

from fastcrud.crud.fast_crud import FastCRUD
from fastcrud.endpoint.crud_router import crud_router
Expand Down Expand Up @@ -272,13 +274,10 @@ class TaskRead(TaskReadSub):
client: Optional[ClientRead]


async_engine = create_async_engine(
"sqlite+aiosqlite:///:memory:", echo=True, future=True
)

@asynccontextmanager
async def _async_session(url: str) -> AsyncSession:
async_engine = create_async_engine(url, echo=True, future=True)

@pytest_asyncio.fixture(scope="function")
async def async_session() -> AsyncSession:
session = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)

async with session() as s:
Expand All @@ -293,6 +292,23 @@ async def async_session() -> AsyncSession:
await async_engine.dispose()


@pytest_asyncio.fixture(scope="function")
async def async_session(request: pytest.FixtureRequest) -> AsyncSession:
dialect_marker = request.node.get_closest_marker("dialect")
dialect = dialect_marker.args[0] if dialect_marker else "sqlite"
if dialect == "postgresql":
with PostgresContainer(driver="psycopg") as pg:
async with _async_session(
url=pg.get_connection_url(host=pg.get_container_host_ip())
) as session:
yield session
elif dialect == "sqlite":
async with _async_session(url="sqlite+aiosqlite:///:memory:") as session:
yield session
else:
raise ValueError(f"Unsupported dialect: {dialect}")


@pytest.fixture(scope="function")
def test_data() -> list[dict]:
return [
Expand Down
79 changes: 79 additions & 0 deletions tests/sqlalchemy/crud/test_upsert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from fastcrud.crud.fast_crud import FastCRUD
from tests.sqlalchemy.conftest import CategoryModel, ReadSchemaTest, TierModel


@pytest.mark.asyncio
Expand All @@ -14,3 +15,81 @@ async def test_upsert_successful(async_session, test_model, read_schema):

updated_fetched_record = await crud.upsert(async_session, fetched_record)
assert read_schema.model_validate(updated_fetched_record) == fetched_record


@pytest.mark.parametrize(
["update_kwargs", "expected_insert_result", "expected_update_result"],
[
pytest.param(
{},
None,
None,
id="none",
),
pytest.param(
{"return_columns": ["id", "name"]},
{
"data": [
{
"id": 1,
"name": "New Record",
}
]
},
{
"data": [
{
"id": 1,
"name": "New name",
}
]
},
id="dict",
),
pytest.param(
{
"schema_to_select": ReadSchemaTest,
"return_as_model": True,
},
{
"data": [
ReadSchemaTest(id=1, name="New Record", tier_id=1, category_id=1)
]
},
{"data": [ReadSchemaTest(id=1, name="New name", tier_id=1, category_id=1)]},
id="model",
),
],
)
@pytest.mark.dialect("postgresql")
@pytest.mark.asyncio
async def test_upsert_multi_successful(
async_session,
test_model,
read_schema,
test_data_tier,
test_data_category,
update_kwargs,
expected_insert_result,
expected_update_result,
):
for tier_item in test_data_tier:
async_session.add(TierModel(**tier_item))
for category_item in test_data_category:
async_session.add(CategoryModel(**category_item))
await async_session.commit()

crud = FastCRUD(test_model)
new_data = read_schema(id=1, name="New Record", tier_id=1, category_id=1)
fetched_records = await crud.upsert_multi(
async_session, [new_data], **update_kwargs
)

assert fetched_records == expected_insert_result

updated_new_data = new_data.model_copy(update={"name": "New name"})
updated_fetched_records = await crud.upsert_multi(
async_session, [updated_new_data], **update_kwargs
)

assert updated_fetched_records == expected_update_result

0 comments on commit bdaa922

Please sign in to comment.