Skip to content

Commit

Permalink
use second db handle for only for user admin and writes (#1184)
Browse files Browse the repository at this point in the history
* change 'user_engine' to a 'WriteSession' instead, so the master db connection is used for writes [and associated admin session reads] only

* eager-load roles, remove unnecessary methods, add @default_session, move session ctx mgrs to admin page

* make sure sql statements and timing are logged for all engines, plus tag engines with id and log those too, and superfluous user method cleanup

* sqlalchemy cleanup: removed superfluous bits, improved argument passing for engine creation

* _assign_roles() does its own commit() and returns an instance of the newly updated User

* raise Exception when trying to update non-existent User, return UserRole on creation.

* use more appropriate reciever for static method call, and expand comment on static vs bound methods in User.
  • Loading branch information
melange396 authored Jun 26, 2023
1 parent 13fcfe3 commit 6fe6e7a
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 134 deletions.
9 changes: 5 additions & 4 deletions src/server/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from flask import Flask, g, request
from sqlalchemy import event
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Connection, Engine
from werkzeug.exceptions import Unauthorized
from werkzeug.local import LocalProxy

Expand Down Expand Up @@ -85,12 +85,12 @@ def log_info_with_request_and_response(message, response, **kwargs):
**kwargs
)

@event.listens_for(engine, "before_cursor_execute")
@event.listens_for(Engine, "before_cursor_execute")
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
context._query_start_time = time.time()


@event.listens_for(engine, "after_cursor_execute")
@event.listens_for(Engine, "after_cursor_execute")
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
# this timing info may be suspect, at least in terms of dbms cpu time...
# it is likely that it includes that time as well as any overhead that
Expand All @@ -101,7 +101,8 @@ def after_cursor_execute(conn, cursor, statement, parameters, context, executema
# Convert to milliseconds
total_time *= 1000
get_structured_logger("server_api").info(
"Executed SQL", statement=statement, params=parameters, elapsed_time_ms=total_time
"Executed SQL", statement=statement, params=parameters, elapsed_time_ms=total_time,
engine_id=conn.get_execution_options().get('engine_id')
)


Expand Down
61 changes: 53 additions & 8 deletions src/server/_db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from sqlalchemy import create_engine, MetaData
import functools
from inspect import signature, Parameter

from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker

Expand All @@ -9,15 +12,57 @@
# previously `_common` imported from `_security` which imported from `admin.models`, which imported (back again) from `_common` for database connection objects


engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS)

if SQLALCHEMY_DATABASE_URI_PRIMARY:
user_engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI_PRIMARY, **SQLALCHEMY_ENGINE_OPTIONS)
else:
user_engine: Engine = engine
# a decorator to automatically provide a sqlalchemy session by default, if an existing session is not explicitly
# specified to override it. it is preferred to use a single session for a sequence of operations logically grouped
# together, but this allows individual operations to be run by themselves without having to provide an
# already-established session. requires an argument to the wrapped function named 'session'.
# for instance:
#
# @default_session(WriteSession)
# def foo(session):
# pass
#
# # calling:
# foo()
# # is identical to:
# with WriteSession() as s:
# foo(s)
def default_session(sess):
def decorator__default_session(func):
# make sure `func` is compatible w/ this decorator
func_params = signature(func).parameters
if 'session' not in func_params or func_params['session'].kind == Parameter.POSITIONAL_ONLY:
raise Exception(f"@default_session(): function {func.__name__}() must accept an argument 'session' that can be specified by keyword.")
# save position of 'session' arg, to later check if its been passed in by position/order
sess_index = list(func_params).index('session')

@functools.wraps(func)
def wrapper__default_session(*args, **kwargs):
if 'session' in kwargs or len(args) >= sess_index+1:
# 'session' has been specified by the caller, so we have nothing to do here. pass along all args unchanged.
return func(*args, **kwargs)
# otherwise, we will wrap this call with a context manager for the default session provider, and pass that session instance to the wrapped function.
with sess() as session:
return func(*args, **kwargs, session=session)

metadata = MetaData(bind=user_engine)
return wrapper__default_session

