Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added custom relationships statements form #802

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 52 additions & 123 deletions sqladmin/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
)

import anyio
from sqlalchemy import Boolean, select
from sqlalchemy import (
Boolean,
Select,
select,
)
from sqlalchemy import inspect as sqlalchemy_inspect
from sqlalchemy.orm import (
ColumnProperty,
Expand Down Expand Up @@ -109,6 +113,11 @@ def _inner(func: T_CC) -> T_CC:


class ModelConverterBase:
relationships_statements: dict[str, Select] = {}
"""
Select statement for relationships.
"""

_converters: dict[str, ConverterCallable] = {}

def __init__(self) -> None:
Expand Down Expand Up @@ -152,19 +161,15 @@ async def _prepare_kwargs(
kwargs.setdefault("render_kw", widget_args)

if isinstance(prop, ColumnProperty):
kwargs = self._prepare_column(
prop=prop, kwargs=kwargs, form_include_pk=form_include_pk
)
kwargs = self._prepare_column(prop=prop, kwargs=kwargs, form_include_pk=form_include_pk)
else:
kwargs = await self._prepare_relationship(
prop=prop, session_maker=session_maker, kwargs=kwargs, loader=loader
)

return kwargs

def _prepare_column(
self, prop: ColumnProperty, form_include_pk: bool, kwargs: dict
) -> Union[dict, None]:
def _prepare_column(self, prop: ColumnProperty, form_include_pk: bool, kwargs: dict) -> Union[dict, None]:
assert len(prop.columns) == 1, "Multiple-column properties not supported"
column = prop.columns[0]

Expand All @@ -180,11 +185,7 @@ def _prepare_column(

if callable_default is not None:
# ColumnDefault(val).arg can be also a plain value
default = (
callable_default(None)
if callable(callable_default)
else callable_default
)
default = callable_default(None) if callable(callable_default) else callable_default

kwargs["default"] = default
optional_types = (Boolean,)
Expand Down Expand Up @@ -217,9 +218,7 @@ async def _prepare_relationship(
kwargs["allow_blank"] = nullable

if not loader:
kwargs.setdefault(
"data", await self._prepare_select_options(prop, session_maker)
)
kwargs.setdefault("data", await self._prepare_select_options(prop, session_maker))

return kwargs

Expand All @@ -229,22 +228,19 @@ async def _prepare_select_options(
session_maker: sessionmaker,
) -> list[tuple[str, Any]]:
target_model = prop.mapper.class_
stmt = select(target_model)
if prop.key in self.relationships_statements:
stmt = self.relationships_statements[prop.key]
else:
stmt = select(target_model)

if is_async_session_maker(session_maker):
async with session_maker() as session:
objects = await session.execute(stmt)
return [
(str(self._get_identifier_value(obj)), str(obj))
for obj in objects.scalars().unique().all()
]
return [(str(self._get_identifier_value(obj)), str(obj)) for obj in objects.scalars().unique().all()]
else:
with session_maker() as session:
objects = await anyio.to_thread.run_sync(session.execute, stmt)
return [
(str(self._get_identifier_value(obj)), str(obj))
for obj in objects.scalars().unique().all()
]
return [(str(self._get_identifier_value(obj)), str(obj)) for obj in objects.scalars().unique().all()]

def get_converter(self, prop: MODEL_PROPERTY) -> ConverterCallable:
if isinstance(prop, RelationshipProperty):
Expand Down Expand Up @@ -310,11 +306,7 @@ async def convert(
assert issubclass(override, Field)
return override(**kwargs)

multiple = (
is_relationship(prop)
and prop.direction.name in ("ONETOMANY", "MANYTOMANY")
and prop.uselist
)
multiple = is_relationship(prop) and prop.direction.name in ("ONETOMANY", "MANYTOMANY") and prop.uselist

if loader:
if multiple:
Expand All @@ -339,27 +331,21 @@ def _string_common(prop: ColumnProperty) -> list[Validator]:
return li

@converts("String", "CHAR") # includes Unicode
def conv_string(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_string(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
extra_validators = self._string_common(prop)
kwargs.setdefault("validators", [])
kwargs["validators"].extend(extra_validators)
return StringField(**kwargs)

@converts("Text", "LargeBinary", "Binary") # includes UnicodeText
def conv_text(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_text(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
kwargs.setdefault("validators", [])
extra_validators = self._string_common(prop)
kwargs["validators"].extend(extra_validators)
return TextAreaField(**kwargs)

@converts("Boolean", "dialects.mssql.base.BIT")
def conv_boolean(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_boolean(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
if not prop.columns[0].nullable:
kwargs.setdefault("render_kw", {})
kwargs["render_kw"]["class"] = "form-check-input"
Expand All @@ -371,27 +357,19 @@ def conv_boolean(
return SelectField(**kwargs)

@converts("Date")
def conv_date(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_date(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
return DateField(**kwargs)

@converts("Time")
def conv_time(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_time(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
return TimeField(**kwargs)

@converts("DateTime")
def conv_datetime(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_datetime(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
return DateTimeField(**kwargs)

@converts("Enum")
def conv_enum(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_enum(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
available_choices = [(e, e) for e in prop.columns[0].type.enums]
accepted_values = [choice[0] for choice in available_choices]

Expand All @@ -409,29 +387,21 @@ def conv_enum(
return SelectField(**kwargs)

@converts("Integer") # includes BigInteger and SmallInteger
def conv_integer(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_integer(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
return IntegerField(**kwargs)

@converts("Numeric") # includes DECIMAL, Float/FLOAT, REAL, and DOUBLE
def conv_decimal(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_decimal(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
# override default decimal places limit, use database defaults instead
kwargs.setdefault("places", None)
return DecimalField(**kwargs)

@converts("JSON", "JSONB")
def conv_json(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_json(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
return JSONField(**kwargs)

@converts("Interval")
def conv_interval(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_interval(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
kwargs["render_kw"]["placeholder"] = "Like: 1 day 1:25:33.652"
return IntervalField(**kwargs)

Expand All @@ -440,9 +410,7 @@ def conv_interval(
"sqlalchemy.dialects.postgresql.types.INET",
"sqlalchemy_utils.types.ip_address.IPAddressType",
)
def conv_ip_address(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_ip_address(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
kwargs.setdefault("validators", [])
kwargs["validators"].append(validators.IPAddress(ipv4=True, ipv6=True))
return StringField(**kwargs)
Expand All @@ -451,9 +419,7 @@ def conv_ip_address(
"sqlalchemy.dialects.postgresql.base.MACADDR",
"sqlalchemy.dialects.postgresql.types.MACADDR",
)
def conv_mac_address(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_mac_address(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
kwargs.setdefault("validators", [])
kwargs["validators"].append(validators.MacAddress())
return StringField(**kwargs)
Expand All @@ -464,76 +430,54 @@ def conv_mac_address(
"sqlalchemy.sql.sqltypes.Uuid",
"sqlalchemy_utils.types.uuid.UUIDType",
)
def conv_uuid(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_uuid(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
kwargs.setdefault("validators", [])
kwargs["validators"].append(validators.UUID())
return StringField(**kwargs)

@converts(
"sqlalchemy.dialects.postgresql.base.ARRAY", "sqlalchemy.sql.sqltypes.ARRAY"
)
def conv_ARRAY(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
@converts("sqlalchemy.dialects.postgresql.base.ARRAY", "sqlalchemy.sql.sqltypes.ARRAY")
def conv_ARRAY(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
return Select2TagsField(**kwargs)

@converts("sqlalchemy_utils.types.email.EmailType")
def conv_email(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_email(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
kwargs.setdefault("validators", [])
kwargs["validators"].append(validators.Email())
return StringField(**kwargs)

@converts("sqlalchemy_utils.types.url.URLType")
def conv_url(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_url(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
kwargs.setdefault("validators", [])
kwargs["validators"].append(validators.URL())
return StringField(**kwargs)

@converts("sqlalchemy_utils.types.currency.CurrencyType")
def conv_currency(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_currency(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
kwargs.setdefault("validators", [])
kwargs["validators"].append(CurrencyValidator())
return StringField(**kwargs)

@converts("sqlalchemy_utils.types.timezone.TimezoneType")
def conv_timezone(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_timezone(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
kwargs.setdefault("validators", [])
kwargs["validators"].append(
TimezoneValidator(coerce_function=prop.columns[0].type._coerce)
)
kwargs["validators"].append(TimezoneValidator(coerce_function=prop.columns[0].type._coerce))
return StringField(**kwargs)

@converts("sqlalchemy_utils.types.phone_number.PhoneNumberType")
def conv_phone_number(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_phone_number(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
kwargs.setdefault("validators", [])
kwargs["validators"].append(PhoneNumberValidator())
return StringField(**kwargs)

@converts("sqlalchemy_utils.types.color.ColorType")
def conv_color(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_color(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
kwargs.setdefault("validators", [])
kwargs["validators"].append(ColorValidator())
return StringField(**kwargs)

@converts("sqlalchemy_utils.types.choice.ChoiceType")
@no_type_check
def convert_choice_type(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def convert_choice_type(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
available_choices = []
column = prop.columns[0]

Expand All @@ -542,10 +486,7 @@ def convert_choice_type(
else:
available_choices = column.type.choices

accepted_values = [
choice[0] if isinstance(choice, tuple) else choice.value
for choice in available_choices
]
accepted_values = [choice[0] if isinstance(choice, tuple) else choice.value for choice in available_choices]

if column.nullable:
kwargs["allow_blank"] = column.nullable
Expand All @@ -560,34 +501,24 @@ def convert_choice_type(
return SelectField(**kwargs)

@converts("fastapi_storages.integrations.sqlalchemy.FileType")
def conv_file(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_file(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
return FileField(**kwargs)

@converts("fastapi_storages.integrations.sqlalchemy.ImageType")
def conv_image(
self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_image(self, model: type, prop: ColumnProperty, kwargs: dict[str, Any]) -> UnboundField:
return FileField(**kwargs)

@converts("ONETOONE")
def conv_one_to_one(
self, model: type, prop: RelationshipProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_one_to_one(self, model: type, prop: RelationshipProperty, kwargs: dict[str, Any]) -> UnboundField:
kwargs["allow_blank"] = True
return QuerySelectField(**kwargs)

@converts("MANYTOONE")
def conv_many_to_one(
self, model: type, prop: RelationshipProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_many_to_one(self, model: type, prop: RelationshipProperty, kwargs: dict[str, Any]) -> UnboundField:
return QuerySelectField(**kwargs)

@converts("MANYTOMANY", "ONETOMANY")
def conv_many_to_many(
self, model: type, prop: RelationshipProperty, kwargs: dict[str, Any]
) -> UnboundField:
def conv_many_to_many(self, model: type, prop: RelationshipProperty, kwargs: dict[str, Any]) -> UnboundField:
return QuerySelectMultipleField(**kwargs)


Expand Down Expand Up @@ -618,9 +549,7 @@ async def get_model_form(
names = only or mapper.attrs.keys()
for name in names:
attr = mapper.attrs[name]
if (exclude and name in exclude) or (
isinstance(attr, ColumnProperty) and isinstance(attr.expression, Label)
):
if (exclude and name in exclude) or (isinstance(attr, ColumnProperty) and isinstance(attr.expression, Label)):
continue
attributes.append((name, attr))

Expand Down
Loading