Skip to content

Commit

Permalink
more strict uuid validation
Browse files Browse the repository at this point in the history
  • Loading branch information
igorbenav committed Dec 23, 2024
1 parent 7ae5b1f commit c493f6a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
20 changes: 15 additions & 5 deletions fastcrud/endpoint/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,10 @@ def _get_column_types(
column_types = {}
for column in inspector_result.mapper.columns:
column_type = _get_python_type(column)
if hasattr(column.type, "__visit_name__") and column.type.__visit_name__ == "uuid":
if (
hasattr(column.type, "__visit_name__")
and column.type.__visit_name__ == "uuid"
):
column_type = UUID
column_types[column.name] = column_type
return column_types
Expand Down Expand Up @@ -189,16 +192,23 @@ def wrapper(endpoint):
extra_positional_params.append(
inspect.Parameter(
name=k,
annotation=Annotated[UUID, Path(title=k)],
kind=inspect.Parameter.POSITIONAL_ONLY,
annotation=Annotated[
UUID,
Path(
...,
description=f"The {k} must be a valid UUID",
examples=['123e4567-e89b-12d3-a456-426614174000']
)
],
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD
)
)
else:
extra_positional_params.append(
inspect.Parameter(
name=k,
annotation=Annotated[v, Path(title=k)],
kind=inspect.Parameter.POSITIONAL_ONLY,
annotation=Annotated[v, Path(...)],
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD
)
)

Expand Down
9 changes: 6 additions & 3 deletions tests/sqlalchemy/core/test_uuid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from uuid import UUID, uuid4

from sqlalchemy import Column, String, Text
from sqlalchemy import Column, String
from sqlalchemy.dialects.postgresql import UUID as PostgresUUID
from sqlalchemy.types import TypeDecorator
from fastapi import FastAPI
Expand All @@ -12,26 +12,28 @@

from ..conftest import Base


class UUIDType(TypeDecorator):
"""Platform-independent UUID type.
Uses PostgreSQL's UUID type, otherwise CHAR(36)
"""

impl = String
cache_ok = True

def __init__(self):
super().__init__(36)

def load_dialect_impl(self, dialect):
if dialect.name == 'postgresql':
if dialect.name == "postgresql":
return dialect.type_descriptor(PostgresUUID(as_uuid=True))
else:
return dialect.type_descriptor(String(36))

def process_bind_param(self, value, dialect):
if value is None:
return value
elif dialect.name == 'postgresql':
elif dialect.name == "postgresql":
return value
else:
return str(value)
Expand All @@ -52,6 +54,7 @@ class UUIDModel(Base):

class CustomUUID(TypeDecorator):
"""Custom UUID type for testing."""

impl = String
cache_ok = True

Expand Down
15 changes: 8 additions & 7 deletions tests/sqlmodel/core/test_uuid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from uuid import UUID, uuid4

from sqlalchemy import Column, String, Text
from sqlalchemy import Column, String
from sqlalchemy.dialects.postgresql import UUID as PostgresUUID
from sqlalchemy.types import TypeDecorator
from fastapi import FastAPI
Expand All @@ -11,26 +11,28 @@
from fastcrud import crud_router, FastCRUD
from pydantic import ConfigDict


class UUIDType(TypeDecorator):
"""Platform-independent UUID type.
Uses PostgreSQL's UUID type, otherwise CHAR(36)
"""

impl = String
cache_ok = True

def __init__(self):
super().__init__(36)

def load_dialect_impl(self, dialect):
if dialect.name == 'postgresql':
if dialect.name == "postgresql":
return dialect.type_descriptor(PostgresUUID(as_uuid=True))
else:
return dialect.type_descriptor(String(36))

def process_bind_param(self, value, dialect):
if value is None:
return value
elif dialect.name == 'postgresql':
elif dialect.name == "postgresql":
return value
else:
return str(value)
Expand All @@ -46,14 +48,14 @@ def process_result_value(self, value, dialect):
class UUIDModel(SQLModel, table=True):
__tablename__ = "uuid_test"
id: UUID = Field(
default_factory=uuid4,
sa_column=Column(UUIDType(), primary_key=True)
default_factory=uuid4, sa_column=Column(UUIDType(), primary_key=True)
)
name: str = Field(sa_column=Column(String(255)))


class CustomUUID(TypeDecorator):
"""Custom UUID type for testing."""

impl = String
cache_ok = True

Expand All @@ -75,8 +77,7 @@ def process_result_value(self, value, dialect):
class CustomUUIDModel(SQLModel, table=True):
__tablename__ = "custom_uuid_test"
id: UUID = Field(
default_factory=uuid4,
sa_column=Column(CustomUUID(), primary_key=True)
default_factory=uuid4, sa_column=Column(CustomUUID(), primary_key=True)
)
name: str = Field(sa_column=Column(String(255)))

Expand Down

0 comments on commit c493f6a

Please sign in to comment.