Session = sessionmaker(bind=user_engine)
return decorator__default_session


engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS, execution_options={'engine_id': 'default'})
Session = sessionmaker(bind=engine)

if SQLALCHEMY_DATABASE_URI_PRIMARY and SQLALCHEMY_DATABASE_URI_PRIMARY != SQLALCHEMY_DATABASE_URI:
# if available, use the main/primary DB for write operations. DB replication processes should be in place to
# propagate any written changes to the regular (load balanced) replicas.
write_engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI_PRIMARY, **SQLALCHEMY_ENGINE_OPTIONS, execution_options={'engine_id': 'write_engine'})
WriteSession = sessionmaker(bind=write_engine)
# TODO: insert log statement acknowledging this second session handle is in use?
else:
write_engine: Engine = engine
WriteSession = Session
# NOTE: `WriteSession` could be called `AdminSession`, as its only (currently) used by the admin page, and the admin
# page is the only thing that should be writing to the db. its tempting to let the admin page read from the
# regular `Session` and write with `WriteSession`, but concurrency problems may arise from sync/replication lag.
6 changes: 1 addition & 5 deletions src/server/_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
TEMPORARY_API_KEY,
URL_PREFIX,
)
from .admin.models import User, UserRole
from .admin.models import User

API_KEY_HARD_WARNING = API_KEY_REQUIRED_STARTING_AT - timedelta(days=14)
API_KEY_SOFT_WARNING = API_KEY_HARD_WARNING - timedelta(days=14)
Expand Down Expand Up @@ -91,10 +91,6 @@ def _get_current_user():
current_user: User = cast(User, LocalProxy(_get_current_user))


def register_user_role(role_name: str) -> None:
UserRole.create_role(role_name)


def _is_public_route() -> bool:
public_routes_list = ["lib", "admin", "version"]
for route in public_routes_list:
Expand Down
114 changes: 51 additions & 63 deletions src/server/admin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sqlalchemy.orm import relationship
from copy import deepcopy

from .._db import Session
from .._db import Session, WriteSession, default_session
from delphi.epidata.common.logger import get_structured_logger

from typing import Set, Optional, List
Expand All @@ -25,7 +25,7 @@ def _default_date_now():
class User(Base):
__tablename__ = "api_user"
id = Column(Integer, primary_key=True, autoincrement=True)
roles = relationship("UserRole", secondary=association_table)
roles = relationship("UserRole", secondary=association_table, lazy="joined") # last arg does an eager load of this property from foreign tables
api_key = Column(String(50), unique=True, nullable=False)
email = Column(String(320), unique=True, nullable=False)
created = Column(Date, default=_default_date_now)
Expand All @@ -35,97 +35,85 @@ def __init__(self, api_key: str, email: str = None) -> None:
self.api_key = api_key
self.email = email

@staticmethod
def list_users() -> List["User"]:
with Session() as session:
return session.query(User).all()

@property
def as_dict(self):
return {
"id": self.id,
"api_key": self.api_key,
"email": self.email,
"roles": User.get_user_roles(self.id),
"roles": set(role.name for role in self.roles),
"created": self.created,
"last_time_used": self.last_time_used
}

@staticmethod
def get_user_roles(user_id: int) -> Set[str]:
with Session() as session:
user = session.query(User).filter(User.id == user_id).first()
return set([role.name for role in user.roles])

def has_role(self, required_role: str) -> bool:
return required_role in User.get_user_roles(self.id)
return required_role in set(role.name for role in self.roles)

@staticmethod
def _assign_roles(user: "User", roles: Optional[Set[str]], session) -> None:
# NOTE: this uses a borrowed/existing `session`, and thus does not do a `session.commit()`...
# that is the responsibility of the caller!
get_structured_logger("api_user_models").info("setting roles", roles=roles, user_id=user.id, api_key=user.api_key)
db_user = session.query(User).filter(User.id == user.id).first()
# TODO: would it be sufficient to use the passed-in `user` instead of looking up this `db_user`?
# or even use this as a bound method instead of a static??
# same goes for `update_user()` and `delete_user()` below...
if roles:
roles_to_assign = session.query(UserRole).filter(UserRole.name.in_(roles)).all()
db_user.roles = roles_to_assign
db_user.roles = session.query(UserRole).filter(UserRole.name.in_(roles)).all()
else:
db_user.roles = []
session.commit()
# retrieve the newly updated User object
return session.query(User).filter(User.id == user.id).first()

