From ff81d4a486c87783ea38ea2eb678d44d3df393f1 Mon Sep 17 00:00:00 2001 From: Jos van der Velde Date: Fri, 17 Nov 2023 16:29:37 +0100 Subject: [PATCH 1/3] Using a DbSession and an EngineSingleton to inject the engine, instead of giving it to every router. This is cleaner (I think) and makes it easy to patch the engine for unittests. Note that FastAPI Depends would work as well for the routers, but not for other parts of the code. --- src/connectors/synchronization.py | 10 +- src/database/deletion/hard_delete.py | 12 +- src/database/session.py | 52 +++++++ src/database/setup.py | 64 ++------- src/main.py | 20 ++- src/routers/enum_routers/enum_router.py | 10 +- src/routers/parent_router.py | 12 +- src/routers/resource_ai_asset_router.py | 24 ++-- src/routers/resource_router.py | 114 +++++++-------- src/routers/upload_router_huggingface.py | 5 +- .../test_dataset_dcatap_converter.py | 4 +- .../test_dataset_schemaDotOrg_converter.py | 4 +- .../database/deletion/test_hard_delete.py | 21 ++- .../database/model/agent/test_agent_delete.py | 11 +- .../model/agent/test_person_delete.py | 11 +- .../model/ai_asset/test_ai_asset_delete.py | 11 +- .../model/dataset/test_dataset_delete.py | 11 +- .../test_ml_model_delete.py | 11 +- .../model/resource/test_resource_delete.py | 11 +- .../test_router_aiassets_retrieve_content.py | 12 +- .../enum_routers/test_license_router.py | 4 +- .../routers/generic/test_router_delete.py | 25 ++-- .../generic/test_router_deprecation.py | 2 +- .../routers/generic/test_router_get_all.py | 17 +-- .../routers/generic/test_router_get_count.py | 19 +-- .../generic/test_router_platform_get_all.py | 7 +- .../routers/generic/test_router_relations.py | 9 +- .../routers/generic/test_router_scheme.py | 4 +- .../parent_routers/test_agent_router.py | 6 +- .../parent_routers/test_ai_asset_router.py | 4 +- .../parent_routers/test_ai_resource_router.py | 4 +- .../resource_routers/test_router_dataset.py | 4 +- .../test_router_dataset_generic_fields.py | 33 ++--- .../resource_routers/test_router_event.py | 6 +- .../resource_routers/test_router_ml_model.py | 6 +- .../test_router_organisation.py | 6 +- .../resource_routers/test_router_person.py | 6 +- .../resource_routers/test_router_project.py | 6 +- .../resource_routers/test_router_team.py | 6 +- src/tests/testutils/default_instances.py | 28 ++-- src/tests/testutils/default_sqlalchemy.py | 77 +++++----- src/tests/testutils/test_resource.py | 2 +- .../huggingface/test_dataset_uploader.py | 12 +- src/uploader/hugging_face_uploader.py | 132 +++++++++--------- 44 files changed, 389 insertions(+), 466 deletions(-) create mode 100644 src/database/session.py diff --git a/src/connectors/synchronization.py b/src/connectors/synchronization.py index 866f595f..5bd223d2 100644 --- a/src/connectors/synchronization.py +++ b/src/connectors/synchronization.py @@ -8,13 +8,14 @@ from datetime import datetime from typing import Optional -from sqlmodel import Session, select +from sqlmodel import select, Session from connectors.abstract.resource_connector import ResourceConnector, RESOURCE from connectors.record_error import RecordError from connectors.resource_with_relations import ResourceWithRelations from database.model.concept.concept import AIoDConcept -from database.setup import _create_or_fetch_related_objects, _get_existing_resource, sqlmodel_engine +from database.session import DbSession +from database.setup import _create_or_fetch_related_objects, _get_existing_resource from routers import ResourceRouter, resource_routers, enum_routers RELATIVE_PATH_STATE_JSON = pathlib.Path("state.json") @@ -141,8 +142,7 @@ def main(): state_path = working_dir / RELATIVE_PATH_STATE_JSON first_run = not state_path.exists() - engine = sqlmodel_engine(rebuild_db="never") - with Session(engine) as session: + with DbSession() as session: db_empty = session.scalars(select(connector.resource_class)).first() is None if first_run or db_empty: @@ -165,7 +165,7 @@ def main(): if router.resource_class == connector.resource_class ] - with Session(engine) as session: + with DbSession() as session: for i, item in enumerate(items): error = save_to_database(router=router, connector=connector, session=session, item=item) if error: diff --git a/src/database/deletion/hard_delete.py b/src/database/deletion/hard_delete.py index 45b158f2..714750aa 100644 --- a/src/database/deletion/hard_delete.py +++ b/src/database/deletion/hard_delete.py @@ -12,19 +12,17 @@ from typing import Type from sqlalchemy import delete, and_ -from sqlalchemy.engine import Engine from sqlalchemy.sql.operators import is_not -from sqlmodel import Session from database.model.concept.concept import AIoDConcept from database.model.helper_functions import non_abstract_subclasses -from database.setup import sqlmodel_engine +from database.session import DbSession -def hard_delete_older_than(engine: Engine, time_threshold: timedelta): +def hard_delete_older_than(time_threshold: timedelta): classes: list[Type[AIoDConcept]] = non_abstract_subclasses(AIoDConcept) date_threshold = datetime.datetime.now() - time_threshold - with Session(engine) as session: + with DbSession() as session: for concept in classes: filter_ = and_( is_not(concept.date_deleted, None), @@ -50,10 +48,8 @@ def _parse_args() -> argparse.Namespace: def main(): args = _parse_args() - - engine = sqlmodel_engine(rebuild_db="never") time_threshold = timedelta(minutes=args.time_threshold_minutes) - hard_delete_older_than(engine=engine, time_threshold=time_threshold) + hard_delete_older_than(time_threshold=time_threshold) if __name__ == "__main__": diff --git a/src/database/session.py b/src/database/session.py new file mode 100644 index 00000000..49ba51f0 --- /dev/null +++ b/src/database/session.py @@ -0,0 +1,52 @@ +""" +Enabling access to database sessions. +""" + +from contextlib import contextmanager + +from sqlalchemy.engine import Engine +from sqlmodel import Session, create_engine + +from config import DB_CONFIG + + +class EngineSingleton: + """Making sure the engine is created only once.""" + + __monostate = None + + def __init__(self): + if not EngineSingleton.__monostate: + EngineSingleton.__monostate = self.__dict__ + self.engine = create_engine(db_url(), echo=False, pool_recycle=3600) + else: + self.__dict__ = EngineSingleton.__monostate + + def patch(self, engine: Engine): + self.__monostate["engine"] = engine # type: ignore + + +def db_url(including_db=True): + username = DB_CONFIG.get("name", "root") + password = DB_CONFIG.get("password", "ok") + host = DB_CONFIG.get("host", "demodb") + port = DB_CONFIG.get("port", 3306) + database = DB_CONFIG.get("database", "aiod") + if including_db: + return f"mysql://{username}:{password}@{host}:{port}/{database}" + return f"mysql://{username}:{password}@{host}:{port}" + + +@contextmanager +def DbSession() -> Session: + """ + Returning a SQLModel session bound to the (configured) database engine. + + Alternatively, we could have used FastAPI Depends, but that only works for FastAPI - while + the synchronization, for instance, also needs a Session, but doesn't use FastAPI. + """ + session = Session(EngineSingleton().engine) + try: + yield session + finally: + session.close() diff --git a/src/database/setup.py b/src/database/setup.py index bc04c3d9..c743d47e 100644 --- a/src/database/setup.py +++ b/src/database/setup.py @@ -1,60 +1,33 @@ """ Utility functions for initializing the database and tables through SQLAlchemy. """ - from operator import and_ -from sqlalchemy import text -from sqlalchemy.engine import Engine -from sqlmodel import create_engine, Session, SQLModel, select +import sqlmodel +from sqlalchemy import text, create_engine +from sqlmodel import SQLModel, select from config import DB_CONFIG from connectors.resource_with_relations import ResourceWithRelations from database.model.concept.concept import AIoDConcept from database.model.named_relation import NamedRelation from database.model.platform.platform_names import PlatformName +from database.session import db_url from routers import resource_routers -def connect_to_database( - url: str = "mysql://root:ok@127.0.0.1:3307/aiod", - create_if_not_exists: bool = True, - delete_first: bool = False, -) -> Engine: - """Connect to server, optionally creating the database if it does not exist. - - Params - ------ - url: URL to the database, see https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls # noqa - create_if_not_exists: create the database if it does not exist - delete_first: drop the database before creating it again, to start with an empty database. - IMPORTANT: Using `delete_first` means ALL data in that database will be lost permanently. - - Returns - ------- - engine: Engine SQLAlchemy Engine configured with a database connection - """ - - if delete_first or create_if_not_exists: - drop_or_create_database(url, delete_first) - engine = create_engine(url, echo=False, pool_recycle=3600) - return engine - - -def drop_or_create_database(url: str, delete_first: bool): - server, database = url.rsplit("/", 1) - engine = create_engine(server, echo=False) # Temporary engine, not connected to a database - +def drop_or_create_database(delete_first: bool): + url = db_url(including_db=False) + engine = create_engine(url, echo=False) # Temporary engine, not connected to a database with engine.connect() as connection: + database = DB_CONFIG.get("database", "aiod") if delete_first: connection.execute(text(f"DROP DATABASE IF EXISTS {database}")) connection.execute(text(f"CREATE DATABASE IF NOT EXISTS {database}")) - connection.commit() - engine.dispose() def _get_existing_resource( - session: Session, resource: AIoDConcept, clazz: type[SQLModel] + session: sqlmodel.Session, resource: AIoDConcept, clazz: type[SQLModel] ) -> AIoDConcept | None: """Selecting a resource based on platform and platform_resource_identifier""" is_enum = NamedRelation in clazz.__mro__ @@ -70,7 +43,7 @@ def _get_existing_resource( return session.scalars(query).first() -def _create_or_fetch_related_objects(session: Session, item: ResourceWithRelations): +def _create_or_fetch_related_objects(session: sqlmodel.Session, item: ResourceWithRelations): """ For all resources in the `related_resources`, get the identifier, by either inserting them in the database, or retrieving the existing values, and put the identifiers @@ -111,20 +84,3 @@ def _create_or_fetch_related_objects(session: Session, item: ResourceWithRelatio item.resource.__setattr__(field_name, id_) # E.g. Dataset.license_identifier = 1 else: item.resource.__setattr__(field_name, identifiers) # E.g. Dataset.keywords = [1, 4] - - -def sqlmodel_engine(rebuild_db: str) -> Engine: - """ - Return a SQLModel engine, backed by the MySql connection as configured in the configuration - file. - """ - username = DB_CONFIG.get("name", "root") - password = DB_CONFIG.get("password", "ok") - host = DB_CONFIG.get("host", "demodb") - port = DB_CONFIG.get("port", 3306) - database = DB_CONFIG.get("database", "aiod") - - db_url = f"mysql://{username}:{password}@{host}:{port}/{database}" - - delete_before_create = rebuild_db == "always" - return connect_to_database(db_url, delete_first=delete_before_create) diff --git a/src/main.py b/src/main.py index a676eb52..9e401dd0 100644 --- a/src/main.py +++ b/src/main.py @@ -10,8 +10,7 @@ from fastapi import Depends, FastAPI from fastapi.responses import HTMLResponse from pydantic import Json -from sqlalchemy.engine import Engine -from sqlmodel import Session, select +from sqlmodel import select import routers from authentication import get_current_user @@ -20,7 +19,8 @@ from database.model.concept.concept import AIoDConcept from database.model.platform.platform import Platform from database.model.platform.platform_names import PlatformName -from database.setup import sqlmodel_engine +from database.session import EngineSingleton, DbSession +from database.setup import drop_or_create_database from routers import resource_routers, parent_routers, enum_routers @@ -42,7 +42,7 @@ def _parse_args() -> argparse.Namespace: return parser.parse_args() -def add_routes(app: FastAPI, engine: Engine, url_prefix=""): +def add_routes(app: FastAPI, url_prefix=""): """Add routes to the FastAPI application""" @app.get(url_prefix + "/", response_class=HTMLResponse) @@ -73,7 +73,7 @@ def test_authorization(user: Json = Depends(get_current_user)) -> dict: + parent_routers.router_list + enum_routers.router_list ): - app.include_router(router.create(engine, url_prefix)) + app.include_router(router.create(url_prefix)) def create_app() -> FastAPI: @@ -91,11 +91,9 @@ def create_app() -> FastAPI: "scopes": KEYCLOAK_CONFIG.get("scopes"), }, ) - engine = sqlmodel_engine(args.rebuild_db) - with engine.connect() as connection: - AIoDConcept.metadata.create_all(connection, checkfirst=True) - connection.commit() - with Session(engine) as session: + drop_or_create_database(delete_first=args.rebuild_db == "allways") + AIoDConcept.metadata.create_all(EngineSingleton().engine, checkfirst=True) + with DbSession() as session: existing_platforms = session.scalars(select(Platform)).all() if not any(existing_platforms): session.add_all([Platform(name=name) for name in PlatformName]) @@ -106,7 +104,7 @@ def create_app() -> FastAPI: # empty, and so the triggers should still be added. add_delete_triggers(AIoDConcept) - add_routes(app, engine, url_prefix=args.url_prefix) + add_routes(app, url_prefix=args.url_prefix) return app diff --git a/src/routers/enum_routers/enum_router.py b/src/routers/enum_routers/enum_router.py index 0efb1b88..87d71306 100644 --- a/src/routers/enum_routers/enum_router.py +++ b/src/routers/enum_routers/enum_router.py @@ -2,10 +2,10 @@ from typing import Type from fastapi import APIRouter -from sqlalchemy.engine import Engine from sqlmodel import select, Session from database.model.named_relation import NamedRelation +from database.session import DbSession class EnumRouter(abc.ABC): @@ -21,7 +21,7 @@ def __init__(self, resource_class: Type[NamedRelation]): self.resource_name + "s" if not self.resource_name.endswith("s") else self.resource_name ) - def create(self, engine: Engine, url_prefix: str) -> APIRouter: + def create(self, url_prefix: str) -> APIRouter: router = APIRouter() version = "v1" default_kwargs = { @@ -30,16 +30,16 @@ def create(self, engine: Engine, url_prefix: str) -> APIRouter: } router.add_api_route( path=url_prefix + f"/{self.resource_name_plural}/{version}", - endpoint=self.get_resources_func(engine), + endpoint=self.get_resources_func(), response_model=list[str], name=self.resource_name, **default_kwargs, ) return router - def get_resources_func(self, engine: Engine): + def get_resources_func(self): def get_resources(): - with Session(engine) as session: + with DbSession() as session: query = select(self.resource_class) resources = session.scalars(query).all() return [r.name for r in resources] diff --git a/src/routers/parent_router.py b/src/routers/parent_router.py index 4f3cf120..cd7ea2ee 100644 --- a/src/routers/parent_router.py +++ b/src/routers/parent_router.py @@ -2,12 +2,12 @@ from typing import Union from fastapi import APIRouter, HTTPException -from sqlalchemy.engine import Engine -from sqlmodel import SQLModel, select, Session +from sqlmodel import SQLModel, select from starlette.status import HTTP_404_NOT_FOUND, HTTP_500_INTERNAL_SERVER_ERROR from database.model.concept.concept import AIoDConcept from database.model.helper_functions import non_abstract_subclasses +from database.session import DbSession from routers import resource_routers @@ -38,7 +38,7 @@ def parent_class(self): def parent_class_table(self): """The table class of the resource. E.g. AgentTable""" - def create(self, engine: Engine, url_prefix: str) -> APIRouter: + def create(self, url_prefix: str) -> APIRouter: router = APIRouter() version = "v1" default_kwargs = { @@ -53,16 +53,16 @@ def create(self, engine: Engine, url_prefix: str) -> APIRouter: router.add_api_route( path=url_prefix + f"/{self.resource_name_plural}/{version}/{{identifier}}", - endpoint=self.get_resource_func(engine, classes_dict, read_classes_dict), + endpoint=self.get_resource_func(classes_dict, read_classes_dict), response_model=response_model, # type: ignore name=self.resource_name, **default_kwargs, ) return router - def get_resource_func(self, engine: Engine, classes_dict: dict, read_classes_dict: dict): + def get_resource_func(self, classes_dict: dict, read_classes_dict: dict): def get_resource(identifier: int): - with Session(engine) as session: + with DbSession() as session: query = select(self.parent_class_table).where( self.parent_class_table.identifier == identifier ) diff --git a/src/routers/resource_ai_asset_router.py b/src/routers/resource_ai_asset_router.py index 22a34b7a..65eee8b5 100644 --- a/src/routers/resource_ai_asset_router.py +++ b/src/routers/resource_ai_asset_router.py @@ -1,15 +1,13 @@ -from fastapi.responses import Response -from fastapi import APIRouter, HTTPException, status import requests -from sqlalchemy.engine import Engine +from fastapi import APIRouter, HTTPException, status +from fastapi.responses import Response from database.model.ai_asset.ai_asset import AIAsset - from .resource_router import ResourceRouter, _wrap_as_http_exception class ResourceAIAssetRouter(ResourceRouter): - def create(self, engine: Engine, url_prefix: str) -> APIRouter: + def create(self, url_prefix: str) -> APIRouter: version = "v1" default_kwargs = { "response_model_exclude_none": True, @@ -17,11 +15,11 @@ def create(self, engine: Engine, url_prefix: str) -> APIRouter: "tags": [self.resource_name_plural], } - router = super().create(engine, url_prefix) + router = super().create(url_prefix) router.add_api_route( path=f"{url_prefix}/{self.resource_name_plural}/{version}/{{identifier}}/content", - endpoint=self.get_resource_content_func(engine, default=True), + endpoint=self.get_resource_content_func(default=True), name=self.resource_name, response_model=str, **default_kwargs, @@ -30,7 +28,7 @@ def create(self, engine: Engine, url_prefix: str) -> APIRouter: router.add_api_route( path=f"{url_prefix}/{self.resource_name_plural}/{version}/{{identifier}}/content/" f"{{distribution_idx}}", - endpoint=self.get_resource_content_func(engine, default=False), + endpoint=self.get_resource_content_func(default=False), name=self.resource_name, response_model=str, **default_kwargs, @@ -38,19 +36,23 @@ def create(self, engine: Engine, url_prefix: str) -> APIRouter: return router - def get_resource_content_func(self, engine: Engine, default: bool): + def get_resource_content_func(self, default: bool): """ Returns a function to download the content from resources. This function returns a function (instead of being that function directly) because the docstring and the variables are dynamic, and used in Swagger. """ - def get_resource_content(identifier: str, distribution_idx: int, default: bool = False): + def get_resource_content( + identifier: str, + distribution_idx: int, + default: bool = False, + ): f"""Retrieve a distribution of the content for {self.resource_name} identified by its identifier.""" metadata: AIAsset = self.get_resource( - engine=engine, identifier=identifier, schema="aiod", platform=None + identifier=identifier, schema="aiod", platform=None ) # type: ignore distributions = metadata.distribution diff --git a/src/routers/resource_router.py b/src/routers/resource_router.py index a43b5e72..35584772 100644 --- a/src/routers/resource_router.py +++ b/src/routers/resource_router.py @@ -10,7 +10,6 @@ from fastapi.encoders import jsonable_encoder from pydantic import BaseModel from sqlalchemy import and_ -from sqlalchemy.engine import Engine from sqlalchemy.sql.operators import is_ from sqlmodel import SQLModel, Session, select from starlette.responses import JSONResponse @@ -27,6 +26,7 @@ resource_read, ) from database.model.serializers import deserialize_resource_relationships +from database.session import DbSession class Pagination(BaseModel): @@ -104,7 +104,7 @@ def schema_converters(self) -> dict[str, SchemaConverter[RESOURCE, Any]]: """ return {} - def create(self, engine: Engine, url_prefix: str) -> APIRouter: + def create(self, url_prefix: str) -> APIRouter: router = APIRouter() version = f"v{self.version}" default_kwargs = { @@ -120,14 +120,15 @@ def create(self, engine: Engine, url_prefix: str) -> APIRouter: router.add_api_route( path=f"{url_prefix}/{self.resource_name_plural}/{version}", - endpoint=self.get_resources_func(engine), + endpoint=self.get_resources_func(), response_model=response_model_plural, # type: ignore name=f"List {self.resource_name_plural}", + description=f"Retrieve all meta-data of the {self.resource_name_plural}.", **default_kwargs, ) router.add_api_route( path=f"{url_prefix}/counts/{self.resource_name_plural}/v1", - endpoint=self.get_resource_count_func(engine), + endpoint=self.get_resource_count_func(), response_model=int, # type: ignore name=f"Count of {self.resource_name_plural}", **default_kwargs, @@ -135,13 +136,13 @@ def create(self, engine: Engine, url_prefix: str) -> APIRouter: router.add_api_route( path=f"{url_prefix}/{self.resource_name_plural}/{version}", methods={"POST"}, - endpoint=self.register_resource_func(engine), + endpoint=self.register_resource_func(), name=self.resource_name, **default_kwargs, ) router.add_api_route( path=url_prefix + f"/{self.resource_name_plural}/{version}/{{identifier}}", - endpoint=self.get_resource_func(engine), + endpoint=self.get_resource_func(), response_model=response_model, # type: ignore name=self.resource_name, **default_kwargs, @@ -149,21 +150,21 @@ def create(self, engine: Engine, url_prefix: str) -> APIRouter: router.add_api_route( path=f"{url_prefix}/{self.resource_name_plural}/{version}/{{identifier}}", methods={"PUT"}, - endpoint=self.put_resource_func(engine), + endpoint=self.put_resource_func(), name=self.resource_name, **default_kwargs, ) router.add_api_route( path=f"{url_prefix}/{self.resource_name_plural}/{version}/{{identifier}}", methods={"DELETE"}, - endpoint=self.delete_resource_func(engine), + endpoint=self.delete_resource_func(), name=self.resource_name, **default_kwargs, ) if hasattr(self.resource_class, "platform"): router.add_api_route( path=f"{url_prefix}/platforms/{{platform}}/{self.resource_name_plural}/{version}", - endpoint=self.get_platform_resources_func(engine), + endpoint=self.get_platform_resources_func(), response_model=response_model_plural, # type: ignore name=f"List {self.resource_name_plural}", **default_kwargs, @@ -171,20 +172,18 @@ def create(self, engine: Engine, url_prefix: str) -> APIRouter: router.add_api_route( path=f"{url_prefix}/platforms/{{platform}}/{self.resource_name_plural}/{version}" f"/{{identifier}}", - endpoint=self.get_platform_resource_func(engine), + endpoint=self.get_platform_resource_func(), response_model=response_model, # type: ignore name=self.resource_name, **default_kwargs, ) return router - def get_resources( - self, engine: Engine, schema: str, pagination: Pagination, platform: str | None = None - ): + def get_resources(self, schema: str, pagination: Pagination, platform: str | None = None): """Fetch all resources of this platform in given schema, using pagination""" _raise_error_on_invalid_schema(self._possible_schemas, schema) - try: - with Session(engine) as session: + with DbSession() as session: + try: convert_schema = ( partial(self.schema_converters[schema].convert, session) if schema != "aiod" @@ -204,19 +203,17 @@ def get_resources( return self._wrap_with_headers( [convert_schema(resource) for resource in session.scalars(query).all()] ) - except Exception as e: - raise _wrap_as_http_exception(e) + except Exception as e: + raise _wrap_as_http_exception(e) - def get_resource( - self, engine: Engine, identifier: str, schema: str, platform: str | None = None - ): + def get_resource(self, identifier: str, schema: str, platform: str | None = None): """ Get the resource identified by AIoD identifier (if platform is None) or by platform AND platform-identifier (if platform is not None), return in given schema. """ _raise_error_on_invalid_schema(self._possible_schemas, schema) try: - with Session(engine) as session: + with DbSession() as session: resource = self._retrieve_resource(session, identifier, platform=platform) if schema != "aiod": return self.schema_converters[schema].convert(session, resource) @@ -224,7 +221,7 @@ def get_resource( except Exception as e: raise _wrap_as_http_exception(e) - def get_resources_func(self, engine: Engine): + def get_resources_func(self): """ Return a function that can be used to retrieve a list of resources. This function returns a function (instead of being that function directly) because the @@ -235,15 +232,12 @@ def get_resources( pagination: Pagination = Depends(Pagination), schema: Literal[tuple(self._possible_schemas)] = "aiod", # type:ignore ): - f"""Retrieve all meta-data of the {self.resource_name_plural}.""" - resources = self.get_resources( - engine=engine, pagination=pagination, schema=schema, platform=None - ) + resources = self.get_resources(pagination=pagination, schema=schema, platform=None) return resources return get_resources - def get_resource_count_func(self, engine: Engine): + def get_resource_count_func(self): """ Gets the total number of resources from the database. This function returns a function (instead of being that function directly) because the @@ -253,7 +247,7 @@ def get_resource_count_func(self, engine: Engine): def get_resource_count(): f"""Retrieve the number of {self.resource_name_plural}.""" try: - with Session(engine) as session: + with DbSession() as session: return ( session.query(self.resource_class) .where(is_(self.resource_class.date_deleted, None)) @@ -264,7 +258,7 @@ def get_resource_count(): return get_resource_count - def get_platform_resources_func(self, engine: Engine): + def get_platform_resources_func(self): """ Return a function that can be used to retrieve a list of resources for a platform. This function returns a function (instead of being that function directly) because the @@ -277,14 +271,12 @@ def get_resources( schema: Literal[tuple(self._possible_schemas)] = "aiod", # type:ignore ): f"""Retrieve all meta-data of the {self.resource_name_plural} of given platform.""" - resources = self.get_resources( - engine=engine, pagination=pagination, schema=schema, platform=platform - ) + resources = self.get_resources(pagination=pagination, schema=schema, platform=platform) return resources return get_resources - def get_resource_func(self, engine: Engine): + def get_resource_func(self): """ Return a function that can be used to retrieve a single resource. This function returns a function (instead of being that function directly) because the @@ -292,19 +284,18 @@ def get_resource_func(self, engine: Engine): """ def get_resource( - identifier: str, schema: Literal[tuple(self._possible_schemas)] = "aiod" # type:ignore + identifier: str, + schema: Literal[tuple(self._possible_schemas)] = "aiod", # type:ignore ): f""" Retrieve all meta-data for a {self.resource_name} identified by the AIoD identifier. """ - resource = self.get_resource( - engine=engine, identifier=identifier, schema=schema, platform=None - ) + resource = self.get_resource(identifier=identifier, schema=schema, platform=None) return self._wrap_with_headers(resource) return get_resource - def get_platform_resource_func(self, engine: Engine): + def get_platform_resource_func(self): """ Return a function that can be used to retrieve a single resource of a platform. This function returns a function (instead of being that function directly) because the @@ -318,13 +309,11 @@ def get_resource( ): f"""Retrieve all meta-data for a {self.resource_name} identified by the platform-specific-identifier.""" - return self.get_resource( - engine=engine, identifier=identifier, schema=schema, platform=platform - ) + return self.get_resource(identifier=identifier, schema=schema, platform=platform) return get_resource - def register_resource_func(self, engine: Engine): + def register_resource_func(self): """ Return a function that can be used to register a resource. This function returns a function (instead of being that function directly) because the @@ -343,7 +332,7 @@ def register_resource( detail="You do not have permission to edit Aiod resources.", ) try: - with Session(engine) as session: + with DbSession() as session: try: resource = self.create_resource(session, resource_create) return self._wrap_with_headers({"identifier": resource.identifier}) @@ -364,7 +353,7 @@ def create_resource(self, session: Session, resource_create_instance: SQLModel): session.commit() return resource - def put_resource_func(self, engine: Engine): + def put_resource_func(self): """ Return a function that can be used to update a resource. This function returns a function (instead of being that function directly) because the @@ -384,8 +373,8 @@ def put_resource( detail="You do not have permission to edit Aiod resources.", ) - try: - with Session(engine) as session: + with DbSession() as session: + try: resource = self._retrieve_resource(session, identifier) for attribute_name in resource.schema()["properties"]: if hasattr(resource_create_instance, attribute_name): @@ -401,27 +390,30 @@ def put_resource( session.commit() except Exception as e: self._raise_clean_http_exception(e, session, resource_create_instance) - return self._wrap_with_headers(None) - except Exception as e: - raise self._raise_clean_http_exception(e, session, resource_create_instance) + return self._wrap_with_headers(None) + except Exception as e: + raise self._raise_clean_http_exception(e, session, resource_create_instance) return put_resource - def delete_resource_func(self, engine: Engine): + def delete_resource_func(self): """ Return a function that can be used to delete a resource. This function returns a function (instead of being that function directly) because the docstring is dynamic and used in Swagger. """ - def delete_resource(identifier: str, user: dict = Depends(get_current_user)): - if "groups" in user and KEYCLOAK_CONFIG.get("role") not in user["groups"]: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="You do not have permission to delete Aiod resources.", - ) - try: - with Session(engine) as session: + def delete_resource( + identifier: str, + user: dict = Depends(get_current_user), + ): + with DbSession() as session: + if "groups" in user and KEYCLOAK_CONFIG.get("role") not in user["groups"]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have permission to delete Aiod resources.", + ) + try: # Raise error if it does not exist resource = self._retrieve_resource(session, identifier) if ( @@ -433,9 +425,9 @@ def delete_resource(identifier: str, user: dict = Depends(get_current_user)): resource.date_deleted = datetime.datetime.utcnow() session.add(resource) session.commit() - return self._wrap_with_headers(None) - except Exception as e: - raise _wrap_as_http_exception(e) + return self._wrap_with_headers(None) + except Exception as e: + raise _wrap_as_http_exception(e) return delete_resource diff --git a/src/routers/upload_router_huggingface.py b/src/routers/upload_router_huggingface.py index c3e254ca..740244a3 100644 --- a/src/routers/upload_router_huggingface.py +++ b/src/routers/upload_router_huggingface.py @@ -1,12 +1,11 @@ from fastapi import APIRouter from fastapi import File, Query, UploadFile -from sqlalchemy.engine import Engine from uploader.hugging_face_uploader import handle_upload class UploadRouterHuggingface: - def create(self, engine: Engine, url_prefix: str) -> APIRouter: + def create(self, url_prefix: str) -> APIRouter: router = APIRouter() @router.post(url_prefix + "/upload/datasets/{identifier}/huggingface", tags=["upload"]) @@ -22,6 +21,6 @@ def huggingFaceUpload( ..., title="Huggingface username", description="The username of HuggingFace" ), ) -> int: - return handle_upload(engine, identifier, file, token, username) + return handle_upload(identifier, file, token, username) return router diff --git a/src/tests/converters/schema_converters/test_dataset_dcatap_converter.py b/src/tests/converters/schema_converters/test_dataset_dcatap_converter.py index 4c14761c..5df56c36 100644 --- a/src/tests/converters/schema_converters/test_dataset_dcatap_converter.py +++ b/src/tests/converters/schema_converters/test_dataset_dcatap_converter.py @@ -1,7 +1,6 @@ import datetime from sqlalchemy.engine import Engine -from sqlmodel import Session from converters.schema_converters import dataset_converter_dcatap_instance from database.model.agent.person import Person @@ -10,6 +9,7 @@ from database.model.dataset.dataset import Dataset from database.model.dataset.size import DatasetSizeORM from database.model.knowledge_asset.publication import Publication +from database.session import DbSession from tests.testutils.paths import path_test_resources @@ -27,7 +27,7 @@ def test_aiod_to_dcatap_happy_path(engine: Engine, dataset: Dataset): dataset.date_published = datetime.datetime(2023, 8, 22, 4, 5, 6) converter = dataset_converter_dcatap_instance - with Session(engine) as session: + with DbSession() as session: dcat_ap = converter.convert(session, dataset) actual = dcat_ap.json(by_alias=True, indent=4) diff --git a/src/tests/converters/schema_converters/test_dataset_schemaDotOrg_converter.py b/src/tests/converters/schema_converters/test_dataset_schemaDotOrg_converter.py index cc71348b..514bcbeb 100644 --- a/src/tests/converters/schema_converters/test_dataset_schemaDotOrg_converter.py +++ b/src/tests/converters/schema_converters/test_dataset_schemaDotOrg_converter.py @@ -1,7 +1,6 @@ import datetime from sqlalchemy.engine import Engine -from sqlmodel import Session from converters.schema_converters import dataset_converter_schema_dot_org_instance from database.model.agent.agent_table import AgentTable @@ -12,6 +11,7 @@ from database.model.dataset.dataset import Dataset from database.model.dataset.size import DatasetSizeORM from database.model.knowledge_asset.publication import Publication +from database.session import DbSession from tests.testutils.paths import path_test_resources @@ -31,7 +31,7 @@ def test_aiod_to_schema_dot_org_happy_path(engine: Engine, dataset: Dataset): dataset.date_published = datetime.datetime(2023, 8, 22, 4, 5, 6) converter = dataset_converter_schema_dot_org_instance - with Session(engine) as session: + with DbSession() as session: session.add(creator) session.add(funder) session.commit() diff --git a/src/tests/database/deletion/test_hard_delete.py b/src/tests/database/deletion/test_hard_delete.py index 7bcf6de3..bc4dfe0e 100644 --- a/src/tests/database/deletion/test_hard_delete.py +++ b/src/tests/database/deletion/test_hard_delete.py @@ -1,18 +1,17 @@ import datetime from unittest.mock import Mock -from sqlalchemy.future import Engine -from sqlmodel import Session, select +from sqlmodel import select from authentication import keycloak_openid from database.deletion import hard_delete from database.model.concept.aiod_entry import AIoDEntryORM from database.model.concept.status import Status -from tests.testutils.test_resource import test_resource_factory, TestResource +from database.session import DbSession +from tests.testutils.test_resource import factory, TestResource def test_hard_delete( - engine_test_resource: Engine, mocked_privileged_token: Mock, draft: Status, ): @@ -20,31 +19,31 @@ def test_hard_delete( now = datetime.datetime.now() deletion_time = now - datetime.timedelta(seconds=10) - with Session(engine_test_resource) as session: + with DbSession() as session: session.add_all( [ - test_resource_factory( + factory( title="test_resource_to_keep", platform="example", platform_resource_identifier=1, status=draft, date_deleted=None, ), - test_resource_factory( + factory( title="test_resource_to_keep_2", platform="example", platform_resource_identifier=2, status=draft, date_deleted=now, ), - test_resource_factory( + factory( title="my_test_resource", platform="example", platform_resource_identifier=3, status=draft, date_deleted=deletion_time, ), - test_resource_factory( + factory( title="second_test_resource", platform="example", platform_resource_identifier=4, @@ -55,8 +54,8 @@ def test_hard_delete( ) session.commit() - hard_delete.hard_delete_older_than(engine_test_resource, datetime.timedelta(seconds=5)) - with Session(engine_test_resource) as session: + hard_delete.hard_delete_older_than(datetime.timedelta(seconds=5)) + with DbSession() as session: resources = session.scalars(select(TestResource)).all() assert len(resources) == 2 assert {r.platform_resource_identifier for r in resources} == {"1", "2"} diff --git a/src/tests/database/model/agent/test_agent_delete.py b/src/tests/database/model/agent/test_agent_delete.py index 5356d235..6d56721b 100644 --- a/src/tests/database/model/agent/test_agent_delete.py +++ b/src/tests/database/model/agent/test_agent_delete.py @@ -1,16 +1,13 @@ -from sqlalchemy.engine import Engine -from sqlmodel import Session, select +from sqlmodel import select from starlette.testclient import TestClient from database.model.agent.agent_table import AgentTable from database.model.agent.organisation import Organisation from database.model.agent.person import Person +from database.session import DbSession -def test_happy_path( - client: TestClient, - engine: Engine, -): +def test_happy_path(client: TestClient): organisation = Organisation( name="organisation", agent_identifier=AgentTable(type="organisation"), @@ -20,7 +17,7 @@ def test_happy_path( agent_identifier=AgentTable(type="person"), ) - with Session(engine) as session: + with DbSession() as session: session.add(person) session.merge(organisation) session.commit() diff --git a/src/tests/database/model/agent/test_person_delete.py b/src/tests/database/model/agent/test_person_delete.py index 53692023..2acebd0b 100644 --- a/src/tests/database/model/agent/test_person_delete.py +++ b/src/tests/database/model/agent/test_person_delete.py @@ -1,15 +1,12 @@ -from sqlalchemy.engine import Engine -from sqlmodel import Session, select +from sqlmodel import select from starlette.testclient import TestClient from database.model.agent.expertise import Expertise from database.model.agent.person import Person +from database.session import DbSession -def test_happy_path( - client: TestClient, - engine: Engine, -): +def test_happy_path(client: TestClient): expertise_a = Expertise(name="just") expertise_b = Expertise(name="an") expertise_c = Expertise(name="example") @@ -22,7 +19,7 @@ def test_happy_path( expertise=[expertise_a, expertise_c], ) - with Session(engine) as session: + with DbSession() as session: session.add(person_a) session.add(person_b) session.commit() diff --git a/src/tests/database/model/ai_asset/test_ai_asset_delete.py b/src/tests/database/model/ai_asset/test_ai_asset_delete.py index c1a816c2..e3a23ee0 100644 --- a/src/tests/database/model/ai_asset/test_ai_asset_delete.py +++ b/src/tests/database/model/ai_asset/test_ai_asset_delete.py @@ -1,16 +1,13 @@ -from sqlalchemy.engine import Engine -from sqlmodel import Session, select +from sqlmodel import select from starlette.testclient import TestClient from database.model.ai_asset.ai_asset_table import AIAssetTable from database.model.dataset.dataset import Dataset from database.model.knowledge_asset.publication import Publication +from database.session import DbSession -def test_happy_path( - client: TestClient, - engine: Engine, -): +def test_happy_path(client: TestClient): dataset_distribution = Dataset.__annotations__["distribution"].__args__[0] publication_distribution = Publication.__annotations__["distribution"].__args__[0] dataset_1 = Dataset( @@ -38,7 +35,7 @@ def test_happy_path( ai_asset_identifier=AIAssetTable(type="publication"), ) - with Session(engine) as session: + with DbSession() as session: session.add_all([dataset_1, dataset_2, publication]) session.commit() session.delete(dataset_1) diff --git a/src/tests/database/model/dataset/test_dataset_delete.py b/src/tests/database/model/dataset/test_dataset_delete.py index 60301834..8641410b 100644 --- a/src/tests/database/model/dataset/test_dataset_delete.py +++ b/src/tests/database/model/dataset/test_dataset_delete.py @@ -1,5 +1,4 @@ -from sqlalchemy.engine import Engine -from sqlmodel import Session, select +from sqlmodel import select from starlette.testclient import TestClient from database.model.agent.agent_table import AgentTable @@ -7,12 +6,10 @@ from database.model.ai_asset.ai_asset_table import AIAssetTable from database.model.dataset.dataset import Dataset from database.model.dataset.size import DatasetSizeORM +from database.session import DbSession -def test_happy_path( - client: TestClient, - engine: Engine, -): +def test_happy_path(client: TestClient): dataset = Dataset( ai_asset_identifier=AIAssetTable(type="dataset"), @@ -24,7 +21,7 @@ def test_happy_path( funder=[AgentTable(type="person")], ) - with Session(engine) as session: + with DbSession() as session: session.add(dataset) session.commit() assert len(session.scalars(select(Dataset)).all()) == 1 diff --git a/src/tests/database/model/models_and_experiments/test_ml_model_delete.py b/src/tests/database/model/models_and_experiments/test_ml_model_delete.py index 6e02f6c9..e0f87429 100644 --- a/src/tests/database/model/models_and_experiments/test_ml_model_delete.py +++ b/src/tests/database/model/models_and_experiments/test_ml_model_delete.py @@ -1,20 +1,17 @@ -from sqlalchemy.engine import Engine -from sqlmodel import Session, select +from sqlmodel import select from starlette.testclient import TestClient from database.model.models_and_experiments.experiment import Experiment from database.model.models_and_experiments.ml_model import MLModel +from database.session import DbSession -def test_happy_path( - client: TestClient, - engine: Engine, -): +def test_happy_path(client: TestClient): experiment = Experiment(name="experiment") ml_model = MLModel(name="model", related_experiment=[experiment]) link_model = ml_model.__sqlmodel_relationships__["related_experiment"].link_model - with Session(engine) as session: + with DbSession() as session: session.add(ml_model) session.commit() assert len(session.scalars(select(Experiment)).all()) == 1 diff --git a/src/tests/database/model/resource/test_resource_delete.py b/src/tests/database/model/resource/test_resource_delete.py index 9741ad6c..25b3fa5b 100644 --- a/src/tests/database/model/resource/test_resource_delete.py +++ b/src/tests/database/model/resource/test_resource_delete.py @@ -1,5 +1,4 @@ -from sqlalchemy.engine import Engine -from sqlmodel import Session, select +from sqlmodel import select from starlette.testclient import TestClient from database.model.ai_resource.alternate_name import AlternateName @@ -8,12 +7,10 @@ from database.model.ai_resource.resource_table import AIResourceORM from database.model.dataset.dataset import Dataset from database.model.knowledge_asset.publication import Publication +from database.session import DbSession -def test_happy_path( - client: TestClient, - engine: Engine, -): +def test_happy_path(client: TestClient): dataset_media = Dataset.__annotations__["media"].__args__[0] dataset_note = Dataset.__annotations__["note"].__args__[0] @@ -62,7 +59,7 @@ def test_happy_path( ai_resource_identifier=AIResourceORM(type="publication"), ) - with Session(engine) as session: + with DbSession() as session: session.add_all([dataset_1, dataset_2, publication]) session.commit() session.delete(dataset_1) diff --git a/src/tests/routers/ai_asset_routers/test_router_aiassets_retrieve_content.py b/src/tests/routers/ai_asset_routers/test_router_aiassets_retrieve_content.py index d413ac87..74bfe5be 100644 --- a/src/tests/routers/ai_asset_routers/test_router_aiassets_retrieve_content.py +++ b/src/tests/routers/ai_asset_routers/test_router_aiassets_retrieve_content.py @@ -1,20 +1,18 @@ import copy -import pytest -import responses - -from pytest import FixtureRequest from unittest.mock import Mock +import pytest +import responses from fastapi import status +from pytest import FixtureRequest from sqlalchemy.engine import Engine -from sqlmodel import Session from starlette.testclient import TestClient from authentication import keycloak_openid from database.model.agent.person import Person +from database.session import DbSession from tests.testutils.paths import path_test_resources - TEST_URL1 = "https://www.example.com/example1.csv/content" TEST_URL2 = "https://www.example.com/example2.tsv/content" @@ -66,7 +64,7 @@ def set_up( specified endpoint for the given resource. """ keycloak_openid.userinfo = mocked_privileged_token - with Session(engine) as session: + with DbSession() as session: session.add(person) session.commit() diff --git a/src/tests/routers/enum_routers/test_license_router.py b/src/tests/routers/enum_routers/test_license_router.py index 6289bea8..57df77fd 100644 --- a/src/tests/routers/enum_routers/test_license_router.py +++ b/src/tests/routers/enum_routers/test_license_router.py @@ -1,17 +1,17 @@ from sqlalchemy.engine import Engine -from sqlmodel import Session from starlette.testclient import TestClient from database.model.ai_asset.license import License from database.model.dataset.dataset import Dataset from database.model.knowledge_asset.publication import Publication +from database.session import DbSession def test_happy_path(client: TestClient, engine: Engine, dataset: Dataset, publication: Publication): dataset.license = License(name="license 1") publication.license = License(name="license 2") - with Session(engine) as session: + with DbSession() as session: session.add(dataset) session.merge(publication) session.commit() diff --git a/src/tests/routers/generic/test_router_delete.py b/src/tests/routers/generic/test_router_delete.py index acfeea63..9d798d62 100644 --- a/src/tests/routers/generic/test_router_delete.py +++ b/src/tests/routers/generic/test_router_delete.py @@ -1,34 +1,33 @@ +from unittest.mock import Mock + import pytest -from sqlmodel import Session from starlette.testclient import TestClient -from sqlalchemy.future import Engine -from database.model.concept.status import Status -from tests.testutils.test_resource import test_resource_factory from authentication import keycloak_openid -from unittest.mock import Mock +from database.model.concept.status import Status +from database.session import DbSession +from tests.testutils.test_resource import factory @pytest.mark.parametrize("identifier", [1, 2]) def test_happy_path( client_test_resource: TestClient, - engine_test_resource: Engine, identifier: int, mocked_privileged_token: Mock, draft: Status, ): keycloak_openid.userinfo = mocked_privileged_token - with Session(engine_test_resource) as session: + with DbSession() as session: session.add_all( [ - test_resource_factory( + factory( title="my_test_resource", platform="example", platform_resource_identifier=1, status=draft, ), - test_resource_factory( + factory( title="second_test_resource", platform="example", platform_resource_identifier=2, @@ -51,22 +50,21 @@ def test_happy_path( @pytest.mark.parametrize("identifier", [3, 4]) def test_non_existent( client_test_resource: TestClient, - engine_test_resource: Engine, identifier: int, mocked_privileged_token: Mock, draft: Status, ): keycloak_openid.userinfo = mocked_privileged_token - with Session(engine_test_resource) as session: + with DbSession() as session: session.add_all( [ - test_resource_factory( + factory( title="my_test_resource", platform="example", platform_resource_identifier=1, status=draft, ), - test_resource_factory( + factory( title="second_test_resource", platform="example", platform_resource_identifier=2, @@ -84,7 +82,6 @@ def test_non_existent( def test_add_after_deletion( client_test_resource: TestClient, - engine_test_resource: Engine, mocked_privileged_token: Mock, ): keycloak_openid.userinfo = mocked_privileged_token diff --git a/src/tests/routers/generic/test_router_deprecation.py b/src/tests/routers/generic/test_router_deprecation.py index 3e570de3..3f04442d 100644 --- a/src/tests/routers/generic/test_router_deprecation.py +++ b/src/tests/routers/generic/test_router_deprecation.py @@ -39,7 +39,7 @@ def test_deprecated_router( ): keycloak_openid.userinfo = mocked_privileged_token app = FastAPI() - app.include_router(DeprecatedRouter().create(engine_test_resource_filled, "")) + app.include_router(DeprecatedRouter().create("")) client = TestClient(app) kwargs = {} diff --git a/src/tests/routers/generic/test_router_get_all.py b/src/tests/routers/generic/test_router_get_all.py index c3bcd66f..75a3e08e 100644 --- a/src/tests/routers/generic/test_router_get_all.py +++ b/src/tests/routers/generic/test_router_get_all.py @@ -1,21 +1,16 @@ -from sqlalchemy.future import Engine -from sqlmodel import Session from starlette.testclient import TestClient from database.model.concept.status import Status -from tests.testutils.test_resource import test_resource_factory +from database.session import DbSession +from tests.testutils.test_resource import factory -def test_get_all_happy_path( - client_test_resource: TestClient, engine_test_resource: Engine, draft: Status -): - with Session(engine_test_resource) as session: +def test_get_all_happy_path(client_test_resource: TestClient, draft: Status): + with DbSession() as session: session.add_all( [ - test_resource_factory( - title="my_test_resource_1", status=draft, platform_resource_identifier="2" - ), - test_resource_factory( + factory(title="my_test_resource_1", status=draft, platform_resource_identifier="2"), + factory( title="My second test resource", status=draft, platform_resource_identifier="3" ), ] diff --git a/src/tests/routers/generic/test_router_get_count.py b/src/tests/routers/generic/test_router_get_count.py index db70abd2..9347304d 100644 --- a/src/tests/routers/generic/test_router_get_count.py +++ b/src/tests/routers/generic/test_router_get_count.py @@ -1,26 +1,21 @@ import datetime -from sqlalchemy.future import Engine -from sqlmodel import Session from starlette.testclient import TestClient from database.model.concept.status import Status -from tests.testutils.test_resource import test_resource_factory +from database.session import DbSession +from tests.testutils.test_resource import factory -def test_get_count_happy_path( - client_test_resource: TestClient, engine_test_resource: Engine, draft: Status -): - with Session(engine_test_resource) as session: +def test_get_count_happy_path(client_test_resource: TestClient, draft: Status): + with DbSession() as session: session.add_all( [ - test_resource_factory( - title="my_test_resource_1", status=draft, platform_resource_identifier="1" - ), - test_resource_factory( + factory(title="my_test_resource_1", status=draft, platform_resource_identifier="1"), + factory( title="My second test resource", status=draft, platform_resource_identifier="2" ), - test_resource_factory( + factory( title="My third test resource", status=draft, platform_resource_identifier="3", diff --git a/src/tests/routers/generic/test_router_platform_get_all.py b/src/tests/routers/generic/test_router_platform_get_all.py index d2d74aa7..18d53c8b 100644 --- a/src/tests/routers/generic/test_router_platform_get_all.py +++ b/src/tests/routers/generic/test_router_platform_get_all.py @@ -1,12 +1,11 @@ -from sqlalchemy.future import Engine -from sqlmodel import Session from starlette.testclient import TestClient +from database.session import DbSession from tests.testutils.test_resource import TestResource -def test_get_all_happy_path(client_test_resource: TestClient, engine_test_resource: Engine): - with Session(engine_test_resource) as session: +def test_get_all_happy_path(client_test_resource: TestClient): + with DbSession() as session: session.add_all( [ TestResource( diff --git a/src/tests/routers/generic/test_router_relations.py b/src/tests/routers/generic/test_router_relations.py index 845f4c79..387d00fd 100644 --- a/src/tests/routers/generic/test_router_relations.py +++ b/src/tests/routers/generic/test_router_relations.py @@ -3,7 +3,7 @@ import pytest from fastapi import FastAPI -from sqlmodel import Session, Field, Relationship, SQLModel +from sqlmodel import Field, Relationship, SQLModel from starlette.testclient import TestClient from authentication import keycloak_openid @@ -13,6 +13,7 @@ from database.model.named_relation import NamedRelation from database.model.relationships import ManyToOne, ManyToMany from database.model.serializers import AttributeSerializer, FindByNameDeserializer, CastDeserializer +from database.session import DbSession from routers import ResourceRouter @@ -126,8 +127,8 @@ def resource_class(self) -> Type[TestObject]: @pytest.fixture -def client_with_testobject(engine_test_resource) -> TestClient: - with Session(engine_test_resource) as session: +def client_with_testobject() -> TestClient: + with DbSession() as session: named1, named2 = TestEnum(name="named_string1"), TestEnum(name="named_string2") enum1, enum2, enum3 = TestEnum2(name="1"), TestEnum2(name="2"), TestEnum2(name="3") draft = Status(name="draft") @@ -158,7 +159,7 @@ def client_with_testobject(engine_test_resource) -> TestClient: ) session.commit() app = FastAPI() - app.include_router(RouterTestObject().create(engine_test_resource, "")) + app.include_router(RouterTestObject().create("")) return TestClient(app) diff --git a/src/tests/routers/generic/test_router_scheme.py b/src/tests/routers/generic/test_router_scheme.py index 9e111427..61cb8898 100644 --- a/src/tests/routers/generic/test_router_scheme.py +++ b/src/tests/routers/generic/test_router_scheme.py @@ -37,11 +37,11 @@ def schema_converters(self) -> dict[str, SchemaConverter[TestResource, Any]]: @pytest.fixture(scope="module") -def client_test_resource_other_schema(engine_test_resource: Engine) -> TestClient: +def client_test_resource_other_schema() -> TestClient: """A Startlette TestClient including routes to the TestResource, using schemas "aiod" and "other-schema" """ app = FastAPI() - app.include_router(RouterWithOtherSchema().create(engine_test_resource, "")) + app.include_router(RouterWithOtherSchema().create("")) return TestClient(app) diff --git a/src/tests/routers/parent_routers/test_agent_router.py b/src/tests/routers/parent_routers/test_agent_router.py index 1418c72e..bcb00189 100644 --- a/src/tests/routers/parent_routers/test_agent_router.py +++ b/src/tests/routers/parent_routers/test_agent_router.py @@ -1,11 +1,11 @@ import datetime from sqlalchemy.engine import Engine -from sqlmodel import Session from starlette.testclient import TestClient from database.model.agent.organisation import Organisation from database.model.agent.person import Person +from database.session import DbSession def test_happy_path( @@ -17,7 +17,7 @@ def test_happy_path( organisation.name = "Organisation" person.name = "Person" - with Session(engine) as session: + with DbSession() as session: session.add(organisation) session.merge(person) session.commit() @@ -47,7 +47,7 @@ def test_ignore_deleted( organisation.name = "Organisation" organisation.date_deleted = datetime.datetime.now() person.name = "Person" - with Session(engine) as session: + with DbSession() as session: session.add(organisation) session.merge(person) session.commit() diff --git a/src/tests/routers/parent_routers/test_ai_asset_router.py b/src/tests/routers/parent_routers/test_ai_asset_router.py index 9cc227ad..a13b087d 100644 --- a/src/tests/routers/parent_routers/test_ai_asset_router.py +++ b/src/tests/routers/parent_routers/test_ai_asset_router.py @@ -1,9 +1,9 @@ from sqlalchemy.engine import Engine -from sqlmodel import Session from starlette.testclient import TestClient from database.model.dataset.dataset import Dataset from database.model.knowledge_asset.publication import Publication +from database.session import DbSession def test_happy_path( @@ -15,7 +15,7 @@ def test_happy_path( dataset.name = "Dataset" publication.name = "Publication" - with Session(engine) as session: + with DbSession() as session: session.add(dataset) session.merge(publication) session.commit() diff --git a/src/tests/routers/parent_routers/test_ai_resource_router.py b/src/tests/routers/parent_routers/test_ai_resource_router.py index 14b10e42..94e60316 100644 --- a/src/tests/routers/parent_routers/test_ai_resource_router.py +++ b/src/tests/routers/parent_routers/test_ai_resource_router.py @@ -1,9 +1,9 @@ from sqlalchemy.engine import Engine -from sqlmodel import Session from starlette.testclient import TestClient from database.model.agent.organisation import Organisation from database.model.agent.person import Person +from database.session import DbSession def test_happy_path( @@ -15,7 +15,7 @@ def test_happy_path( organisation.name = "Organisation" person.name = "Person" - with Session(engine) as session: + with DbSession() as session: session.add(organisation) session.merge(person) session.commit() diff --git a/src/tests/routers/resource_routers/test_router_dataset.py b/src/tests/routers/resource_routers/test_router_dataset.py index 984c3451..c8a5dcbd 100644 --- a/src/tests/routers/resource_routers/test_router_dataset.py +++ b/src/tests/routers/resource_routers/test_router_dataset.py @@ -2,11 +2,11 @@ from unittest.mock import Mock from sqlalchemy.engine import Engine -from sqlmodel import Session from starlette.testclient import TestClient from authentication import keycloak_openid from database.model.agent.person import Person +from database.session import DbSession def test_happy_path( @@ -18,7 +18,7 @@ def test_happy_path( ): keycloak_openid.userinfo = mocked_privileged_token - with Session(engine) as session: + with DbSession() as session: session.add(person) session.commit() diff --git a/src/tests/routers/resource_routers/test_router_dataset_generic_fields.py b/src/tests/routers/resource_routers/test_router_dataset_generic_fields.py index 7974b4f2..14c80367 100644 --- a/src/tests/routers/resource_routers/test_router_dataset_generic_fields.py +++ b/src/tests/routers/resource_routers/test_router_dataset_generic_fields.py @@ -1,13 +1,12 @@ import copy import time +import typing from datetime import datetime from unittest.mock import Mock import dateutil.parser import pytz -import typing -from sqlalchemy.engine import Engine -from sqlmodel import Session, select +from sqlmodel import select from starlette.testclient import TestClient from authentication import keycloak_openid @@ -19,11 +18,11 @@ from database.model.dataset.dataset import Dataset from database.model.helper_functions import all_annotations from database.model.knowledge_asset.publication import Publication +from database.session import DbSession def test_happy_path( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body_asset: dict, person: Person, @@ -31,7 +30,7 @@ def test_happy_path( contact: Contact, ): keycloak_openid.userinfo = mocked_privileged_token - with Session(engine) as session: + with DbSession() as session: session.add(person) session.merge(publication) session.add(contact) @@ -151,7 +150,6 @@ def test_happy_path( def test_post_duplicate_named_relations( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, ): """ @@ -221,7 +219,6 @@ def create_body(i: int, *keywords): def test_post_duplicate_named_relations_with_different_capitals( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, ): keycloak_openid.userinfo = mocked_privileged_token @@ -238,7 +235,6 @@ def create_body(i: int, *keywords): def test_post_editors( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, ): """ @@ -269,7 +265,7 @@ def assert_editors_are_stored(id_: str, *editors: int): assert_editors_are_stored("36", 1, 2) -def test_create_aiod_entry(client: TestClient, engine: Engine, mocked_privileged_token: Mock): +def test_create_aiod_entry(client: TestClient, mocked_privileged_token: Mock): keycloak_openid.userinfo = mocked_privileged_token body = {"name": "news"} start = datetime.now(pytz.utc) @@ -288,7 +284,7 @@ def test_create_aiod_entry(client: TestClient, engine: Engine, mocked_privileged assert resource_json["ai_resource_identifier"] == 1 -def test_update_aiod_entry(client: TestClient, engine: Engine, mocked_privileged_token: Mock): +def test_update_aiod_entry(client: TestClient, mocked_privileged_token: Mock): keycloak_openid.userinfo = mocked_privileged_token body = {"name": "news"} start = datetime.now(pytz.utc) @@ -310,38 +306,38 @@ def test_update_aiod_entry(client: TestClient, engine: Engine, mocked_privileged assert end < date_modified assert resource_json["aiod_entry"]["status"] == "published" - with Session(engine) as session: + with DbSession() as session: entries = session.scalars(select(AIoDEntryORM)).all() assert len(entries) == 1 -def assert_distributions(client: TestClient, engine: Engine, *content_urls: str): +def assert_distributions(client: TestClient, *content_urls: str): response = client.get("/datasets/v1/1") distributions = response.json()["distribution"] assert {distribution["content_url"] for distribution in distributions} == set(content_urls) (distribution_class,) = typing.get_args(all_annotations(Dataset)["distribution"]) - with Session(engine) as session: + with DbSession() as session: distributions = session.scalars(select(distribution_class)).all() assert {distribution.content_url for distribution in distributions} == set(content_urls) -def test_update_distribution(client: TestClient, engine: Engine, mocked_privileged_token: Mock): +def test_update_distribution(client: TestClient, mocked_privileged_token: Mock): keycloak_openid.userinfo = mocked_privileged_token body = {"name": "dataset", "distribution": [{"content_url": "url"}]} response = client.post("/datasets/v1", json=body, headers={"Authorization": "Fake token"}) assert response.status_code == 200, response.json() - assert_distributions(client, engine, "url") + assert_distributions(client, "url") body = {"name": "dataset", "distribution": [{"content_url": "url2"}, {"content_url": "test"}]} response = client.put("/datasets/v1/1", json=body, headers={"Authorization": "Fake token"}) assert response.status_code == 200, response.json() - assert_distributions(client, engine, "url2", "test") + assert_distributions(client, "url2", "test") body = {"name": "dataset", "distribution": [{"content_url": "url"}]} response = client.put("/datasets/v1/1", json=body, headers={"Authorization": "Fake token"}) assert response.status_code == 200, response.json() - assert_distributions(client, engine, "url") + assert_distributions(client, "url") def assert_relations( @@ -363,7 +359,6 @@ def assert_relations( def test_relations_between_resources( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body_asset: dict, dataset: Dataset, @@ -372,7 +367,7 @@ def test_relations_between_resources( ): keycloak_openid.userinfo = mocked_privileged_token - with Session(engine) as session: + with DbSession() as session: session.add(dataset) session.merge(publication) session.merge(organisation) diff --git a/src/tests/routers/resource_routers/test_router_event.py b/src/tests/routers/resource_routers/test_router_event.py index c82a66cb..78d04d88 100644 --- a/src/tests/routers/resource_routers/test_router_event.py +++ b/src/tests/routers/resource_routers/test_router_event.py @@ -1,23 +1,21 @@ import copy from unittest.mock import Mock -from sqlalchemy.engine import Engine -from sqlmodel import Session from starlette.testclient import TestClient from authentication import keycloak_openid from database.model.agent.person import Person +from database.session import DbSession def test_happy_path( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body_resource: dict, person: Person, ): - with Session(engine) as session: + with DbSession() as session: session.add(person) session.commit() diff --git a/src/tests/routers/resource_routers/test_router_ml_model.py b/src/tests/routers/resource_routers/test_router_ml_model.py index fdc20204..a2ac5060 100644 --- a/src/tests/routers/resource_routers/test_router_ml_model.py +++ b/src/tests/routers/resource_routers/test_router_ml_model.py @@ -1,24 +1,22 @@ import copy from unittest.mock import Mock -from sqlalchemy.engine import Engine -from sqlalchemy.orm import Session from starlette.testclient import TestClient from authentication import keycloak_openid from database.model.models_and_experiments.experiment import Experiment +from database.session import DbSession def test_happy_path( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, experiment: Experiment, body_asset: dict, ): keycloak_openid.userinfo = mocked_privileged_token - with Session(engine) as session: + with DbSession() as session: session.add(experiment) session.commit() diff --git a/src/tests/routers/resource_routers/test_router_organisation.py b/src/tests/routers/resource_routers/test_router_organisation.py index 4532f13e..ba30845c 100644 --- a/src/tests/routers/resource_routers/test_router_organisation.py +++ b/src/tests/routers/resource_routers/test_router_organisation.py @@ -1,18 +1,16 @@ import copy from unittest.mock import Mock -from sqlalchemy.engine import Engine -from sqlalchemy.orm import Session from starlette.testclient import TestClient from authentication import keycloak_openid from database.model.agent.contact import Contact from database.model.agent.organisation import Organisation +from database.session import DbSession def test_happy_path( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, organisation: Organisation, contact: Contact, @@ -20,7 +18,7 @@ def test_happy_path( ): keycloak_openid.userinfo = mocked_privileged_token - with Session(engine) as session: + with DbSession() as session: session.add(organisation) # The new organisation will be a member of this organisation session.add(contact) session.commit() diff --git a/src/tests/routers/resource_routers/test_router_person.py b/src/tests/routers/resource_routers/test_router_person.py index d5ee394a..f301bcdb 100644 --- a/src/tests/routers/resource_routers/test_router_person.py +++ b/src/tests/routers/resource_routers/test_router_person.py @@ -1,18 +1,16 @@ import copy from unittest.mock import Mock -from sqlalchemy.engine import Engine -from sqlmodel import Session from starlette.testclient import TestClient from authentication import keycloak_openid from database.model.agent.contact import Contact from database.model.agent.person import Person +from database.session import DbSession def test_happy_path( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body_agent: dict, person: Person, @@ -20,7 +18,7 @@ def test_happy_path( ): keycloak_openid.userinfo = mocked_privileged_token - with Session(engine) as session: + with DbSession() as session: person.platform_resource_identifier = "2" session.add(person) session.add(contact) diff --git a/src/tests/routers/resource_routers/test_router_project.py b/src/tests/routers/resource_routers/test_router_project.py index 1b019a94..2b565067 100644 --- a/src/tests/routers/resource_routers/test_router_project.py +++ b/src/tests/routers/resource_routers/test_router_project.py @@ -1,8 +1,6 @@ import copy from unittest.mock import Mock -from sqlalchemy.engine import Engine -from sqlmodel import Session from starlette.testclient import TestClient from authentication import keycloak_openid @@ -10,11 +8,11 @@ from database.model.agent.person import Person from database.model.dataset.dataset import Dataset from database.model.knowledge_asset.publication import Publication +from database.session import DbSession def test_happy_path( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body_resource: dict, person: Person, @@ -24,7 +22,7 @@ def test_happy_path( ): keycloak_openid.userinfo = mocked_privileged_token - with Session(engine) as session: + with DbSession() as session: session.add(person) session.merge(organisation) session.merge(dataset) diff --git a/src/tests/routers/resource_routers/test_router_team.py b/src/tests/routers/resource_routers/test_router_team.py index 0a64392b..35b85843 100644 --- a/src/tests/routers/resource_routers/test_router_team.py +++ b/src/tests/routers/resource_routers/test_router_team.py @@ -1,18 +1,16 @@ import copy from unittest.mock import Mock -from sqlalchemy.engine import Engine -from sqlmodel import Session from starlette.testclient import TestClient from authentication import keycloak_openid from database.model.agent.organisation import Organisation from database.model.agent.person import Person +from database.session import DbSession def test_happy_path( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body_resource: dict, person: Person, @@ -20,7 +18,7 @@ def test_happy_path( ): keycloak_openid.userinfo = mocked_privileged_token - with Session(engine) as session: + with DbSession() as session: session.add(person) session.merge(organisation) session.commit() diff --git a/src/tests/testutils/default_instances.py b/src/tests/testutils/default_instances.py index 715be542..c6f3f15d 100644 --- a/src/tests/testutils/default_instances.py +++ b/src/tests/testutils/default_instances.py @@ -8,7 +8,6 @@ import pytest from sqlalchemy.engine import Engine -from sqlmodel import Session from database.model.agent.contact import Contact from database.model.agent.organisation import Organisation @@ -19,6 +18,7 @@ from database.model.models_and_experiments.experiment import Experiment from database.model.resource_read_and_create import resource_create from database.model.serializers import deserialize_resource_relationships +from database.session import DbSession from tests.testutils.paths import path_test_resources @@ -61,13 +61,13 @@ def body_agent(body_resource: dict) -> dict: @pytest.fixture -def publication(body_asset: dict, engine: Engine) -> Publication: +def publication(body_asset: dict) -> Publication: body = copy.copy(body_asset) body["permanent_identifier"] = "http://dx.doi.org/10.1093/ajae/aaq063" body["isbn"] = "9783161484100" body["issn"] = "20493630" body["type"] = "journal" - return _create_class_with_body(Publication, body, engine) + return _create_class_with_body(Publication, body) @pytest.fixture @@ -81,46 +81,46 @@ def contact(body_concept, engine: Engine) -> Contact: "geo": {"latitude": 37.42242, "longitude": -122.08585, "elevation_millimeters": 2000}, } ] - return _create_class_with_body(Contact, body, engine) + return _create_class_with_body(Contact, body) @pytest.fixture -def dataset(body_asset: dict, engine: Engine) -> Dataset: +def dataset(body_asset: dict) -> Dataset: body = copy.copy(body_asset) body["issn"] = "20493630" body["measurement_technique"] = "mass spectrometry" body["temporal_coverage"] = "2011/2012" - return _create_class_with_body(Dataset, body, engine) + return _create_class_with_body(Dataset, body) @pytest.fixture -def organisation(body_agent, engine: Engine) -> Organisation: +def organisation(body_agent) -> Organisation: body = copy.copy(body_agent) body["date_founded"] = "2022-01-01" body["legal_name"] = "Legal Name" body["ai_relevance"] = "Description of relevance in AI" - return _create_class_with_body(Organisation, body, engine) + return _create_class_with_body(Organisation, body) @pytest.fixture -def person(body_agent, engine: Engine) -> Person: +def person(body_agent) -> Person: body = copy.copy(body_agent) body["expertise"] = ["machine learning"] body["language"] = ["eng", "nld"] - return _create_class_with_body(Person, body, engine) + return _create_class_with_body(Person, body) @pytest.fixture -def experiment(body_asset, engine: Engine) -> Experiment: +def experiment(body_asset) -> Experiment: body = copy.copy(body_asset) - return _create_class_with_body(Experiment, body, engine) + return _create_class_with_body(Experiment, body) -def _create_class_with_body(clz, body: dict, engine: Engine): +def _create_class_with_body(clz, body: dict): pydantic_class = resource_create(clz) res_create = pydantic_class(**body) res = clz.from_orm(res_create) - with Session(engine) as session: + with DbSession() as session: deserialize_resource_relationships(session, clz, res, res_create) session.commit() if hasattr(res, "ai_resource"): diff --git a/src/tests/testutils/default_sqlalchemy.py b/src/tests/testutils/default_sqlalchemy.py index 2822a47f..537c8804 100644 --- a/src/tests/testutils/default_sqlalchemy.py +++ b/src/tests/testutils/default_sqlalchemy.py @@ -7,15 +7,16 @@ from fastapi import FastAPI from sqlalchemy import event, text from sqlalchemy.engine import Engine -from sqlmodel import create_engine, SQLModel, Session +from sqlmodel import create_engine, SQLModel from starlette.testclient import TestClient from database.deletion.triggers import add_delete_triggers from database.model.concept.concept import AIoDConcept from database.model.platform.platform import Platform from database.model.platform.platform_names import PlatformName +from database.session import EngineSingleton, DbSession from main import add_routes -from tests.testutils.test_resource import RouterTestResource, test_resource_factory +from tests.testutils.test_resource import RouterTestResource, factory @pytest.fixture(scope="session") @@ -31,49 +32,47 @@ def engine(deletion_triggers) -> Iterator[Engine]: Create a SqlAlchemy engine for tests, backed by a temporary sqlite file. """ temporary_file = tempfile.NamedTemporaryFile() - engine = create_engine(f"sqlite:///{temporary_file.name}") - SQLModel.metadata.create_all(engine) + engine = create_engine(f"sqlite:///{temporary_file.name}?check_same_thread=False") + AIoDConcept.metadata.create_all(engine) + EngineSingleton().patch(engine) + # Yielding is essential, the temporary file will be closed after the engine is used yield engine -@pytest.fixture(scope="session") -def engine_test_resource(deletion_triggers) -> Iterator[Engine]: - """Create a SqlAlchemy Engine populated with an instance of the TestResource""" - temporary_file = tempfile.NamedTemporaryFile() - engine = create_engine(f"sqlite:///{temporary_file.name}") - SQLModel.metadata.create_all(engine) +@pytest.fixture +def engine_test_resource_filled(engine: Engine) -> Iterator[Engine]: + """ + Engine will be filled with an example value after before each test, in clear_db. + """ yield engine @pytest.fixture(autouse=True) -def clear_db(request): +def clear_db(request, engine: Engine): """ This fixture will be used by every test and checks if the test uses an engine. If it does, it deletes the content of the database, so the test has a fresh db to work with. """ - for engine_name in ("engine", "engine_test_resource", "engine_test_resource_filled"): - if engine_name in request.fixturenames: - engine = request.getfixturevalue(engine_name) - with engine.connect() as connection: - transaction = connection.begin() - connection.execute(text("PRAGMA foreign_keys=OFF")) - for table in SQLModel.metadata.tables.values(): - connection.execute(table.delete()) - connection.execute(text("PRAGMA foreign_keys=ON")) - transaction.commit() - with Session(engine) as session: - session.add_all([Platform(name=name) for name in PlatformName]) - if "filled" in engine_name: - session.add( - test_resource_factory( - title="A title", - platform="example", - platform_resource_identifier="1", - ) - ) - session.commit() + with engine.connect() as connection: + transaction = connection.begin() + connection.execute(text("PRAGMA foreign_keys=OFF")) + for table in SQLModel.metadata.sorted_tables: + connection.execute(table.delete()) + connection.execute(text("PRAGMA foreign_keys=ON")) + transaction.commit() + with DbSession() as session: + session.add_all([Platform(name=name) for name in PlatformName]) + if any("engine" in fixture and "filled" in fixture for fixture in request.fixturenames): + session.add( + factory( + title="A title", + platform="example", + platform_resource_identifier="1", + ) + ) + session.commit() @event.listens_for(Engine, "connect") @@ -87,29 +86,21 @@ def sqlite_enable_foreign_key_constraints(dbapi_connection, connection_record): cursor.close() -@pytest.fixture -def engine_test_resource_filled(engine_test_resource: Engine) -> Iterator[Engine]: - """ - Engine will be filled with an example value after before each test, in clear_db. - """ - yield engine_test_resource - - @pytest.fixture(scope="session") def client(engine: Engine) -> TestClient: """ Create a TestClient that can be used to mock sending requests to our application """ app = FastAPI() - add_routes(app, engine) + add_routes(app) return TestClient(app, base_url="http://localhost") @pytest.fixture(scope="session") -def client_test_resource(engine_test_resource) -> TestClient: +def client_test_resource(engine: Engine) -> TestClient: """A Startlette TestClient including routes to the TestResource, only in "aiod" schema""" app = FastAPI() - app.include_router(RouterTestResource().create(engine_test_resource, "")) + app.include_router(RouterTestResource().create("")) return TestClient(app, base_url="http://localhost") diff --git a/src/tests/testutils/test_resource.py b/src/tests/testutils/test_resource.py index d3975745..b8d59f48 100644 --- a/src/tests/testutils/test_resource.py +++ b/src/tests/testutils/test_resource.py @@ -20,7 +20,7 @@ class TestResource(TestResourceBase, AIoDConcept, table=True): # type: ignore [ identifier: int = Field(default=None, primary_key=True) -def test_resource_factory( +def factory( title=None, status=None, platform="example", platform_resource_identifier="1", date_deleted=None ): if status is None: diff --git a/src/tests/uploader/huggingface/test_dataset_uploader.py b/src/tests/uploader/huggingface/test_dataset_uploader.py index d0aceeec..38ee8952 100644 --- a/src/tests/uploader/huggingface/test_dataset_uploader.py +++ b/src/tests/uploader/huggingface/test_dataset_uploader.py @@ -1,27 +1,22 @@ from unittest.mock import Mock import huggingface_hub -import pytest import responses from sqlalchemy.engine import Engine -from sqlalchemy.orm import Session from starlette.testclient import TestClient from authentication import keycloak_openid from database.model.ai_asset.ai_asset_table import AIAssetTable from database.model.dataset.dataset import Dataset +from database.session import DbSession from tests.testutils.paths import path_test_resources -@pytest.mark.skip(reason="We'll fix this in a separate PR") -# TODO: there are errors when running these tests: "... is not bound to a Session; lazy load -# operation of attribute 'license' cannot proceed". -# See TODOs at hugging_face_uploader.py. def test_happy_path_new_repository( client: TestClient, engine: Engine, mocked_privileged_token: Mock, dataset: Dataset ): keycloak_openid.userinfo = mocked_privileged_token - with Session(engine) as session: + with DbSession() as session: session.add(dataset) session.commit() @@ -53,11 +48,10 @@ def test_happy_path_new_repository( assert id_response == 1 -@pytest.mark.skip(reason="We'll fix this in a separate PR") def test_repo_already_exists(client: TestClient, engine: Engine, mocked_privileged_token: Mock): keycloak_openid.userinfo = mocked_privileged_token dataset_id = 1 - with Session(engine) as session: + with DbSession() as session: session.add_all( [ AIAssetTable(type="dataset"), diff --git a/src/uploader/hugging_face_uploader.py b/src/uploader/hugging_face_uploader.py index 689ee4c9..fc7f18a2 100644 --- a/src/uploader/hugging_face_uploader.py +++ b/src/uploader/hugging_face_uploader.py @@ -1,83 +1,78 @@ import io + import huggingface_hub from fastapi import HTTPException, UploadFile, status from requests import HTTPError -from sqlalchemy.engine import Engine -from sqlalchemy.orm import joinedload from sqlmodel import Session from database.model.dataset.dataset import Dataset +from database.session import DbSession from .utils import huggingface_license_identifiers def handle_upload( - engine: Engine, identifier: int, file: UploadFile, token: str, username: str, ): - dataset = _get_resource(engine=engine, identifier=identifier) # TODO: place this inside session - dataset_name_cleaned = "".join(c if c.isalnum() else "_" for c in dataset.name) - repo_id = f"{username}/{dataset_name_cleaned}" + with DbSession() as session: + dataset = _get_resource(session=session, identifier=identifier) + dataset_name_cleaned = "".join(c if c.isalnum() else "_" for c in dataset.name) + repo_id = f"{username}/{dataset_name_cleaned}" - url = _create_or_get_repo_url(repo_id, token) - metadata_file = _generate_metadata_file(dataset) - try: - huggingface_hub.upload_file( - path_or_fileobj=metadata_file, - path_in_repo="README.md", - repo_id=repo_id, - repo_type="dataset", - token=token, - ) - except HTTPError: - msg = "Error updating the metadata, huggingface api returned a http error: {e.strerror}" - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=msg) - - except ValueError as e: - msg = f"Error updating the metadata, bad format: {e}" - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=msg) - except Exception: - msg = "Error updating the metadata, unexpected error" - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=msg) + url = _create_or_get_repo_url(repo_id, token) + metadata_file = _generate_metadata_file(dataset) + try: + huggingface_hub.upload_file( + path_or_fileobj=metadata_file, + path_in_repo="README.md", + repo_id=repo_id, + repo_type="dataset", + token=token, + ) + except HTTPError: + msg = "Error updating the metadata, huggingface api returned a http error: {e.strerror}" + raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=msg) + + except ValueError as e: + msg = f"Error updating the metadata, bad format: {e}" + raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=msg) + except Exception: + msg = "Error updating the metadata, unexpected error" + raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=msg) - try: - huggingface_hub.upload_file( - path_or_fileobj=io.BufferedReader(file.file), - path_in_repo=f"/data/{file.filename}", - repo_id=repo_id, - repo_type="dataset", - token=token, - ) - except HTTPError as e: - msg = f"Error uploading the file, huggingface api returned a http error: {e.strerror}" - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=msg) - - except ValueError: - msg = "Error uploading the file, bad format" - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=msg) - except Exception as e: - msg = f"Error uploading the file, unexpected error: {e.with_traceback}" - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=msg) + try: + huggingface_hub.upload_file( + path_or_fileobj=io.BufferedReader(file.file), + path_in_repo=f"/data/{file.filename}", + repo_id=repo_id, + repo_type="dataset", + token=token, + ) + except HTTPError as e: + msg = f"Error uploading the file, huggingface api returned a http error: {e.strerror}" + raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=msg) + + except ValueError: + msg = "Error uploading the file, bad format" + raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=msg) + except Exception as e: + msg = f"Error uploading the file, unexpected error: {e.with_traceback}" + raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=msg) - if not any(data.name == repo_id for data in dataset.distribution): - _store_resource_updated(engine, dataset, url, repo_id) + if not any(data.name == repo_id for data in dataset.distribution): + _store_resource_updated(session, dataset, url, repo_id) - return dataset.identifier + return dataset.identifier -def _get_resource(engine: Engine, identifier: int) -> Dataset: +def _get_resource(session: Session, identifier: int) -> Dataset: """ Get the resource identified by AIoD identifier """ - with Session(engine) as session: - query = ( - session.query(Dataset) # TODO: remove these joinedloads - .options(joinedload(Dataset.keyword), joinedload(Dataset.distribution)) - .filter(Dataset.identifier == identifier) - ) + query = session.query(Dataset).filter(Dataset.identifier == identifier) resource = query.first() if not resource: @@ -86,21 +81,20 @@ def _get_resource(engine: Engine, identifier: int) -> Dataset: return resource -def _store_resource_updated(engine: Engine, resource: Dataset, url: str, repo_id: str): - with Session(engine) as session: - try: - # Hack to get the right DistributionORM class (for each class, such as Dataset - # and Publication, there is a different DistributionORM table). - dist = resource.RelationshipConfig.distribution.deserializer.clazz # type: ignore - distribution = dist(content_url=url, name=repo_id, dataset=resource) - resource.distribution.append(distribution) - session.merge(resource) - session.commit() - except Exception as e: - raise HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, - detail="Dataset metadata could not be uploaded", - ) from e +def _store_resource_updated(session: Session, resource: Dataset, url: str, repo_id: str): + try: + # Hack to get the right DistributionORM class (for each class, such as Dataset + # and Publication, there is a different DistributionORM table). + dist = resource.RelationshipConfig.distribution.deserializer.clazz # type: ignore + distribution = dist(content_url=url, name=repo_id, dataset=resource) + resource.distribution.append(distribution) + session.merge(resource) + session.commit() + except Exception as e: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Dataset metadata could not be uploaded", + ) from e def _create_or_get_repo_url(repo_id, token): From 6b3e06761a6584b4f562ef7fc4a8668e54e88008 Mon Sep 17 00:00:00 2001 From: Jos van der Velde Date: Tue, 21 Nov 2023 15:39:17 +0100 Subject: [PATCH 2/3] Fixed typo: allways --- src/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main.py b/src/main.py index 948e22a0..e0b1f3a1 100644 --- a/src/main.py +++ b/src/main.py @@ -100,7 +100,7 @@ def create_app() -> FastAPI: "scopes": KEYCLOAK_CONFIG.get("scopes"), }, ) - drop_or_create_database(delete_first=args.rebuild_db == "allways") + drop_or_create_database(delete_first=args.rebuild_db == "always") AIoDConcept.metadata.create_all(EngineSingleton().engine, checkfirst=True) with DbSession() as session: existing_platforms = session.scalars(select(Platform)).all() From 47b254c7e5d729f8767fbb1f45020f69660ae5e7 Mon Sep 17 00:00:00 2001 From: Jos van der Velde Date: Tue, 21 Nov 2023 15:39:52 +0100 Subject: [PATCH 3/3] Removed unused engine from tests --- .../test_dataset_dcatap_converter.py | 4 +--- .../test_dataset_schemaDotOrg_converter.py | 4 +--- .../test_router_aiassets_retrieve_content.py | 19 ++++++------------- .../enum_routers/test_license_router.py | 3 +-- .../parent_routers/test_agent_router.py | 3 --- .../parent_routers/test_ai_asset_router.py | 2 -- .../parent_routers/test_ai_resource_router.py | 2 -- .../test_router_case_study.py | 2 -- .../resource_routers/test_router_dataset.py | 2 -- .../test_router_educational_resource.py | 2 -- .../test_router_experiment.py | 2 -- .../test_router_publication.py | 2 -- .../resource_routers/test_router_service.py | 2 -- .../huggingface/test_dataset_uploader.py | 5 ++--- 14 files changed, 11 insertions(+), 43 deletions(-) diff --git a/src/tests/converters/schema_converters/test_dataset_dcatap_converter.py b/src/tests/converters/schema_converters/test_dataset_dcatap_converter.py index 5df56c36..bafb410c 100644 --- a/src/tests/converters/schema_converters/test_dataset_dcatap_converter.py +++ b/src/tests/converters/schema_converters/test_dataset_dcatap_converter.py @@ -1,7 +1,5 @@ import datetime -from sqlalchemy.engine import Engine - from converters.schema_converters import dataset_converter_dcatap_instance from database.model.agent.person import Person from database.model.ai_asset.license import License @@ -13,7 +11,7 @@ from tests.testutils.paths import path_test_resources -def test_aiod_to_dcatap_happy_path(engine: Engine, dataset: Dataset): +def test_aiod_to_dcatap_happy_path(dataset: Dataset): dataset.identifier = 1 dataset.license = License(name="a license") dataset.alternate_name = [AlternateName(name="alias1"), AlternateName(name="alias2")] diff --git a/src/tests/converters/schema_converters/test_dataset_schemaDotOrg_converter.py b/src/tests/converters/schema_converters/test_dataset_schemaDotOrg_converter.py index 514bcbeb..7e452cfa 100644 --- a/src/tests/converters/schema_converters/test_dataset_schemaDotOrg_converter.py +++ b/src/tests/converters/schema_converters/test_dataset_schemaDotOrg_converter.py @@ -1,7 +1,5 @@ import datetime -from sqlalchemy.engine import Engine - from converters.schema_converters import dataset_converter_schema_dot_org_instance from database.model.agent.agent_table import AgentTable from database.model.agent.contact import Contact @@ -15,7 +13,7 @@ from tests.testutils.paths import path_test_resources -def test_aiod_to_schema_dot_org_happy_path(engine: Engine, dataset: Dataset): +def test_aiod_to_schema_dot_org_happy_path(dataset: Dataset): dataset.identifier = 1 dataset.license = License(name="a license") dataset.alternate_name = [AlternateName(name="alias1"), AlternateName(name="alias2")] diff --git a/src/tests/routers/ai_asset_routers/test_router_aiassets_retrieve_content.py b/src/tests/routers/ai_asset_routers/test_router_aiassets_retrieve_content.py index 74bfe5be..ca33c7a9 100644 --- a/src/tests/routers/ai_asset_routers/test_router_aiassets_retrieve_content.py +++ b/src/tests/routers/ai_asset_routers/test_router_aiassets_retrieve_content.py @@ -5,7 +5,6 @@ import responses from fastapi import status from pytest import FixtureRequest -from sqlalchemy.engine import Engine from starlette.testclient import TestClient from authentication import keycloak_openid @@ -51,7 +50,6 @@ def mock_response2(mocked_requests: responses.RequestsMock): def set_up( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body: dict, person: Person, @@ -79,9 +77,8 @@ def resource_name(request: FixtureRequest) -> str: return request.param -def test_ai_asset_has_endopoints( +def test_ai_asset_has_endpoints( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body_asset_with_single_distribution: dict, person: Person, @@ -96,7 +93,7 @@ def test_ai_asset_has_endopoints( return a response with status code 200. """ body = copy.deepcopy(body_asset_with_single_distribution) - set_up(client, engine, mocked_privileged_token, body, person, resource_name) + set_up(client, mocked_privileged_token, body, person, resource_name) default_endpoint = f"{resource_name}/v1/1/content" @@ -111,7 +108,6 @@ def test_ai_asset_has_endopoints( def test_endpoints_when_empty_distribution( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body_asset: dict, person: Person, @@ -125,7 +121,7 @@ def test_endpoints_when_empty_distribution( """ body = copy.deepcopy(body_asset) body["distribution"] = [] - set_up(client, engine, mocked_privileged_token, body, person, SAMPLE_RESOURCE_NAME) + set_up(client, mocked_privileged_token, body, person, SAMPLE_RESOURCE_NAME) response = client.get(SAMPLE_ENDPOINT) assert response.status_code == status.HTTP_404_NOT_FOUND, response.json() @@ -147,7 +143,6 @@ def body_asset_with_single_distribution(body_asset: dict) -> dict: def test_endpoints_when_single_distribution( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body_asset_with_single_distribution: dict, person: Person, @@ -160,7 +155,7 @@ def test_endpoints_when_single_distribution( content, headers, and filename are returned. """ body = copy.deepcopy(body_asset_with_single_distribution) - set_up(client, engine, mocked_privileged_token, body, person, SAMPLE_RESOURCE_NAME) + set_up(client, mocked_privileged_token, body, person, SAMPLE_RESOURCE_NAME) with responses.RequestsMock() as mocked_requests: mock_response1(mocked_requests) @@ -203,7 +198,6 @@ def body_asset_with_two_distributions(body_asset_with_single_distribution: dict) def test_endpoints_when_two_distributions( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body_asset_with_two_distributions: dict, person: Person, @@ -216,7 +210,7 @@ def test_endpoints_when_two_distributions( content, headers, and filename are returned. """ body = copy.deepcopy(body_asset_with_two_distributions) - set_up(client, engine, mocked_privileged_token, body, person, SAMPLE_RESOURCE_NAME) + set_up(client, mocked_privileged_token, body, person, SAMPLE_RESOURCE_NAME) with responses.RequestsMock() as mocked_requests: mock_response1(mocked_requests) @@ -259,7 +253,6 @@ def encoding_format(request: FixtureRequest) -> str: def test_headers_when_distribution_has_missing_fields( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body_asset_with_single_distribution: dict, person: Person, @@ -279,7 +272,7 @@ def test_headers_when_distribution_has_missing_fields( alternate_filename = body["distribution"][0]["content_url"].split("/")[-1] - set_up(client, engine, mocked_privileged_token, body, person, SAMPLE_RESOURCE_NAME) + set_up(client, mocked_privileged_token, body, person, SAMPLE_RESOURCE_NAME) with responses.RequestsMock() as mocked_requests: mock_response1(mocked_requests) diff --git a/src/tests/routers/enum_routers/test_license_router.py b/src/tests/routers/enum_routers/test_license_router.py index 57df77fd..9360cf92 100644 --- a/src/tests/routers/enum_routers/test_license_router.py +++ b/src/tests/routers/enum_routers/test_license_router.py @@ -1,4 +1,3 @@ -from sqlalchemy.engine import Engine from starlette.testclient import TestClient from database.model.ai_asset.license import License @@ -7,7 +6,7 @@ from database.session import DbSession -def test_happy_path(client: TestClient, engine: Engine, dataset: Dataset, publication: Publication): +def test_happy_path(client: TestClient, dataset: Dataset, publication: Publication): dataset.license = License(name="license 1") publication.license = License(name="license 2") diff --git a/src/tests/routers/parent_routers/test_agent_router.py b/src/tests/routers/parent_routers/test_agent_router.py index bcb00189..f3409027 100644 --- a/src/tests/routers/parent_routers/test_agent_router.py +++ b/src/tests/routers/parent_routers/test_agent_router.py @@ -1,6 +1,5 @@ import datetime -from sqlalchemy.engine import Engine from starlette.testclient import TestClient from database.model.agent.organisation import Organisation @@ -10,7 +9,6 @@ def test_happy_path( client: TestClient, - engine: Engine, organisation: Organisation, person: Person, ): @@ -39,7 +37,6 @@ def test_happy_path( def test_ignore_deleted( client: TestClient, - engine: Engine, organisation: Organisation, person: Person, ): diff --git a/src/tests/routers/parent_routers/test_ai_asset_router.py b/src/tests/routers/parent_routers/test_ai_asset_router.py index a13b087d..40ad8a22 100644 --- a/src/tests/routers/parent_routers/test_ai_asset_router.py +++ b/src/tests/routers/parent_routers/test_ai_asset_router.py @@ -1,4 +1,3 @@ -from sqlalchemy.engine import Engine from starlette.testclient import TestClient from database.model.dataset.dataset import Dataset @@ -8,7 +7,6 @@ def test_happy_path( client: TestClient, - engine: Engine, dataset: Dataset, publication: Publication, ): diff --git a/src/tests/routers/parent_routers/test_ai_resource_router.py b/src/tests/routers/parent_routers/test_ai_resource_router.py index 94e60316..f7239535 100644 --- a/src/tests/routers/parent_routers/test_ai_resource_router.py +++ b/src/tests/routers/parent_routers/test_ai_resource_router.py @@ -1,4 +1,3 @@ -from sqlalchemy.engine import Engine from starlette.testclient import TestClient from database.model.agent.organisation import Organisation @@ -8,7 +7,6 @@ def test_happy_path( client: TestClient, - engine: Engine, organisation: Organisation, person: Person, ): diff --git a/src/tests/routers/resource_routers/test_router_case_study.py b/src/tests/routers/resource_routers/test_router_case_study.py index 8340735d..b162515f 100644 --- a/src/tests/routers/resource_routers/test_router_case_study.py +++ b/src/tests/routers/resource_routers/test_router_case_study.py @@ -1,7 +1,6 @@ import copy from unittest.mock import Mock -from sqlalchemy.engine import Engine from starlette.testclient import TestClient from authentication import keycloak_openid @@ -9,7 +8,6 @@ def test_happy_path( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body_asset: dict, ): diff --git a/src/tests/routers/resource_routers/test_router_dataset.py b/src/tests/routers/resource_routers/test_router_dataset.py index c8a5dcbd..4b21b712 100644 --- a/src/tests/routers/resource_routers/test_router_dataset.py +++ b/src/tests/routers/resource_routers/test_router_dataset.py @@ -1,7 +1,6 @@ import copy from unittest.mock import Mock -from sqlalchemy.engine import Engine from starlette.testclient import TestClient from authentication import keycloak_openid @@ -11,7 +10,6 @@ def test_happy_path( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body_asset: dict, person: Person, diff --git a/src/tests/routers/resource_routers/test_router_educational_resource.py b/src/tests/routers/resource_routers/test_router_educational_resource.py index dbefd7c3..c411f48e 100644 --- a/src/tests/routers/resource_routers/test_router_educational_resource.py +++ b/src/tests/routers/resource_routers/test_router_educational_resource.py @@ -1,7 +1,6 @@ import copy from unittest.mock import Mock -from sqlalchemy.engine import Engine from starlette.testclient import TestClient from authentication import keycloak_openid @@ -9,7 +8,6 @@ def test_happy_path( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body_asset: dict, ): diff --git a/src/tests/routers/resource_routers/test_router_experiment.py b/src/tests/routers/resource_routers/test_router_experiment.py index 1b9ce76b..d98f1903 100644 --- a/src/tests/routers/resource_routers/test_router_experiment.py +++ b/src/tests/routers/resource_routers/test_router_experiment.py @@ -1,7 +1,6 @@ import copy from unittest.mock import Mock -from sqlalchemy.engine import Engine from starlette.testclient import TestClient from authentication import keycloak_openid @@ -9,7 +8,6 @@ def test_happy_path( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body_asset: dict, ): diff --git a/src/tests/routers/resource_routers/test_router_publication.py b/src/tests/routers/resource_routers/test_router_publication.py index ddbe0cc0..8f7aff8f 100644 --- a/src/tests/routers/resource_routers/test_router_publication.py +++ b/src/tests/routers/resource_routers/test_router_publication.py @@ -1,7 +1,6 @@ import copy from unittest.mock import Mock -from sqlalchemy.engine import Engine from starlette.testclient import TestClient from authentication import keycloak_openid @@ -10,7 +9,6 @@ def test_happy_path( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body_asset: dict, dataset: Dataset, diff --git a/src/tests/routers/resource_routers/test_router_service.py b/src/tests/routers/resource_routers/test_router_service.py index e7cd0914..6ea64aff 100644 --- a/src/tests/routers/resource_routers/test_router_service.py +++ b/src/tests/routers/resource_routers/test_router_service.py @@ -1,7 +1,6 @@ import copy from unittest.mock import Mock -from sqlalchemy.engine import Engine from starlette.testclient import TestClient from authentication import keycloak_openid @@ -9,7 +8,6 @@ def test_happy_path( client: TestClient, - engine: Engine, mocked_privileged_token: Mock, body_resource: dict, ): diff --git a/src/tests/uploader/huggingface/test_dataset_uploader.py b/src/tests/uploader/huggingface/test_dataset_uploader.py index 38ee8952..8b8e09e7 100644 --- a/src/tests/uploader/huggingface/test_dataset_uploader.py +++ b/src/tests/uploader/huggingface/test_dataset_uploader.py @@ -2,7 +2,6 @@ import huggingface_hub import responses -from sqlalchemy.engine import Engine from starlette.testclient import TestClient from authentication import keycloak_openid @@ -13,7 +12,7 @@ def test_happy_path_new_repository( - client: TestClient, engine: Engine, mocked_privileged_token: Mock, dataset: Dataset + client: TestClient, mocked_privileged_token: Mock, dataset: Dataset ): keycloak_openid.userinfo = mocked_privileged_token with DbSession() as session: @@ -48,7 +47,7 @@ def test_happy_path_new_repository( assert id_response == 1 -def test_repo_already_exists(client: TestClient, engine: Engine, mocked_privileged_token: Mock): +def test_repo_already_exists(client: TestClient, mocked_privileged_token: Mock): keycloak_openid.userinfo = mocked_privileged_token dataset_id = 1 with DbSession() as session: