forked from openml-labs/server-demo
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #199 from aiondemand/feature/cleaner-sql-engine-de…
…pendency-injection Using a DbSession and an EngineSingleton to inject the engine
- Loading branch information
Showing
49 changed files
with
409 additions
and
523 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
""" | ||
Enabling access to database sessions. | ||
""" | ||
|
||
from contextlib import contextmanager | ||
|
||
from sqlalchemy.engine import Engine | ||
from sqlmodel import Session, create_engine | ||
|
||
from config import DB_CONFIG | ||
|
||
|
||
class EngineSingleton: | ||
"""Making sure the engine is created only once.""" | ||
|
||
__monostate = None | ||
|
||
def __init__(self): | ||
if not EngineSingleton.__monostate: | ||
EngineSingleton.__monostate = self.__dict__ | ||
self.engine = create_engine(db_url(), echo=False, pool_recycle=3600) | ||
else: | ||
self.__dict__ = EngineSingleton.__monostate | ||
|
||
def patch(self, engine: Engine): | ||
self.__monostate["engine"] = engine # type: ignore | ||
|
||
|
||
def db_url(including_db=True): | ||
username = DB_CONFIG.get("name", "root") | ||
password = DB_CONFIG.get("password", "ok") | ||
host = DB_CONFIG.get("host", "demodb") | ||
port = DB_CONFIG.get("port", 3306) | ||
database = DB_CONFIG.get("database", "aiod") | ||
if including_db: | ||
return f"mysql://{username}:{password}@{host}:{port}/{database}" | ||
return f"mysql://{username}:{password}@{host}:{port}" | ||
|
||
|
||
@contextmanager | ||
def DbSession() -> Session: | ||
""" | ||
Returning a SQLModel session bound to the (configured) database engine. | ||
Alternatively, we could have used FastAPI Depends, but that only works for FastAPI - while | ||
the synchronization, for instance, also needs a Session, but doesn't use FastAPI. | ||
""" | ||
session = Session(EngineSingleton().engine) | ||
try: | ||
yield session | ||
finally: | ||
session.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,60 +1,33 @@ | ||
""" | ||
Utility functions for initializing the database and tables through SQLAlchemy. | ||
""" | ||
|
||
from operator import and_ | ||
|
||
from sqlalchemy import text | ||
from sqlalchemy.engine import Engine | ||
from sqlmodel import create_engine, Session, SQLModel, select | ||
import sqlmodel | ||
from sqlalchemy import text, create_engine | ||
from sqlmodel import SQLModel, select | ||
|
||
from config import DB_CONFIG | ||
from connectors.resource_with_relations import ResourceWithRelations | ||
from database.model.concept.concept import AIoDConcept | ||
from database.model.named_relation import NamedRelation | ||
from database.model.platform.platform_names import PlatformName | ||
from database.session import db_url | ||
from routers import resource_routers | ||
|
||
|
||
def connect_to_database( | ||
url: str = "mysql://root:[email protected]:3307/aiod", | ||
create_if_not_exists: bool = True, | ||
delete_first: bool = False, | ||
) -> Engine: | ||
"""Connect to server, optionally creating the database if it does not exist. | ||
Params | ||
------ | ||
url: URL to the database, see https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls # noqa | ||
create_if_not_exists: create the database if it does not exist | ||
delete_first: drop the database before creating it again, to start with an empty database. | ||
IMPORTANT: Using `delete_first` means ALL data in that database will be lost permanently. | ||
Returns | ||
------- | ||
engine: Engine SQLAlchemy Engine configured with a database connection | ||
""" | ||
|
||
if delete_first or create_if_not_exists: | ||
drop_or_create_database(url, delete_first) | ||
engine = create_engine(url, echo=False, pool_recycle=3600) | ||
return engine | ||
|
||
|
||
def drop_or_create_database(url: str, delete_first: bool): | ||
server, database = url.rsplit("/", 1) | ||
engine = create_engine(server, echo=False) # Temporary engine, not connected to a database | ||
|
||
def drop_or_create_database(delete_first: bool): | ||
url = db_url(including_db=False) | ||
engine = create_engine(url, echo=False) # Temporary engine, not connected to a database | ||
with engine.connect() as connection: | ||
database = DB_CONFIG.get("database", "aiod") | ||
if delete_first: | ||
connection.execute(text(f"DROP DATABASE IF EXISTS {database}")) | ||
connection.execute(text(f"CREATE DATABASE IF NOT EXISTS {database}")) | ||
connection.commit() | ||
engine.dispose() | ||
|
||
|
||
def _get_existing_resource( | ||
session: Session, resource: AIoDConcept, clazz: type[SQLModel] | ||
session: sqlmodel.Session, resource: AIoDConcept, clazz: type[SQLModel] | ||
) -> AIoDConcept | None: | ||
"""Selecting a resource based on platform and platform_resource_identifier""" | ||
is_enum = NamedRelation in clazz.__mro__ | ||
|
@@ -70,7 +43,7 @@ def _get_existing_resource( | |
return session.scalars(query).first() | ||
|
||
|
||
def _create_or_fetch_related_objects(session: Session, item: ResourceWithRelations): | ||
def _create_or_fetch_related_objects(session: sqlmodel.Session, item: ResourceWithRelations): | ||
""" | ||
For all resources in the `related_resources`, get the identifier, by either | ||
inserting them in the database, or retrieving the existing values, and put the identifiers | ||
|
@@ -111,20 +84,3 @@ def _create_or_fetch_related_objects(session: Session, item: ResourceWithRelatio | |
item.resource.__setattr__(field_name, id_) # E.g. Dataset.license_identifier = 1 | ||
else: | ||
item.resource.__setattr__(field_name, identifiers) # E.g. Dataset.keywords = [1, 4] | ||
|
||
|
||
def sqlmodel_engine(rebuild_db: str) -> Engine: | ||
""" | ||
Return a SQLModel engine, backed by the MySql connection as configured in the configuration | ||
file. | ||
""" | ||
username = DB_CONFIG.get("name", "root") | ||
password = DB_CONFIG.get("password", "ok") | ||
host = DB_CONFIG.get("host", "demodb") | ||
port = DB_CONFIG.get("port", 3306) | ||
database = DB_CONFIG.get("database", "aiod") | ||
|
||
db_url = f"mysql://{username}:{password}@{host}:{port}/{database}" | ||
|
||
delete_before_create = rebuild_db == "always" | ||
return connect_to_database(db_url, delete_first=delete_before_create) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.