Skip to content

Commit

Permalink
UUID support (#188)
Browse files Browse the repository at this point in the history
* uuid support fix

* new tests for uuid

* code coverage
  • Loading branch information
igorbenav authored Dec 23, 2024
1 parent a2bb6fc commit 5126244
Show file tree
Hide file tree
Showing 4 changed files with 462 additions and 8 deletions.
58 changes: 50 additions & 8 deletions fastcrud/endpoint/helper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import inspect
from uuid import UUID
from typing import Optional, Union, Annotated, Sequence, Callable, TypeVar, Any

from pydantic import BaseModel, Field
from pydantic.functional_validators import field_validator
from fastapi import Depends, Query, params
from fastapi import Depends, Query, Path, params

from sqlalchemy import Column, inspect as sa_inspect
from sqlalchemy.dialects.postgresql import UUID as PostgresUUID
from sqlalchemy.types import TypeEngine
from sqlalchemy.sql.elements import KeyedColumnElement

from fastcrud.types import ModelType
Expand Down Expand Up @@ -87,12 +90,36 @@ def _get_primary_keys(
return primary_key_columns


def _is_uuid_type(column_type: TypeEngine) -> bool: # pragma: no cover
"""
Check if a SQLAlchemy column type represents a UUID.
Handles various SQL dialects and common UUID implementations.
"""
if isinstance(column_type, PostgresUUID):
return True

type_name = getattr(column_type, "__visit_name__", "").lower()
if "uuid" in type_name:
return True

if hasattr(column_type, "impl"):
return _is_uuid_type(column_type.impl)

return False


def _get_python_type(column: Column) -> Optional[type]:
"""Get the Python type for a SQLAlchemy column, with special handling for UUIDs."""
try:
if _is_uuid_type(column.type):
return UUID

direct_type: Optional[type] = column.type.python_type
return direct_type
except NotImplementedError:
if hasattr(column.type, "impl") and hasattr(column.type.impl, "python_type"):
if _is_uuid_type(column.type.impl): # pragma: no cover
return UUID
indirect_type: Optional[type] = column.type.impl.python_type
return indirect_type
else: # pragma: no cover
Expand All @@ -110,7 +137,10 @@ def _get_column_types(
raise ValueError("Model inspection failed, resulting in None.")
column_types = {}
for column in inspector_result.mapper.columns:
column_types[column.name] = _get_python_type(column)
column_type = _get_python_type(column)
if hasattr(column.type, "__visit_name__") and column.type.__visit_name__ == "uuid":
column_type = UUID
column_types[column.name] = column_type
return column_types


Expand Down Expand Up @@ -154,12 +184,24 @@ def wrapper(endpoint):
for p in signature.parameters.values()
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
]
extra_positional_params = [
inspect.Parameter(
name=k, annotation=v, kind=inspect.Parameter.POSITIONAL_ONLY
)
for k, v in pkeys.items()
]
extra_positional_params = []
for k, v in pkeys.items():
if v == UUID:
extra_positional_params.append(
inspect.Parameter(
name=k,
annotation=Annotated[UUID, Path(...)],
kind=inspect.Parameter.POSITIONAL_ONLY
)
)
else:
extra_positional_params.append(
inspect.Parameter(
name=k,
annotation=v,
kind=inspect.Parameter.POSITIONAL_ONLY
)
)

endpoint.__signature__ = signature.replace(
parameters=extra_positional_params + parameters
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ mypy = "^1.9.0"
ruff = "^0.3.4"
coverage = "^7.4.4"
testcontainers = "^4.7.1"
asyncpg = "^0.30.0"
psycopg2-binary = "^2.9.10"
psycopg = "^3.2.1"
aiomysql = "^0.2.0"
cryptography = "^43.0.1"
Expand Down
204 changes: 204 additions & 0 deletions tests/sqlalchemy/core/test_uuid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import pytest
from uuid import UUID, uuid4

from sqlalchemy import Column, String
from sqlalchemy.dialects.postgresql import UUID as PostgresUUID
from sqlalchemy.types import TypeDecorator
from fastapi import FastAPI
from fastapi.testclient import TestClient

from fastcrud import crud_router, FastCRUD
from pydantic import BaseModel

from ..conftest import Base


class UUIDType(TypeDecorator):
"""Platform-independent UUID type.
Uses PostgreSQL's UUID type, otherwise CHAR(36)
"""

impl = String
cache_ok = True

def __init__(self):
super().__init__(36)

def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(PostgresUUID(as_uuid=True))
else:
return dialect.type_descriptor(String(36))

def process_bind_param(self, value, dialect):
if value is None: # pragma: no cover
return value
elif dialect.name == "postgresql": # pragma: no cover
return value
else:
return str(value)

def process_result_value(self, value, dialect):
if value is None: # pragma: no cover
return value
if not isinstance(value, UUID):
return UUID(value)
return value # pragma: no cover


class UUIDModel(Base):
__tablename__ = "uuid_test"
id = Column(UUIDType(), primary_key=True, default=uuid4)
name = Column(String(255))


class CustomUUID(TypeDecorator):
"""Custom UUID type for testing."""

impl = String
cache_ok = True

def __init__(self):
super().__init__(36)
self.__visit_name__ = "uuid"

def process_bind_param(self, value, dialect):
if value is None: # pragma: no cover
return value
return str(value)

def process_result_value(self, value, dialect):
if value is None: # pragma: no cover
return value
return UUID(value)


class CustomUUIDModel(Base):
__tablename__ = "custom_uuid_test"
id = Column(CustomUUID(), primary_key=True, default=uuid4)
name = Column(String(255))


class UUIDSchema(BaseModel):
id: UUID
name: str

model_config = {"from_attributes": True}


class CreateUUIDSchema(BaseModel):
name: str

model_config = {"from_attributes": True}


class UpdateUUIDSchema(BaseModel):
name: str

model_config = {"from_attributes": True}


@pytest.fixture
def uuid_client(async_session):
app = FastAPI()

app.include_router(
crud_router(
session=lambda: async_session,
model=UUIDModel,
crud=FastCRUD(UUIDModel),
create_schema=CreateUUIDSchema,
update_schema=UpdateUUIDSchema,
path="/uuid-test",
tags=["uuid-test"],
endpoint_names={
"create": "create",
"read": "get",
"update": "update",
"delete": "delete",
"read_multi": "get_multi",
},
)
)

app.include_router(
crud_router(
session=lambda: async_session,
model=CustomUUIDModel,
crud=FastCRUD(CustomUUIDModel),
create_schema=CreateUUIDSchema,
update_schema=UpdateUUIDSchema,
path="/custom-uuid-test",
tags=["custom-uuid-test"],
endpoint_names={
"create": "create",
"read": "get",
"update": "update",
"delete": "delete",
"read_multi": "get_multi",
},
)
)

return TestClient(app)


@pytest.mark.asyncio
@pytest.mark.dialect("sqlite")
async def test_custom_uuid_crud(uuid_client):
response = uuid_client.post("/custom-uuid-test/create", json={"name": "test"})
assert (
response.status_code == 200
), f"Creation failed with response: {response.text}"

try:
data = response.json()
assert "id" in data, f"Response does not contain 'id': {data}"
uuid_id = data["id"]
except Exception as e: # pragma: no cover
pytest.fail(f"Failed to process response: {response.text}. Error: {str(e)}")

try:
UUID(uuid_id)
except ValueError: # pragma: no cover
pytest.fail("Invalid UUID format")

response = uuid_client.get(f"/custom-uuid-test/get/{uuid_id}")
assert response.status_code == 200
assert response.json()["id"] == uuid_id
assert response.json()["name"] == "test"

update_response = uuid_client.patch(
f"/custom-uuid-test/update/{uuid_id}", json={"name": "updated"}
)
response = uuid_client.get(f"/custom-uuid-test/get/{uuid_id}")

assert update_response.status_code == 200
assert response.status_code == 200
assert response.json()["name"] == "updated"

response = uuid_client.delete(f"/custom-uuid-test/delete/{uuid_id}")
assert response.status_code == 200

response = uuid_client.get(f"/custom-uuid-test/get/{uuid_id}")
assert response.status_code == 404


@pytest.mark.asyncio
async def test_uuid_list_endpoint(uuid_client):
created_ids = []
for i in range(3):
response = uuid_client.post("/uuid-test/create", json={"name": f"test_{i}"})
assert response.status_code == 200
created_ids.append(response.json()["id"])

response = uuid_client.get("/uuid-test/get_multi")
assert response.status_code == 200
data = response.json()["data"]
assert len(data) == 3

for item in data:
try:
UUID(item["id"])
except ValueError: # pragma: no cover
pytest.fail("Invalid UUID format in list response")
Loading

0 comments on commit 5126244

Please sign in to comment.