diff --git a/fastcrud/crud/fast_crud.py b/fastcrud/crud/fast_crud.py index 465a812..29b5f4b 100644 --- a/fastcrud/crud/fast_crud.py +++ b/fastcrud/crud/fast_crud.py @@ -2,7 +2,19 @@ from datetime import datetime, timezone from pydantic import BaseModel, ValidationError -from sqlalchemy import Result, select, update, delete, func, inspect, asc, desc, or_, column +from sqlalchemy import ( + Insert, + Result, + select, + update, + delete, + func, + inspect, + asc, + desc, + or_, + column, +) from sqlalchemy.exc import ArgumentError, MultipleResultsFound, NoResultFound from sqlalchemy.sql import Join from sqlalchemy.ext.asyncio import AsyncSession @@ -10,6 +22,7 @@ from sqlalchemy.orm.util import AliasedClass from sqlalchemy.sql.elements import BinaryExpression, ColumnElement from sqlalchemy.sql.selectable import Select +from sqlalchemy.dialects import postgresql from fastcrud.types import ( CreateSchemaType, @@ -567,6 +580,69 @@ async def upsert( return db_instance + async def upsert_multi( + self, + db: AsyncSession, + instances: list[Union[UpdateSchemaType, CreateSchemaType]], + return_columns: Optional[list[str]] = None, + schema_to_select: Optional[type[BaseModel]] = None, + return_as_model: bool = False, + ) -> Optional[Dict[str, Any]]: + """ + Upsert multiple records in the database. This method is currently only supported for PostgreSQL databases. + + Args: + db: The database session to use for the operation. + instances: A list of Pydantic schemas representing the instances to upsert. + return_columns: Optional list of column names to return after the upsert operation. + schema_to_select: Optional Pydantic schema for selecting specific columns. Required if return_as_model is True. + return_as_model: If True, returns data as instances of the specified Pydantic model. + + Returns: + The updated record(s) as a dictionary or Pydantic model instance or None, depending on the value of `return_as_model` and `return_columns`. + + Raises: + NotImplementedError: If the database dialect is not PostgreSQL. + """ + if db.bind.dialect.name == "postgresql": + statement = await self._upsert_multi_postgresql(instances) + else: + raise NotImplementedError( + f"Upsert multi is not implemented for {db.bind.dialect.name}" + ) + + if return_as_model: + # All columns are returned to ensure the model can be constructed + return_columns = self.model_col_names + + if return_columns: + statement = statement.returning(*[column(name) for name in return_columns]) + db_row = await db.execute(statement) + return self._as_multi_response( + db_row, + schema_to_select=schema_to_select, + return_as_model=return_as_model, + ) + + await db.execute(statement) + return None + + async def _upsert_multi_postgresql( + self, + instances: list[Union[UpdateSchemaType, CreateSchemaType]], + ) -> Insert: + statement = postgresql.insert(self.model) + statement = statement.values([instance.model_dump() for instance in instances]) + statement = statement.on_conflict_do_update( + index_elements=self._primary_keys, + set_={ + column.name: column + for column in statement.excluded + if not column.primary_key and not column.unique + }, + ) + return statement + async def exists(self, db: AsyncSession, **kwargs: Any) -> bool: """ Checks if any records exist that match the given filter conditions. diff --git a/pyproject.toml b/pyproject.toml index 3e73fb5..5d5cf9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ SQLAlchemy = "^2.0.0" pydantic = "^2.0.0" SQLAlchemy-Utils = "^0.41.1" fastapi = ">=0.100.0,<0.112.0" +psycopg = "^3.2.1" [tool.poetry.dev-dependencies] pytest = "^7.4.4" @@ -44,7 +45,13 @@ sqlmodel = "^0.0.14" mypy = "^1.9.0" ruff = "^0.3.4" coverage = "^7.4.4" +testcontainers = "^4.7.1" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.pytest.ini_options] +markers = [ + "dialect(name): mark test to run only on specific SQL dialect", +] diff --git a/tests/sqlalchemy/conftest.py b/tests/sqlalchemy/conftest.py index b2a896c..ea98521 100644 --- a/tests/sqlalchemy/conftest.py +++ b/tests/sqlalchemy/conftest.py @@ -1,3 +1,4 @@ +from contextlib import asynccontextmanager from typing import Optional from datetime import datetime @@ -10,6 +11,7 @@ from fastapi import FastAPI from fastapi.testclient import TestClient from sqlalchemy.sql import func +from testcontainers.postgres import PostgresContainer from fastcrud.crud.fast_crud import FastCRUD from fastcrud.endpoint.crud_router import crud_router @@ -272,13 +274,10 @@ class TaskRead(TaskReadSub): client: Optional[ClientRead] -async_engine = create_async_engine( - "sqlite+aiosqlite:///:memory:", echo=True, future=True -) - +@asynccontextmanager +async def _async_session(url: str) -> AsyncSession: + async_engine = create_async_engine(url, echo=True, future=True) -@pytest_asyncio.fixture(scope="function") -async def async_session() -> AsyncSession: session = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) async with session() as s: @@ -293,6 +292,23 @@ async def async_session() -> AsyncSession: await async_engine.dispose() +@pytest_asyncio.fixture(scope="function") +async def async_session(request: pytest.FixtureRequest) -> AsyncSession: + dialect_marker = request.node.get_closest_marker("dialect") + dialect = dialect_marker.args[0] if dialect_marker else "sqlite" + if dialect == "postgresql": + with PostgresContainer(driver="psycopg") as pg: + async with _async_session( + url=pg.get_connection_url(host=pg.get_container_host_ip()) + ) as session: + yield session + elif dialect == "sqlite": + async with _async_session(url="sqlite+aiosqlite:///:memory:") as session: + yield session + else: + raise ValueError(f"Unsupported dialect: {dialect}") + + @pytest.fixture(scope="function") def test_data() -> list[dict]: return [ diff --git a/tests/sqlalchemy/crud/test_upsert.py b/tests/sqlalchemy/crud/test_upsert.py index 7fdaba2..c0c30cd 100644 --- a/tests/sqlalchemy/crud/test_upsert.py +++ b/tests/sqlalchemy/crud/test_upsert.py @@ -1,6 +1,7 @@ import pytest from fastcrud.crud.fast_crud import FastCRUD +from tests.sqlalchemy.conftest import CategoryModel, ReadSchemaTest, TierModel @pytest.mark.asyncio @@ -14,3 +15,81 @@ async def test_upsert_successful(async_session, test_model, read_schema): updated_fetched_record = await crud.upsert(async_session, fetched_record) assert read_schema.model_validate(updated_fetched_record) == fetched_record + + +@pytest.mark.parametrize( + ["update_kwargs", "expected_insert_result", "expected_update_result"], + [ + pytest.param( + {}, + None, + None, + id="none", + ), + pytest.param( + {"return_columns": ["id", "name"]}, + { + "data": [ + { + "id": 1, + "name": "New Record", + } + ] + }, + { + "data": [ + { + "id": 1, + "name": "New name", + } + ] + }, + id="dict", + ), + pytest.param( + { + "schema_to_select": ReadSchemaTest, + "return_as_model": True, + }, + { + "data": [ + ReadSchemaTest(id=1, name="New Record", tier_id=1, category_id=1) + ] + }, + {"data": [ReadSchemaTest(id=1, name="New name", tier_id=1, category_id=1)]}, + id="model", + ), + ], +) +@pytest.mark.dialect("postgresql") +@pytest.mark.asyncio +async def test_upsert_multi_successful( + async_session, + test_model, + read_schema, + test_data_tier, + test_data_category, + update_kwargs, + expected_insert_result, + expected_update_result, +): + for tier_item in test_data_tier: + async_session.add(TierModel(**tier_item)) + for category_item in test_data_category: + async_session.add(CategoryModel(**category_item)) + await async_session.commit() + + crud = FastCRUD(test_model) + new_data = read_schema(id=1, name="New Record", tier_id=1, category_id=1) + fetched_records = await crud.upsert_multi( + async_session, [new_data], **update_kwargs + ) + + assert fetched_records == expected_insert_result + + updated_new_data = new_data.model_copy(update={"name": "New name"}) + updated_fetched_records = await crud.upsert_multi( + async_session, [updated_new_data], **update_kwargs + ) + + assert updated_fetched_records == expected_update_result