Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sqlalchemy] allow create/update with object for one/many 2 many #171

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 48 additions & 6 deletions fastapi_crudrouter/core/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Callable, List, Type, Generator, Optional, Union
from collections.abc import Sequence

from fastapi import Depends, HTTPException

Expand All @@ -9,10 +10,12 @@
from sqlalchemy.orm import Session
from sqlalchemy.ext.declarative import DeclarativeMeta as Model
from sqlalchemy.exc import IntegrityError
from sqlalchemy import column
except ImportError:
Model = None
Session = None
IntegrityError = None
column = None
sqlalchemy_installed = False
else:
sqlalchemy_installed = True
Expand All @@ -39,7 +42,7 @@ def __init__(
update_route: Union[bool, DEPENDENCIES] = True,
delete_one_route: Union[bool, DEPENDENCIES] = True,
delete_all_route: Union[bool, DEPENDENCIES] = True,
**kwargs: Any
**kwargs: Any,
) -> None:
assert (
sqlalchemy_installed
Expand All @@ -63,7 +66,7 @@ def __init__(
update_route=update_route,
delete_one_route=delete_one_route,
delete_all_route=delete_all_route,
**kwargs
**kwargs,
)

def _get_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST:
Expand Down Expand Up @@ -97,13 +100,49 @@ def route(

return route

def _get_orm_object(self, db: Session, orm_model: Model, model: Model) -> Any:
query = db.query(orm_model)
filter_items = 0
for key, val in model.dict().items():
if val:
filter_items += 1
query = query.filter(column(key) == val)
if filter_items == 0:
raise Exception("No attributes for filter found")
return query.one()

def _get_orm_object_or_value(self, db: Session, val: Any) -> Any:
"""Return an inflated database object or a plain value.

If a `val` is a SqlModel type and has defined a Meta.orm model
attribute, lookup the object from the `db` and return it.
Otherwise, just return the `val`. If `val` is a sequence of
objects, return the sequence of objects from the db.
"""
# we want to iterate through sequences but not strings
if not val or isinstance(val, str):
return val

if isinstance(val, Sequence):
return [self._get_orm_object_or_value(db, v) for v in val]
else:
if meta_class := getattr(val, "Meta", None):
if orm_model := getattr(meta_class, "orm_model", None):
return self._get_orm_object(db, orm_model, val)
return val

def _create(self, *args: Any, **kwargs: Any) -> CALLABLE:
def route(
model: self.create_schema, # type: ignore
db: Session = Depends(self.db_func),
) -> Model:
try:
db_model: Model = self.db_model(**model.dict())
db_model: Model = self.db_model()

for key, val in model:
if val:
setattr(db_model, key, self._get_orm_object_or_value(db, val))

db.add(db_model)
db.commit()
db.refresh(db_model)
Expand All @@ -123,9 +162,12 @@ def route(
try:
db_model: Model = self._get_one()(item_id, db)

for key, value in model.dict(exclude={self._pk}).items():
if hasattr(db_model, key):
setattr(db_model, key, value)
for key, val in model:
if key != self._pk:
if hasattr(db_model, key):
setattr(
db_model, key, self._get_orm_object_or_value(db, val)
)

db.commit()
db.refresh(db_model)
Expand Down
Loading