Skip to content

Commit

Permalink
Merge pull request #199 from aiondemand/feature/cleaner-sql-engine-de…
Browse files Browse the repository at this point in the history
…pendency-injection

Using a DbSession and an EngineSingleton to inject the engine
  • Loading branch information
josvandervelde authored Nov 21, 2023
2 parents 429e8ac + 47b254c commit 799f63e
Show file tree
Hide file tree
Showing 49 changed files with 409 additions and 523 deletions.
10 changes: 5 additions & 5 deletions src/connectors/synchronization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from datetime import datetime
from typing import Optional

from sqlmodel import Session, select
from sqlmodel import select, Session

from connectors.abstract.resource_connector import ResourceConnector, RESOURCE
from connectors.record_error import RecordError
from connectors.resource_with_relations import ResourceWithRelations
from database.model.concept.concept import AIoDConcept
from database.setup import _create_or_fetch_related_objects, _get_existing_resource, sqlmodel_engine
from database.session import DbSession
from database.setup import _create_or_fetch_related_objects, _get_existing_resource
from routers import ResourceRouter, resource_routers, enum_routers

RELATIVE_PATH_STATE_JSON = pathlib.Path("state.json")
Expand Down Expand Up @@ -141,8 +142,7 @@ def main():
state_path = working_dir / RELATIVE_PATH_STATE_JSON
first_run = not state_path.exists()

engine = sqlmodel_engine(rebuild_db="never")
with Session(engine) as session:
with DbSession() as session:
db_empty = session.scalars(select(connector.resource_class)).first() is None

if first_run or db_empty:
Expand All @@ -165,7 +165,7 @@ def main():
if router.resource_class == connector.resource_class
]

with Session(engine) as session:
with DbSession() as session:
for i, item in enumerate(items):
error = save_to_database(router=router, connector=connector, session=session, item=item)
if error:
Expand Down
12 changes: 4 additions & 8 deletions src/database/deletion/hard_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,17 @@
from typing import Type

from sqlalchemy import delete, and_
from sqlalchemy.engine import Engine
from sqlalchemy.sql.operators import is_not
from sqlmodel import Session

from database.model.concept.concept import AIoDConcept
from database.model.helper_functions import non_abstract_subclasses
from database.setup import sqlmodel_engine
from database.session import DbSession


