diff --git a/tests/sqlmodel/conftest.py b/tests/sqlmodel/conftest.py index b65d85b..aaa0120 100644 --- a/tests/sqlmodel/conftest.py +++ b/tests/sqlmodel/conftest.py @@ -1,5 +1,6 @@ from collections.abc import AsyncGenerator from typing import Optional +from contextlib import asynccontextmanager import pytest import pytest_asyncio @@ -7,11 +8,15 @@ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker +from sqlalchemy import make_url from pydantic import ConfigDict from sqlmodel import SQLModel, Field, Relationship from fastapi import FastAPI from fastapi.testclient import TestClient from sqlalchemy.sql import func +from testcontainers.postgres import PostgresContainer +from testcontainers.mysql import MySqlContainer +from testcontainers.core.docker_client import DockerClient from fastcrud.crud.fast_crud import FastCRUD from fastcrud.endpoint.crud_router import crud_router @@ -266,26 +271,55 @@ class TaskRead(TaskReadSub): assignee: Optional[UserReadSub] client: Optional[ClientRead] +def is_docker_running() -> bool: # pragma: no cover + try: + DockerClient() + return True + except Exception: + return False async_engine = create_async_engine( "sqlite+aiosqlite:///:memory:", echo=True, future=True ) -@pytest_asyncio.fixture(scope="function") -async def async_session() -> AsyncGenerator[AsyncSession]: - session = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) - - async with session() as s: - async with async_engine.begin() as conn: +@asynccontextmanager +async def _setup_database(url: str) -> AsyncGenerator[AsyncSession]: + engine = create_async_engine(url, echo=True, future=True) + session_maker = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + async with session_maker() as session: + async with engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) + try: + yield session + finally: + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.drop_all) + await engine.dispose() - yield s - async with async_engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.drop_all) - - await async_engine.dispose() +@pytest_asyncio.fixture(scope="function") +async def async_session(request: pytest.FixtureRequest) -> AsyncGenerator[AsyncSession]: # pragma: no cover + dialect_marker = request.node.get_closest_marker("dialect") + dialect = dialect_marker.args[0] if dialect_marker else "sqlite" + + if dialect == "postgresql" or dialect == "mysql": + if not is_docker_running(): + pytest.skip("Docker is required, but not running") + + if dialect == "postgresql": + with PostgresContainer() as postgres: + url = postgres.get_connection_url() + async with _setup_database(url) as session: + yield session + elif dialect == "mysql": + with MySqlContainer() as mysql: + url = make_url(mysql.get_connection_url())._replace(drivername="mysql+aiomysql") + async with _setup_database(url) as session: + yield session + else: + async with _setup_database("sqlite+aiosqlite:///:memory:") as session: + yield session @pytest.fixture(scope="function")