@staticmethod
@default_session(Session)
def find_user(*, # asterisk forces explicit naming of all arguments when calling this method
user_id: Optional[int] = None, api_key: Optional[str] = None, user_email: Optional[str] = None
session,
user_id: Optional[int] = None, api_key: Optional[str] = None, user_email: Optional[str] = None
) -> "User":
# NOTE: be careful, using multiple arguments could match multiple users, but this will return only one!
with Session() as session:
user = (
session.query(User)
.filter((User.id == user_id) | (User.api_key == api_key) | (User.email == user_email))
.first()
)
user = (
session.query(User)
.filter((User.id == user_id) | (User.api_key == api_key) | (User.email == user_email))
.first()
)
return user if user else None

@staticmethod
def create_user(api_key: str, email: str, user_roles: Optional[Set[str]] = None) -> "User":
@default_session(WriteSession)
def create_user(api_key: str, email: str, session, user_roles: Optional[Set[str]] = None) -> "User":
get_structured_logger("api_user_models").info("creating user", api_key=api_key)
with Session() as session:
new_user = User(api_key=api_key, email=email)
# TODO: we may need to populate 'created' field/column here, if the default
# specified above gets bound to the time of when that line of python was evaluated.
session.add(new_user)
session.commit()
User._assign_roles(new_user, user_roles, session)
session.commit()
return new_user
new_user = User(api_key=api_key, email=email)
session.add(new_user)
session.commit()
return User._assign_roles(new_user, user_roles, session)

@staticmethod
@default_session(WriteSession)
def update_user(
user: "User",
email: Optional[str],
api_key: Optional[str],
roles: Optional[Set[str]]
roles: Optional[Set[str]],
session
) -> "User":
get_structured_logger("api_user_models").info("updating user", user_id=user.id, new_api_key=api_key)
with Session() as session:
user = User.find_user(user_id=user.id)
if user:
update_stmt = (
update(User)
.where(User.id == user.id)
.values(api_key=api_key, email=email)
)
session.execute(update_stmt)
User._assign_roles(user, roles, session)
session.commit()
return user
user = User.find_user(user_id=user.id, session=session)
if not user:
raise Exception('user not found')
update_stmt = (
update(User)
.where(User.id == user.id)
.values(api_key=api_key, email=email)
)
session.execute(update_stmt)
return User._assign_roles(user, roles, session)

@staticmethod
def delete_user(user_id: int) -> None:
@default_session(WriteSession)
def delete_user(user_id: int, session) -> None:
get_structured_logger("api_user_models").info("deleting user", user_id=user_id)
with Session() as session:
session.execute(delete(User).where(User.id == user_id))
session.commit()
session.execute(delete(User).where(User.id == user_id))
session.commit()


class UserRole(Base):
Expand All @@ -134,23 +122,23 @@ class UserRole(Base):
name = Column(String(50), unique=True)

@staticmethod
def create_role(name: str) -> None:
@default_session(WriteSession)
def create_role(name: str, session) -> None:
get_structured_logger("api_user_models").info("creating user role", role=name)
with Session() as session:
session.execute(
f"""
# TODO: check role doesnt already exist
session.execute(f"""
INSERT INTO user_role (name)
SELECT '{name}'
WHERE NOT EXISTS
(SELECT *
FROM user_role
WHERE name='{name}')
"""
)
session.commit()
""")
session.commit()
return session.query(UserRole).filter(UserRole.name == name).first()

@staticmethod
def list_all_roles():
with Session() as session:
roles = session.query(UserRole).all()
@default_session(Session)
def list_all_roles(session):
roles = session.query(UserRole).all()
return [role.name for role in roles]
Loading

0 comments on commit 6fe6e7a

Please sign in to comment.