diff --git a/fastcrud/endpoint/helper.py b/fastcrud/endpoint/helper.py index 32677d1..0e5c01e 100644 --- a/fastcrud/endpoint/helper.py +++ b/fastcrud/endpoint/helper.py @@ -1,11 +1,14 @@ import inspect +from uuid import UUID from typing import Optional, Union, Annotated, Sequence, Callable, TypeVar, Any from pydantic import BaseModel, Field from pydantic.functional_validators import field_validator -from fastapi import Depends, Query, params +from fastapi import Depends, Query, Path, params from sqlalchemy import Column, inspect as sa_inspect +from sqlalchemy.dialects.postgresql import UUID as PostgresUUID +from sqlalchemy.types import TypeEngine from sqlalchemy.sql.elements import KeyedColumnElement from fastcrud.types import ModelType @@ -87,12 +90,36 @@ def _get_primary_keys( return primary_key_columns +def _is_uuid_type(column_type: TypeEngine) -> bool: # pragma: no cover + """ + Check if a SQLAlchemy column type represents a UUID. + Handles various SQL dialects and common UUID implementations. + """ + if isinstance(column_type, PostgresUUID): + return True + + type_name = getattr(column_type, "__visit_name__", "").lower() + if "uuid" in type_name: + return True + + if hasattr(column_type, "impl"): + return _is_uuid_type(column_type.impl) + + return False + + def _get_python_type(column: Column) -> Optional[type]: + """Get the Python type for a SQLAlchemy column, with special handling for UUIDs.""" try: + if _is_uuid_type(column.type): + return UUID + direct_type: Optional[type] = column.type.python_type return direct_type except NotImplementedError: if hasattr(column.type, "impl") and hasattr(column.type.impl, "python_type"): + if _is_uuid_type(column.type.impl): # pragma: no cover + return UUID indirect_type: Optional[type] = column.type.impl.python_type return indirect_type else: # pragma: no cover @@ -110,7 +137,10 @@ def _get_column_types( raise ValueError("Model inspection failed, resulting in None.") column_types = {} for column in inspector_result.mapper.columns: - column_types[column.name] = _get_python_type(column) + column_type = _get_python_type(column) + if hasattr(column.type, "__visit_name__") and column.type.__visit_name__ == "uuid": + column_type = UUID + column_types[column.name] = column_type return column_types @@ -154,12 +184,24 @@ def wrapper(endpoint): for p in signature.parameters.values() if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD ] - extra_positional_params = [ - inspect.Parameter( - name=k, annotation=v, kind=inspect.Parameter.POSITIONAL_ONLY - ) - for k, v in pkeys.items() - ] + extra_positional_params = [] + for k, v in pkeys.items(): + if v == UUID: + extra_positional_params.append( + inspect.Parameter( + name=k, + annotation=Annotated[UUID, Path(...)], + kind=inspect.Parameter.POSITIONAL_ONLY + ) + ) + else: + extra_positional_params.append( + inspect.Parameter( + name=k, + annotation=v, + kind=inspect.Parameter.POSITIONAL_ONLY + ) + ) endpoint.__signature__ = signature.replace( parameters=extra_positional_params + parameters diff --git a/pyproject.toml b/pyproject.toml index 3ed971d..aaa99ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,8 @@ mypy = "^1.9.0" ruff = "^0.3.4" coverage = "^7.4.4" testcontainers = "^4.7.1" +asyncpg = "^0.30.0" +psycopg2-binary = "^2.9.10" psycopg = "^3.2.1" aiomysql = "^0.2.0" cryptography = "^43.0.1" diff --git a/tests/sqlalchemy/core/test_uuid.py b/tests/sqlalchemy/core/test_uuid.py new file mode 100644 index 0000000..ca6a439 --- /dev/null +++ b/tests/sqlalchemy/core/test_uuid.py @@ -0,0 +1,204 @@ +import pytest +from uuid import UUID, uuid4 + +from sqlalchemy import Column, String +from sqlalchemy.dialects.postgresql import UUID as PostgresUUID +from sqlalchemy.types import TypeDecorator +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from fastcrud import crud_router, FastCRUD +from pydantic import BaseModel + +from ..conftest import Base + + +class UUIDType(TypeDecorator): + """Platform-independent UUID type. + Uses PostgreSQL's UUID type, otherwise CHAR(36) + """ + + impl = String + cache_ok = True + + def __init__(self): + super().__init__(36) + + def load_dialect_impl(self, dialect): + if dialect.name == "postgresql": + return dialect.type_descriptor(PostgresUUID(as_uuid=True)) + else: + return dialect.type_descriptor(String(36)) + + def process_bind_param(self, value, dialect): + if value is None: # pragma: no cover + return value + elif dialect.name == "postgresql": # pragma: no cover + return value + else: + return str(value) + + def process_result_value(self, value, dialect): + if value is None: # pragma: no cover + return value + if not isinstance(value, UUID): + return UUID(value) + return value # pragma: no cover + + +class UUIDModel(Base): + __tablename__ = "uuid_test" + id = Column(UUIDType(), primary_key=True, default=uuid4) + name = Column(String(255)) + + +class CustomUUID(TypeDecorator): + """Custom UUID type for testing.""" + + impl = String + cache_ok = True + + def __init__(self): + super().__init__(36) + self.__visit_name__ = "uuid" + + def process_bind_param(self, value, dialect): + if value is None: # pragma: no cover + return value + return str(value) + + def process_result_value(self, value, dialect): + if value is None: # pragma: no cover + return value + return UUID(value) + + +class CustomUUIDModel(Base): + __tablename__ = "custom_uuid_test" + id = Column(CustomUUID(), primary_key=True, default=uuid4) + name = Column(String(255)) + + +class UUIDSchema(BaseModel): + id: UUID + name: str + + model_config = {"from_attributes": True} + + +class CreateUUIDSchema(BaseModel): + name: str + + model_config = {"from_attributes": True} + + +class UpdateUUIDSchema(BaseModel): + name: str + + model_config = {"from_attributes": True} + + +@pytest.fixture +def uuid_client(async_session): + app = FastAPI() + + app.include_router( + crud_router( + session=lambda: async_session, + model=UUIDModel, + crud=FastCRUD(UUIDModel), + create_schema=CreateUUIDSchema, + update_schema=UpdateUUIDSchema, + path="/uuid-test", + tags=["uuid-test"], + endpoint_names={ + "create": "create", + "read": "get", + "update": "update", + "delete": "delete", + "read_multi": "get_multi", + }, + ) + ) + + app.include_router( + crud_router( + session=lambda: async_session, + model=CustomUUIDModel, + crud=FastCRUD(CustomUUIDModel), + create_schema=CreateUUIDSchema, + update_schema=UpdateUUIDSchema, + path="/custom-uuid-test", + tags=["custom-uuid-test"], + endpoint_names={ + "create": "create", + "read": "get", + "update": "update", + "delete": "delete", + "read_multi": "get_multi", + }, + ) + ) + + return TestClient(app) + + +@pytest.mark.asyncio +@pytest.mark.dialect("sqlite") +async def test_custom_uuid_crud(uuid_client): + response = uuid_client.post("/custom-uuid-test/create", json={"name": "test"}) + assert ( + response.status_code == 200 + ), f"Creation failed with response: {response.text}" + + try: + data = response.json() + assert "id" in data, f"Response does not contain 'id': {data}" + uuid_id = data["id"] + except Exception as e: # pragma: no cover + pytest.fail(f"Failed to process response: {response.text}. Error: {str(e)}") + + try: + UUID(uuid_id) + except ValueError: # pragma: no cover + pytest.fail("Invalid UUID format") + + response = uuid_client.get(f"/custom-uuid-test/get/{uuid_id}") + assert response.status_code == 200 + assert response.json()["id"] == uuid_id + assert response.json()["name"] == "test" + + update_response = uuid_client.patch( + f"/custom-uuid-test/update/{uuid_id}", json={"name": "updated"} + ) + response = uuid_client.get(f"/custom-uuid-test/get/{uuid_id}") + + assert update_response.status_code == 200 + assert response.status_code == 200 + assert response.json()["name"] == "updated" + + response = uuid_client.delete(f"/custom-uuid-test/delete/{uuid_id}") + assert response.status_code == 200 + + response = uuid_client.get(f"/custom-uuid-test/get/{uuid_id}") + assert response.status_code == 404 + + +@pytest.mark.asyncio +async def test_uuid_list_endpoint(uuid_client): + created_ids = [] + for i in range(3): + response = uuid_client.post("/uuid-test/create", json={"name": f"test_{i}"}) + assert response.status_code == 200 + created_ids.append(response.json()["id"]) + + response = uuid_client.get("/uuid-test/get_multi") + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 3 + + for item in data: + try: + UUID(item["id"]) + except ValueError: # pragma: no cover + pytest.fail("Invalid UUID format in list response") diff --git a/tests/sqlmodel/core/test_uuid.py b/tests/sqlmodel/core/test_uuid.py new file mode 100644 index 0000000..0378305 --- /dev/null +++ b/tests/sqlmodel/core/test_uuid.py @@ -0,0 +1,206 @@ +import pytest +from uuid import UUID, uuid4 + +from sqlalchemy import Column, String +from sqlalchemy.dialects.postgresql import UUID as PostgresUUID +from sqlalchemy.types import TypeDecorator +from fastapi import FastAPI +from fastapi.testclient import TestClient +from sqlmodel import Field, SQLModel + +from fastcrud import crud_router, FastCRUD +from pydantic import ConfigDict + + +class UUIDType(TypeDecorator): + """Platform-independent UUID type. + Uses PostgreSQL's UUID type, otherwise CHAR(36) + """ + + impl = String + cache_ok = True + + def __init__(self): + super().__init__(36) + + def load_dialect_impl(self, dialect): + if dialect.name == "postgresql": # pragma: no cover + return dialect.type_descriptor(PostgresUUID(as_uuid=True)) + else: + return dialect.type_descriptor(String(36)) + + def process_bind_param(self, value, dialect): + if value is None: # pragma: no cover + return value + elif dialect.name == "postgresql": # pragma: no cover + return value + else: + return str(value) + + def process_result_value(self, value, dialect): + if value is None: + return value # pragma: no cover + if not isinstance(value, UUID): + return UUID(value) + return value # pragma: no cover + + +class UUIDModel(SQLModel, table=True): + __tablename__ = "uuid_test" + id: UUID = Field( + default_factory=uuid4, sa_column=Column(UUIDType(), primary_key=True) + ) + name: str = Field(sa_column=Column(String(255))) + + +class CustomUUID(TypeDecorator): + """Custom UUID type for testing.""" + + impl = String + cache_ok = True + + def __init__(self): + super().__init__(36) + self.__visit_name__ = "uuid" + + def process_bind_param(self, value, dialect): + if value is None: # pragma: no cover + return value + return str(value) + + def process_result_value(self, value, dialect): + if value is None: # pragma: no cover + return value + return UUID(value) + + +class CustomUUIDModel(SQLModel, table=True): + __tablename__ = "custom_uuid_test" + id: UUID = Field( + default_factory=uuid4, sa_column=Column(CustomUUID(), primary_key=True) + ) + name: str = Field(sa_column=Column(String(255))) + + +class UUIDSchema(SQLModel): + id: UUID + name: str + + model_config = ConfigDict(from_attributes=True) + + +class CreateUUIDSchema(SQLModel): + name: str + + model_config = ConfigDict(from_attributes=True) + + +class UpdateUUIDSchema(SQLModel): + name: str + + model_config = ConfigDict(from_attributes=True) + + +@pytest.fixture +def uuid_client(async_session): + app = FastAPI() + + app.include_router( + crud_router( + session=lambda: async_session, + model=UUIDModel, + crud=FastCRUD(UUIDModel), + create_schema=CreateUUIDSchema, + update_schema=UpdateUUIDSchema, + path="/uuid-test", + tags=["uuid-test"], + endpoint_names={ + "create": "create", + "read": "get", + "update": "update", + "delete": "delete", + "read_multi": "get_multi", + }, + ) + ) + + app.include_router( + crud_router( + session=lambda: async_session, + model=CustomUUIDModel, + crud=FastCRUD(CustomUUIDModel), + create_schema=CreateUUIDSchema, + update_schema=UpdateUUIDSchema, + path="/custom-uuid-test", + tags=["custom-uuid-test"], + endpoint_names={ + "create": "create", + "read": "get", + "update": "update", + "delete": "delete", + "read_multi": "get_multi", + }, + ) + ) + + return TestClient(app) + + +@pytest.mark.asyncio +@pytest.mark.dialect("sqlite") +async def test_custom_uuid_crud(uuid_client): + response = uuid_client.post("/custom-uuid-test/create", json={"name": "test"}) + assert ( + response.status_code == 200 + ), f"Creation failed with response: {response.text}" + + try: + data = response.json() + assert "id" in data, f"Response does not contain 'id': {data}" + uuid_id = data["id"] + except Exception as e: # pragma: no cover + pytest.fail(f"Failed to process response: {response.text}. Error: {str(e)}") + + try: + UUID(uuid_id) + except ValueError: # pragma: no cover + pytest.fail("Invalid UUID format") + + response = uuid_client.get(f"/custom-uuid-test/get/{uuid_id}") + assert response.status_code == 200 + assert response.json()["id"] == uuid_id + assert response.json()["name"] == "test" + + update_response = uuid_client.patch( + f"/custom-uuid-test/update/{uuid_id}", json={"name": "updated"} + ) + response = uuid_client.get(f"/custom-uuid-test/get/{uuid_id}") + assert update_response.status_code == 200 + assert response.status_code == 200 + assert response.json()["name"] == "updated" + + response = uuid_client.delete(f"/custom-uuid-test/delete/{uuid_id}") + assert response.status_code == 200 + + response = uuid_client.get(f"/custom-uuid-test/get/{uuid_id}") + assert response.status_code == 404 + + +@pytest.mark.asyncio +async def test_uuid_list_endpoint(uuid_client): + created_ids = [] + for i in range(3): + response = uuid_client.post("/uuid-test/create", json={"name": f"test_{i}"}) + assert response.status_code == 200 + created_ids.append(response.json()["id"]) + + response = uuid_client.get("/uuid-test/get_multi") + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 3 + + for item in data: + try: + UUID(item["id"]) + except ValueError: # pragma: no cover + pytest.fail("Invalid UUID format in list response")