def hard_delete_older_than(engine: Engine, time_threshold: timedelta):
def hard_delete_older_than(time_threshold: timedelta):
classes: list[Type[AIoDConcept]] = non_abstract_subclasses(AIoDConcept)
date_threshold = datetime.datetime.now() - time_threshold
with Session(engine) as session:
with DbSession() as session:
for concept in classes:
filter_ = and_(
is_not(concept.date_deleted, None),
Expand All @@ -50,10 +48,8 @@ def _parse_args() -> argparse.Namespace:

def main():
args = _parse_args()

engine = sqlmodel_engine(rebuild_db="never")
time_threshold = timedelta(minutes=args.time_threshold_minutes)
hard_delete_older_than(engine=engine, time_threshold=time_threshold)
hard_delete_older_than(time_threshold=time_threshold)


if __name__ == "__main__":
Expand Down
52 changes: 52 additions & 0 deletions src/database/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Enabling access to database sessions.
"""

from contextlib import contextmanager

from sqlalchemy.engine import Engine
from sqlmodel import Session, create_engine

from config import DB_CONFIG


class EngineSingleton:
"""Making sure the engine is created only once."""

__monostate = None

def __init__(self):
if not EngineSingleton.__monostate:
EngineSingleton.__monostate = self.__dict__
self.engine = create_engine(db_url(), echo=False, pool_recycle=3600)
else:
self.__dict__ = EngineSingleton.__monostate

def patch(self, engine: Engine):
self.__monostate["engine"] = engine # type: ignore


def db_url(including_db=True):
username = DB_CONFIG.get("name", "root")
password = DB_CONFIG.get("password", "ok")
host = DB_CONFIG.get("host", "demodb")
port = DB_CONFIG.get("port", 3306)
database = DB_CONFIG.get("database", "aiod")
if including_db:
return f"mysql://{username}:{password}@{host}:{port}/{database}"
return f"mysql://{username}:{password}@{host}:{port}"


@contextmanager
def DbSession() -> Session:
"""
Returning a SQLModel session bound to the (configured) database engine.
Alternatively, we could have used FastAPI Depends, but that only works for FastAPI - while
the synchronization, for instance, also needs a Session, but doesn't use FastAPI.
"""
session = Session(EngineSingleton().engine)
try:
yield session
finally:
session.close()
64 changes: 10 additions & 54 deletions src/database/setup.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,33 @@
"""
Utility functions for initializing the database and tables through SQLAlchemy.
"""

from operator import and_

from sqlalchemy import text
from sqlalchemy.engine import Engine
from sqlmodel import create_engine, Session, SQLModel, select
import sqlmodel
from sqlalchemy import text, create_engine
from sqlmodel import SQLModel, select

from config import DB_CONFIG
from connectors.resource_with_relations import ResourceWithRelations
from database.model.concept.concept import AIoDConcept
from database.model.named_relation import NamedRelation
from database.model.platform.platform_names import PlatformName
from database.session import db_url
from routers import resource_routers


def connect_to_database(
url: str = "mysql://root:[email protected]:3307/aiod",
create_if_not_exists: bool = True,
delete_first: bool = False,
) -> Engine:
"""Connect to server, optionally creating the database if it does not exist.
Params
------
url: URL to the database, see https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls # noqa
create_if_not_exists: create the database if it does not exist
delete_first: drop the database before creating it again, to start with an empty database.
IMPORTANT: Using `delete_first` means ALL data in that database will be lost permanently.
Returns
-------
engine: Engine SQLAlchemy Engine configured with a database connection
"""

if delete_first or create_if_not_exists:
drop_or_create_database(url, delete_first)
engine = create_engine(url, echo=False, pool_recycle=3600)
return engine


def drop_or_create_database(url: str, delete_first: bool):
server, database = url.rsplit("/", 1)
engine = create_engine(server, echo=False) # Temporary engine, not connected to a database

def drop_or_create_database(delete_first: bool):
url = db_url(including_db=False)
engine = create_engine(url, echo=False) # Temporary engine, not connected to a database
with engine.connect() as connection:
database = DB_CONFIG.get("database", "aiod")
if delete_first:
connection.execute(text(f"DROP DATABASE IF EXISTS {database}"))
connection.execute(text(f"CREATE DATABASE IF NOT EXISTS {database}"))
connection.commit()
engine.dispose()


def _get_existing_resource(
session: Session, resource: AIoDConcept, clazz: type[SQLModel]
session: sqlmodel.Session, resource: AIoDConcept, clazz: type[SQLModel]
) -> AIoDConcept | None:
"""Selecting a resource based on platform and platform_resource_identifier"""
is_enum = NamedRelation in clazz.__mro__
Expand All @@ -70,7 +43,7 @@ def _get_existing_resource(
return session.scalars(query).first()


def _create_or_fetch_related_objects(session: Session, item: ResourceWithRelations):
def _create_or_fetch_related_objects(session: sqlmodel.Session, item: ResourceWithRelations):
"""
For all resources in the `related_resources`, get the identifier, by either
inserting them in the database, or retrieving the existing values, and put the identifiers
Expand Down Expand Up @@ -111,20 +84,3 @@ def _create_or_fetch_related_objects(session: Session, item: ResourceWithRelatio
item.resource.__setattr__(field_name, id_) # E.g. Dataset.license_identifier = 1
else:
item.resource.__setattr__(field_name, identifiers) # E.g. Dataset.keywords = [1, 4]


def sqlmodel_engine(rebuild_db: str) -> Engine:
"""
Return a SQLModel engine, backed by the MySql connection as configured in the configuration
file.
"""
username = DB_CONFIG.get("name", "root")
password = DB_CONFIG.get("password", "ok")
host = DB_CONFIG.get("host", "demodb")
port = DB_CONFIG.get("port", 3306)
database = DB_CONFIG.get("database", "aiod")

db_url = f"mysql://{username}:{password}@{host}:{port}/{database}"

delete_before_create = rebuild_db == "always"
return connect_to_database(db_url, delete_first=delete_before_create)
22 changes: 10 additions & 12 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from fastapi import Depends, FastAPI
from fastapi.responses import HTMLResponse
from pydantic import Json
from sqlalchemy.engine import Engine
from sqlmodel import Session, select
from sqlmodel import select

import routers
from authentication import get_current_user
Expand All @@ -20,7 +19,8 @@
from database.model.concept.concept import AIoDConcept
from database.model.platform.platform import Platform
from database.model.platform.platform_names import PlatformName
from database.setup import sqlmodel_engine
from database.session import EngineSingleton, DbSession
from database.setup import drop_or_create_database
from routers import resource_routers, parent_routers, enum_routers


Expand All @@ -42,7 +42,7 @@ def _parse_args() -> argparse.Namespace:
return parser.parse_args()


def add_routes(app: FastAPI, engine: Engine, url_prefix=""):
def add_routes(app: FastAPI, url_prefix=""):
"""Add routes to the FastAPI application"""

@app.get(url_prefix + "/", response_class=HTMLResponse)
Expand Down Expand Up @@ -73,7 +73,7 @@ def counts() -> dict:
router.resource_name_plural: count
for router in resource_routers.router_list
if issubclass(router.resource_class, AIoDConcept)
and (count := router.get_resource_count_func(engine)(detailed=True))
and (count := router.get_resource_count_func()(detailed=True))
}

for router in (
Expand All @@ -82,7 +82,7 @@ def counts() -> dict:
+ parent_routers.router_list
+ enum_routers.router_list
):
app.include_router(router.create(engine, url_prefix))
app.include_router(router.create(url_prefix))


def create_app() -> FastAPI:
Expand All @@ -100,11 +100,9 @@ def create_app() -> FastAPI:
"scopes": KEYCLOAK_CONFIG.get("scopes"),
},
)
engine = sqlmodel_engine(args.rebuild_db)
with engine.connect() as connection:
AIoDConcept.metadata.create_all(connection, checkfirst=True)
connection.commit()
with Session(engine) as session:
drop_or_create_database(delete_first=args.rebuild_db == "always")
AIoDConcept.metadata.create_all(EngineSingleton().engine, checkfirst=True)
with DbSession() as session:
existing_platforms = session.scalars(select(Platform)).all()
if not any(existing_platforms):
session.add_all([Platform(name=name) for name in PlatformName])
Expand All @@ -115,7 +113,7 @@ def create_app() -> FastAPI:
# empty, and so the triggers should still be added.
add_delete_triggers(AIoDConcept)

add_routes(app, engine, url_prefix=args.url_prefix)
add_routes(app, url_prefix=args.url_prefix)
return app


Expand Down
10 changes: 5 additions & 5 deletions src/routers/enum_routers/enum_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import Type

from fastapi import APIRouter
from sqlalchemy.engine import Engine
from sqlmodel import select, Session

from database.model.named_relation import NamedRelation
from database.session import DbSession


class EnumRouter(abc.ABC):
Expand All @@ -21,7 +21,7 @@ def __init__(self, resource_class: Type[NamedRelation]):
self.resource_name + "s" if not self.resource_name.endswith("s") else self.resource_name
)

def create(self, engine: Engine, url_prefix: str) -> APIRouter:
def create(self, url_prefix: str) -> APIRouter:
router = APIRouter()
version = "v1"
default_kwargs = {
Expand All @@ -30,16 +30,16 @@ def create(self, engine: Engine, url_prefix: str) -> APIRouter:
}
router.add_api_route(
path=url_prefix + f"/{self.resource_name_plural}/{version}",
endpoint=self.get_resources_func(engine),
endpoint=self.get_resources_func(),
response_model=list[str],
name=self.resource_name,
**default_kwargs,
)
return router

def get_resources_func(self, engine: Engine):
def get_resources_func(self):
def get_resources():
with Session(engine) as session:
with DbSession() as session:
query = select(self.resource_class)
resources = session.scalars(query).all()
return [r.name for r in resources]
Expand Down
12 changes: 6 additions & 6 deletions src/routers/parent_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import Union

from fastapi import APIRouter, HTTPException
from sqlalchemy.engine import Engine
from sqlmodel import SQLModel, select, Session
from sqlmodel import SQLModel, select
from starlette.status import HTTP_404_NOT_FOUND, HTTP_500_INTERNAL_SERVER_ERROR

from database.model.concept.concept import AIoDConcept
from database.model.helper_functions import non_abstract_subclasses
from database.session import DbSession
from routers import resource_routers


Expand Down Expand Up @@ -38,7 +38,7 @@ def parent_class(self):
def parent_class_table(self):
"""The table class of the resource. E.g. AgentTable"""

def create(self, engine: Engine, url_prefix: str) -> APIRouter:
def create(self, url_prefix: str) -> APIRouter:
router = APIRouter()
version = "v1"
default_kwargs = {
Expand All @@ -53,16 +53,16 @@ def create(self, engine: Engine, url_prefix: str) -> APIRouter:

router.add_api_route(
path=url_prefix + f"/{self.resource_name_plural}/{version}/{{identifier}}",
endpoint=self.get_resource_func(engine, classes_dict, read_classes_dict),
endpoint=self.get_resource_func(classes_dict, read_classes_dict),
response_model=response_model, # type: ignore
name=self.resource_name,
**default_kwargs,
)
return router

def get_resource_func(self, engine: Engine, classes_dict: dict, read_classes_dict: dict):
def get_resource_func(self, classes_dict: dict, read_classes_dict: dict):
def get_resource(identifier: int):
with Session(engine) as session:
with DbSession() as session:
query = select(self.parent_class_table).where(
self.parent_class_table.identifier == identifier
)
Expand Down
Loading

0 comments on commit 799f63e

Please sign in to comment.